diff --git a/.lintrunner.toml b/.lintrunner.toml index e1b24b2955b03..be46ba0baabdb 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -127,7 +127,6 @@ include_patterns = [ ] exclude_patterns = [ 'java/**', # FIXME: Enable clang-format for java - 'js/**', 'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**', # Contains data chunks 'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code 'onnxruntime/test/flatbuffers/*.fbs.h', # Generated code diff --git a/js/.clang-format b/js/.clang-format deleted file mode 100644 index 596eec15a995f..0000000000000 --- a/js/.clang-format +++ /dev/null @@ -1,16 +0,0 @@ ---- -Language: JavaScript -BasedOnStyle: Google -ColumnLimit: 120 ---- -Language: Cpp -BasedOnStyle: LLVM -ColumnLimit: 120 ---- -Language: ObjC -BasedOnStyle: LLVM -ColumnLimit: 120 ---- -Language: Java -BasedOnStyle: LLVM -ColumnLimit: 120 diff --git a/js/.eslintrc.js b/js/.eslintrc.js index 77aced2d4bde0..bd1e9061355f5 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -14,42 +14,47 @@ module.exports = { 'test/data/', 'dist/', ], - env: { 'es6': true }, + env: { es6: true }, parser: '@typescript-eslint/parser', - parserOptions: { 'project': true, 'sourceType': 'module' }, + parserOptions: { project: true, sourceType: 'module' }, plugins: ['@typescript-eslint', 'prefer-arrow', 'header', 'import', 'unicorn', 'jsdoc'], rules: { 'unicorn/filename-case': 'error', 'header/header': [ - 2, 'line', [ - ' Copyright (c) Microsoft Corporation. All rights reserved.', - ' Licensed under the MIT License.' - ], 2 + 2, + 'line', + [' Copyright (c) Microsoft Corporation. All rights reserved.', ' Licensed under the MIT License.'], + 2, + ], + 'import/no-extraneous-dependencies': ['error', { devDependencies: false }], + 'import/no-internal-modules': [ + 'error', + { + allow: ['**/lib/**'], + }, ], - 'import/no-extraneous-dependencies': ['error', { 'devDependencies': false }], - 'import/no-internal-modules': ['error', { - 'allow': ['**/lib/**'], - }], 'import/no-unassigned-import': 'error', - '@typescript-eslint/array-type': ['error', { 'default': 'array-simple' }], + '@typescript-eslint/array-type': ['error', { default: 'array-simple' }], '@typescript-eslint/await-thenable': 'error', '@typescript-eslint/ban-types': [ - 'error', { - 'types': { - 'Object': { 'message': 'Use {} instead.' }, - 'String': { 'message': 'Use \'string\' instead.' }, - 'Number': { 'message': 'Use \'number\' instead.' }, - 'Boolean': { 'message': 'Use \'boolean\' instead.' } - } - } + 'error', + { + types: { + Object: { message: 'Use {} instead.' }, + String: { message: "Use 'string' instead." }, + Number: { message: "Use 'number' instead." }, + Boolean: { message: "Use 'boolean' instead." }, + }, + }, ], '@typescript-eslint/naming-convention': 'error', '@typescript-eslint/consistent-type-assertions': 'error', '@typescript-eslint/member-delimiter-style': [ - 'error', { - 'multiline': { 'delimiter': 'semi', 'requireLast': true }, - 'singleline': { 'delimiter': 'semi', 'requireLast': false } - } + 'error', + { + multiline: { delimiter: 'semi', requireLast: true }, + singleline: { delimiter: 'semi', requireLast: false }, + }, ], '@typescript-eslint/no-empty-function': 'error', '@typescript-eslint/no-explicit-any': 'error', @@ -57,28 +62,25 @@ module.exports = { '@typescript-eslint/no-for-in-array': 'error', '@typescript-eslint/no-inferrable-types': 'error', '@typescript-eslint/no-misused-new': 'error', - '@typescript-eslint/no-namespace': ['error', { 'allowDeclarations': true }], + '@typescript-eslint/no-namespace': ['error', { allowDeclarations: true }], '@typescript-eslint/no-non-null-assertion': 'off', - '@typescript-eslint/no-require-imports': ['error', { 'allow': ['^node:']}], - '@typescript-eslint/no-var-requires': ['error', { 'allow': ['^node:']}], + '@typescript-eslint/no-require-imports': ['error', { allow: ['^node:'] }], + '@typescript-eslint/no-var-requires': ['error', { allow: ['^node:'] }], '@typescript-eslint/no-unnecessary-type-assertion': 'error', - '@typescript-eslint/no-unused-vars': ['error', { 'argsIgnorePattern': '^_' }], + '@typescript-eslint/no-unused-vars': ['error', { argsIgnorePattern: '^_' }], '@typescript-eslint/promise-function-async': 'error', - '@typescript-eslint/quotes': ['error', 'single'], '@typescript-eslint/restrict-plus-operands': 'error', '@typescript-eslint/semi': ['error', 'always'], - '@typescript-eslint/triple-slash-reference': - ['error', { 'path': 'always', 'types': 'prefer-import', 'lib': 'always' }], + '@typescript-eslint/triple-slash-reference': ['error', { path: 'always', types: 'prefer-import', lib: 'always' }], 'arrow-body-style': 'error', - 'camelcase': 'error', + camelcase: 'error', 'constructor-super': 'error', - 'curly': 'error', + curly: 'error', 'default-case': 'error', 'dot-notation': 'error', - 'eqeqeq': ['error', 'smart'], + eqeqeq: ['error', 'smart'], 'guard-for-in': 'error', 'id-match': 'error', - 'max-len': ['error', { 'code': 120, 'ignorePattern': '^import\\s.+\\sfrom\\s.+;$' }], 'new-parens': 'error', 'no-bitwise': 'error', 'no-caller': 'error', @@ -117,136 +119,159 @@ module.exports = { 'object-shorthand': 'error', 'prefer-arrow/prefer-arrow-functions': 'error', 'prefer-const': 'error', - 'radix': 'error', - 'use-isnan': 'error' + radix: 'error', + 'use-isnan': 'error', }, - overrides: [{ - files: ['node/**/*.ts'], - env: { 'es6': true, 'node': true } - }, { - files: ['common/lib/**/*.ts', 'node/lib/**/*.ts'], - rules: { - 'jsdoc/check-alignment': 'error', - 'jsdoc/check-indentation': 'error', - } - }, { - files: ['common/test/**/*.ts'], - rules: { - '@typescript-eslint/naming-convention': 'off', - 'import/no-extraneous-dependencies': 'off', - } - }, { - files: ['node/script/**/*.ts', 'node/test/**/*.ts', 'web/script/**/*.ts', 'web/test/**/*.ts'], rules: { - '@typescript-eslint/naming-convention': 'off', - '@typescript-eslint/no-empty-function': 'off', - '@typescript-eslint/no-explicit-any': 'off', - '@typescript-eslint/no-require-imports': 'off', - '@typescript-eslint/no-var-requires': 'off', - '@typescript-eslint/no-unnecessary-type-assertion': 'off', - 'camelcase': 'off', - 'prefer-arrow/prefer-arrow-functions': 'off', - 'import/no-extraneous-dependencies': 'off', - 'import/no-unassigned-import': 'off', - 'import/no-internal-modules': 'off', - 'no-console': 'off', - 'no-empty': 'off', - 'no-unused-expressions': 'off', - } - }, { - files: ['web/lib/**/*.ts'], rules: { - 'no-underscore-dangle': ['error', { - 'allow': [ - '_free', - '_malloc', - '_JsepGetNodeName', - '_JsepOutput', - '_OrtAddFreeDimensionOverride', - '_OrtAddRunConfigEntry', - '_OrtAddSessionConfigEntry', - '_OrtAppendExecutionProvider', - '_OrtBindInput', - '_OrtBindOutput', - '_OrtClearBoundOutputs', - '_OrtCreateBinding', - '_OrtCreateRunOptions', - '_OrtCreateSession', - '_OrtCreateSessionOptions', - '_OrtCreateTensor', - '_OrtEndProfiling', - '_OrtFree', - '_OrtGetInputName', - '_OrtGetInputOutputCount', - '_OrtGetLastError', - '_OrtGetOutputName', - '_OrtGetTensorData', - '_OrtInit', - '_OrtReleaseBinding', - '_OrtReleaseRunOptions', - '_OrtReleaseSession', - '_OrtReleaseSessionOptions', - '_OrtReleaseTensor', - '_OrtRun', - '_OrtRunWithBinding', - '_OrtTrainingCopyParametersFromBuffer', - '_OrtTrainingCopyParametersToBuffer', - '_OrtTrainingCreateSession', - '_OrtTrainingEvalStep', - '_OrtTrainingGetModelInputOutputCount', - '_OrtTrainingGetModelInputOutputName', - '_OrtTrainingGetParametersSize', - '_OrtTrainingLazyResetGrad', - '_OrtTrainingLoadCheckpoint', - '_OrtTrainingOptimizerStep', - '_OrtTrainingReleaseCheckpoint', - '_OrtTrainingReleaseSession', - '_OrtTrainingRunTrainStep' - ] - }] - } - }, { - files: ['web/lib/onnxjs/**/*.ts'], rules: { - // TODO: those rules are useful. should turn on them in future (webgl refactor) - '@typescript-eslint/no-empty-function': 'off', - '@typescript-eslint/explicit-module-boundary-types': 'off', - '@typescript-eslint/no-use-before-define': 'off', - '@typescript-eslint/no-unnecessary-type-assertion': 'off', - '@typescript-eslint/restrict-plus-operands': 'off', - 'import/no-internal-modules': 'off', - 'prefer-arrow/prefer-arrow-functions': 'off', - 'no-param-reassign': 'off', - 'no-underscore-dangle': 'off', - 'guard-for-in': 'off' - } - }, { - files: ['react_native/e2e/src/**/*.ts', 'react_native/e2e/src/**/*.tsx'], rules: { - '@typescript-eslint/no-non-null-assertion': 'off', - '@typescript-eslint/no-unnecessary-type-assertion': 'off', - 'unicorn/filename-case': 'off', - 'no-invalid-this': 'off', - 'no-console': 'off' - } - }, { - files: ['react_native/lib/**/*.ts'], rules: { - '@typescript-eslint/naming-convention': 'off' - } - }, { - files: ['react_native/scripts/**/*.ts'], rules: { - 'import/no-extraneous-dependencies': 'off', - 'prefer-arrow/prefer-arrow-functions': 'off', - 'no-console': 'off' - } - }, { - files: ['scripts/**/*.ts'], rules: { - 'import/no-extraneous-dependencies': 'off', - 'no-console': 'off' - } - }, { - files: ['web/lib/**/3rd-party/**/*.ts'], rules: { - 'header/header': 'off', - 'unicorn/filename-case': 'off', - '@typescript-eslint/explicit-module-boundary-types': 'off', - } - }], + overrides: [ + { + files: ['node/**/*.ts'], + env: { es6: true, node: true }, + }, + { + files: ['common/lib/**/*.ts', 'node/lib/**/*.ts'], + rules: { + 'jsdoc/check-alignment': 'error', + 'jsdoc/check-indentation': 'error', + }, + }, + { + files: ['common/test/**/*.ts'], + rules: { + '@typescript-eslint/naming-convention': 'off', + 'import/no-extraneous-dependencies': 'off', + }, + }, + { + files: ['node/script/**/*.ts', 'node/test/**/*.ts', 'web/script/**/*.ts', 'web/test/**/*.ts'], + rules: { + '@typescript-eslint/naming-convention': 'off', + '@typescript-eslint/no-empty-function': 'off', + '@typescript-eslint/no-explicit-any': 'off', + '@typescript-eslint/no-require-imports': 'off', + '@typescript-eslint/no-var-requires': 'off', + '@typescript-eslint/no-unnecessary-type-assertion': 'off', + camelcase: 'off', + 'prefer-arrow/prefer-arrow-functions': 'off', + 'import/no-extraneous-dependencies': 'off', + 'import/no-unassigned-import': 'off', + 'import/no-internal-modules': 'off', + 'no-console': 'off', + 'no-empty': 'off', + 'no-unused-expressions': 'off', + }, + }, + { + files: ['web/lib/**/*.ts'], + rules: { + 'no-underscore-dangle': [ + 'error', + { + allow: [ + '_free', + '_malloc', + '_JsepGetNodeName', + '_JsepOutput', + '_OrtAddFreeDimensionOverride', + '_OrtAddRunConfigEntry', + '_OrtAddSessionConfigEntry', + '_OrtAppendExecutionProvider', + '_OrtBindInput', + '_OrtBindOutput', + '_OrtClearBoundOutputs', + '_OrtCreateBinding', + '_OrtCreateRunOptions', + '_OrtCreateSession', + '_OrtCreateSessionOptions', + '_OrtCreateTensor', + '_OrtEndProfiling', + '_OrtFree', + '_OrtGetInputName', + '_OrtGetInputOutputCount', + '_OrtGetLastError', + '_OrtGetOutputName', + '_OrtGetTensorData', + '_OrtInit', + '_OrtReleaseBinding', + '_OrtReleaseRunOptions', + '_OrtReleaseSession', + '_OrtReleaseSessionOptions', + '_OrtReleaseTensor', + '_OrtRun', + '_OrtRunWithBinding', + '_OrtTrainingCopyParametersFromBuffer', + '_OrtTrainingCopyParametersToBuffer', + '_OrtTrainingCreateSession', + '_OrtTrainingEvalStep', + '_OrtTrainingGetModelInputOutputCount', + '_OrtTrainingGetModelInputOutputName', + '_OrtTrainingGetParametersSize', + '_OrtTrainingLazyResetGrad', + '_OrtTrainingLoadCheckpoint', + '_OrtTrainingOptimizerStep', + '_OrtTrainingReleaseCheckpoint', + '_OrtTrainingReleaseSession', + '_OrtTrainingRunTrainStep', + ], + }, + ], + }, + }, + { + files: ['web/lib/onnxjs/**/*.ts'], + rules: { + // TODO: those rules are useful. should turn on them in future (webgl refactor) + '@typescript-eslint/no-empty-function': 'off', + '@typescript-eslint/explicit-module-boundary-types': 'off', + '@typescript-eslint/no-use-before-define': 'off', + '@typescript-eslint/no-unnecessary-type-assertion': 'off', + '@typescript-eslint/restrict-plus-operands': 'off', + 'import/no-internal-modules': 'off', + 'prefer-arrow/prefer-arrow-functions': 'off', + 'no-param-reassign': 'off', + 'no-underscore-dangle': 'off', + 'guard-for-in': 'off', + }, + }, + { + files: ['react_native/e2e/src/**/*.ts', 'react_native/e2e/src/**/*.tsx'], + rules: { + '@typescript-eslint/no-non-null-assertion': 'off', + '@typescript-eslint/no-unnecessary-type-assertion': 'off', + 'unicorn/filename-case': 'off', + 'no-invalid-this': 'off', + 'no-console': 'off', + }, + }, + { + files: ['react_native/lib/**/*.ts'], + rules: { + '@typescript-eslint/naming-convention': 'off', + }, + }, + { + files: ['react_native/scripts/**/*.ts'], + rules: { + 'import/no-extraneous-dependencies': 'off', + 'prefer-arrow/prefer-arrow-functions': 'off', + 'no-console': 'off', + }, + }, + { + files: ['scripts/**/*.ts'], + rules: { + 'import/no-extraneous-dependencies': 'off', + 'no-console': 'off', + }, + }, + { + files: ['web/lib/**/3rd-party/**/*.ts'], + rules: { + 'header/header': 'off', + 'unicorn/filename-case': 'off', + '@typescript-eslint/explicit-module-boundary-types': 'off', + }, + }, + ], extends: [ 'eslint:recommended', 'plugin:@typescript-eslint/eslint-recommended', diff --git a/js/.prettierignore b/js/.prettierignore index 5571721a7a4fd..dee8c1944e3fb 100644 --- a/js/.prettierignore +++ b/js/.prettierignore @@ -11,13 +11,6 @@ dist/ **/*.cc **/*.cpp **/*.h -**/*.js -**/*.mjs -**/*.cjs -**/*.jsx -**/*.ts -**/*.mts -**/*.cts -**/*.tsx +**/*.hpp **/*.java **/*.mm diff --git a/js/.prettierrc b/js/.prettierrc index 0b909ca02d823..852d08d130193 100644 --- a/js/.prettierrc +++ b/js/.prettierrc @@ -1 +1,13 @@ -{ "printWidth": 120, "endOfLine": "auto", "singleQuote": false } +{ + "printWidth": 120, + "endOfLine": "auto", + "singleQuote": true, + "overrides": [ + { + "files": "*.jsonc", + "options": { + "trailingComma": "none" + } + } + ] +} diff --git a/js/.vscode/settings.json b/js/.vscode/settings.json index 9c2fe646d728d..0d67d6f9aa044 100644 --- a/js/.vscode/settings.json +++ b/js/.vscode/settings.json @@ -1,8 +1,4 @@ { - "[cpp]": { - "editor.formatOnSave": true, - "editor.defaultFormatter": "xaver.clang-format" - }, "[json]": { "editor.formatOnSave": true, "editor.defaultFormatter": "esbenp.prettier-vscode" @@ -17,14 +13,13 @@ }, "[javascript]": { "editor.formatOnSave": true, - "editor.defaultFormatter": "xaver.clang-format" + "editor.defaultFormatter": "esbenp.prettier-vscode" }, "[typescript]": { "editor.formatOnSave": true, - "editor.defaultFormatter": "xaver.clang-format" + "editor.defaultFormatter": "esbenp.prettier-vscode" }, - "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format", - "clang-format.style": "file", + "prettier.prettierPath": "./node_modules/prettier", "editor.detectIndentation": false, "editor.insertSpaces": true, "editor.rulers": [120], diff --git a/js/common/build.js b/js/common/build.js index b0956c608b350..39d535823400c 100644 --- a/js/common/build.js +++ b/js/common/build.js @@ -3,18 +3,18 @@ 'use strict'; -import {execSync} from 'node:child_process'; -import {writeFileSync} from 'node:fs'; -import {resolve, dirname} from 'node:path'; -import {fileURLToPath} from 'node:url'; +import { execSync } from 'node:child_process'; +import { writeFileSync } from 'node:fs'; +import { resolve, dirname } from 'node:path'; +import { fileURLToPath } from 'node:url'; const __dirname = dirname(fileURLToPath(import.meta.url)); // build the following folders: // - dist/cjs // - dist/esm -execSync('npm run build:cjs', {shell: true, stdio: 'inherit', cwd: __dirname}); -execSync('npm run build:esm', {shell: true, stdio: 'inherit', cwd: __dirname}); +execSync('npm run build:cjs', { shell: true, stdio: 'inherit', cwd: __dirname }); +execSync('npm run build:esm', { shell: true, stdio: 'inherit', cwd: __dirname }); // generate package.json files under each of the dist folders for commonJS and ESModule // this trick allows typescript to import this package as different module type diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index e90efd7b97c29..3a7bfd0fab5f6 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend} from './backend.js'; -import {InferenceSession} from './inference-session.js'; +import { Backend } from './backend.js'; +import { InferenceSession } from './inference-session.js'; interface BackendInfo { backend: Backend; @@ -31,7 +31,7 @@ export const registerBackend = (name: string, backend: Backend, priority: number if (backend && typeof backend.init === 'function' && typeof backend.createInferenceSessionHandler === 'function') { const currentBackend = backends.get(name); if (currentBackend === undefined) { - backends.set(name, {backend, priority}); + backends.set(name, { backend, priority }); } else if (currentBackend.priority > priority) { // same name is already registered with a higher priority. skip registeration. return; @@ -67,7 +67,7 @@ export const registerBackend = (name: string, backend: Backend, priority: number * @param backendName - the name of the backend. * @returns the backend instance if resolved and initialized successfully, or an error message if failed. */ -const tryResolveAndInitializeBackend = async(backendName: string): Promise => { +const tryResolveAndInitializeBackend = async (backendName: string): Promise => { const backendInfo = backends.get(backendName); if (!backendInfo) { return 'backend not found.'; @@ -107,55 +107,58 @@ const tryResolveAndInitializeBackend = async(backendName: string): Promise => { - // extract backend hints from session options - const eps = options.executionProviders || []; - const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); - const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints; - - // try to resolve and initialize all requested backends - let backend: Backend|undefined; - const errors = []; - const availableBackendNames = new Set(); - for (const backendName of backendNames) { - const resolveResult = await tryResolveAndInitializeBackend(backendName); - if (typeof resolveResult === 'string') { - errors.push({name: backendName, err: resolveResult}); - } else { - if (!backend) { - backend = resolveResult; - } - if (backend === resolveResult) { - availableBackendNames.add(backendName); - } - } - } - - // if no backend is available, throw error. +export const resolveBackendAndExecutionProviders = async ( + options: InferenceSession.SessionOptions, +): Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => { + // extract backend hints from session options + const eps = options.executionProviders || []; + const backendHints = eps.map((i) => (typeof i === 'string' ? i : i.name)); + const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints; + + // try to resolve and initialize all requested backends + let backend: Backend | undefined; + const errors = []; + const availableBackendNames = new Set(); + for (const backendName of backendNames) { + const resolveResult = await tryResolveAndInitializeBackend(backendName); + if (typeof resolveResult === 'string') { + errors.push({ name: backendName, err: resolveResult }); + } else { if (!backend) { - throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`); + backend = resolveResult; } - - // for each explicitly requested backend, if it's not available, output warning message. - for (const {name, err} of errors) { - if (backendHints.includes(name)) { - // eslint-disable-next-line no-console - console.warn(`removing requested execution provider "${ - name}" from session options because it is not available: ${err}`); - } + if (backend === resolveResult) { + availableBackendNames.add(backendName); } + } + } - const filteredEps = eps.filter(i => availableBackendNames.has(typeof i === 'string' ? i : i.name)); - - return [ - backend, new Proxy(options, { - get: (target, prop) => { - if (prop === 'executionProviders') { - return filteredEps; - } - return Reflect.get(target, prop); - } - }) - ]; - }; + // if no backend is available, throw error. + if (!backend) { + throw new Error(`no available backend found. ERR: ${errors.map((e) => `[${e.name}] ${e.err}`).join(', ')}`); + } + + // for each explicitly requested backend, if it's not available, output warning message. + for (const { name, err } of errors) { + if (backendHints.includes(name)) { + // eslint-disable-next-line no-console + console.warn( + `removing requested execution provider "${name}" from session options because it is not available: ${err}`, + ); + } + } + + const filteredEps = eps.filter((i) => availableBackendNames.has(typeof i === 'string' ? i : i.name)); + + return [ + backend, + new Proxy(options, { + get: (target, prop) => { + if (prop === 'executionProviders') { + return filteredEps; + } + return Reflect.get(target, prop); + }, + }), + ]; +}; diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index 8c07bdd5c5c4a..e27e67622aa82 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -1,17 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession} from './inference-session.js'; -import {OnnxValue} from './onnx-value.js'; -import {TrainingSession} from './training-session.js'; +import { InferenceSession } from './inference-session.js'; +import { OnnxValue } from './onnx-value.js'; +import { TrainingSession } from './training-session.js'; /** * @ignore */ export declare namespace SessionHandler { - type FeedsType = {[name: string]: OnnxValue}; - type FetchesType = {[name: string]: OnnxValue | null}; - type ReturnType = {[name: string]: OnnxValue}; + type FeedsType = { [name: string]: OnnxValue }; + type FetchesType = { [name: string]: OnnxValue | null }; + type ReturnType = { [name: string]: OnnxValue }; } /** @@ -35,8 +35,11 @@ export interface InferenceSessionHandler extends SessionHandler { startProfiling(): void; endProfiling(): void; - run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions): Promise; + run( + feeds: SessionHandler.FeedsType, + fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions, + ): Promise; } /** @@ -50,12 +53,16 @@ export interface TrainingSessionHandler extends SessionHandler { lazyResetGrad(): Promise; runTrainStep( - feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions): Promise; + feeds: SessionHandler.FeedsType, + fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions, + ): Promise; runOptimizerStep(options: InferenceSession.RunOptions): Promise; runEvalStep( - feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions): Promise; + feeds: SessionHandler.FeedsType, + fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions, + ): Promise; getParametersSize(trainableOnly: boolean): Promise; loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; @@ -73,13 +80,18 @@ export interface Backend { */ init(backendName: string): Promise; - createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise; + createInferenceSessionHandler( + uriOrBuffer: string | Uint8Array, + options?: InferenceSession.SessionOptions, + ): Promise; - createTrainingSessionHandler? - (checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, trainModelUriOrBuffer: TrainingSession.UriOrBuffer, - evalModelUriOrBuffer: TrainingSession.UriOrBuffer, optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer, - options: InferenceSession.SessionOptions): Promise; + createTrainingSessionHandler?( + checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, + trainModelUriOrBuffer: TrainingSession.UriOrBuffer, + evalModelUriOrBuffer: TrainingSession.UriOrBuffer, + optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer, + options: InferenceSession.SessionOptions, + ): Promise; } -export {registerBackend} from './backend-impl.js'; +export { registerBackend } from './backend-impl.js'; diff --git a/js/common/lib/env-impl.ts b/js/common/lib/env-impl.ts index c3e96d864dcfe..98a2fe1dc0c1c 100644 --- a/js/common/lib/env-impl.ts +++ b/js/common/lib/env-impl.ts @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from './env.js'; -import {version} from './version.js'; +import { Env } from './env.js'; +import { version } from './version.js'; type LogLevelType = Env['logLevel']; @@ -12,7 +12,7 @@ export const env: Env = { wasm: {} as Env.WebAssemblyFlags, webgl: {} as Env.WebGLFlags, webgpu: {} as Env.WebGpuFlags, - versions: {common: version}, + versions: { common: version }, set logLevel(value: LogLevelType) { if (value === undefined) { @@ -29,4 +29,4 @@ export const env: Env = { }; // set property 'logLevel' so that they can be correctly transferred to worker by `postMessage()`. -Object.defineProperty(env, 'logLevel', {enumerable: true}); +Object.defineProperty(env, 'logLevel', { enumerable: true }); diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 1a87569a115a6..642a897a90d26 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env as envImpl} from './env-impl.js'; +import { env as envImpl } from './env-impl.js'; export declare namespace Env { export type WasmPathPrefix = string; @@ -16,7 +16,7 @@ export declare namespace Env { * - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN) * - `ort-training-wasm-simd-threaded.wasm` for training build */ - wasm?: URL|string; + wasm?: URL | string; /** * Specify the override path for the main .mjs file. * @@ -27,9 +27,9 @@ export declare namespace Env { * - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN) * - `ort-training-wasm-simd-threaded.mjs` for training build */ - mjs?: URL|string; + mjs?: URL | string; } - export type WasmPrefixOrFilePaths = WasmPathPrefix|WasmFilePaths; + export type WasmPrefixOrFilePaths = WasmPathPrefix | WasmFilePaths; export interface WebAssemblyFlags { /** * set or get number of thread(s). If omitted or set to 0, number of thread(s) will be determined by system. If set @@ -78,7 +78,7 @@ export declare namespace Env { * Set a custom buffer which contains the WebAssembly binary. If this property is set, the `wasmPaths` property will * be ignored. */ - wasmBinary?: ArrayBufferLike|Uint8Array; + wasmBinary?: ArrayBufferLike | Uint8Array; /** * Set or get a boolean value indicating whether to proxy the execution of main thread to a worker thread. @@ -94,7 +94,7 @@ export declare namespace Env { * * @defaultValue `'webgl2'` */ - contextId?: 'webgl'|'webgl2'; + contextId?: 'webgl' | 'webgl2'; /** * Get the WebGL rendering context. */ @@ -110,7 +110,7 @@ export declare namespace Env { * * @defaultValue `'full'` */ - textureCacheMode?: 'initializerOnly'|'full'; + textureCacheMode?: 'initializerOnly' | 'full'; /** * Set or get the packed texture mode * @@ -150,7 +150,7 @@ export declare namespace Env { * @deprecated Use `env.webgpu.profiling.mode` instead. If `env.webgpu.profiling.mode` is set, this property will be * ignored. */ - profilingMode?: 'off'|'default'; + profilingMode?: 'off' | 'default'; /** * Set or get the profiling configuration. */ @@ -160,7 +160,7 @@ export declare namespace Env { * * @defaultValue `'off'` */ - mode?: 'off'|'default'; + mode?: 'off' | 'default'; /** * Set or get a callback function when a profiling data is received. If not set, the profiling data will be @@ -178,7 +178,7 @@ export declare namespace Env { * * @defaultValue `undefined` */ - powerPreference?: 'low-power'|'high-performance'; + powerPreference?: 'low-power' | 'high-performance'; /** * Set or get the force fallback adapter flag. * @@ -231,7 +231,7 @@ export interface Env { * * @defaultValue `'warning'` */ - logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal'; + logLevel?: 'verbose' | 'info' | 'warning' | 'error' | 'fatal'; /** * Indicate whether run in debug mode. diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index ab4c6a3e0c46b..d47ed7a331045 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {resolveBackendAndExecutionProviders} from './backend-impl.js'; -import {InferenceSessionHandler} from './backend.js'; -import {InferenceSession as InferenceSessionInterface} from './inference-session.js'; -import {OnnxValue} from './onnx-value.js'; -import {Tensor} from './tensor.js'; -import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from './trace.js'; +import { resolveBackendAndExecutionProviders } from './backend-impl.js'; +import { InferenceSessionHandler } from './backend.js'; +import { InferenceSession as InferenceSessionInterface } from './inference-session.js'; +import { OnnxValue } from './onnx-value.js'; +import { Tensor } from './tensor.js'; +import { TRACE_FUNC_BEGIN, TRACE_FUNC_END } from './trace.js'; type SessionOptions = InferenceSessionInterface.SessionOptions; type RunOptions = InferenceSessionInterface.RunOptions; @@ -20,14 +20,15 @@ export class InferenceSession implements InferenceSessionInterface { } run(feeds: FeedsType, options?: RunOptions): Promise; run(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; - async run(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + async run(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { TRACE_FUNC_BEGIN(); - const fetches: {[name: string]: OnnxValue|null} = {}; + const fetches: { [name: string]: OnnxValue | null } = {}; let options: RunOptions = {}; // check inputs if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { throw new TypeError( - '\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.'); + "'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.", + ); } let isFetchesEmpty = true; @@ -37,18 +38,18 @@ export class InferenceSession implements InferenceSessionInterface { throw new TypeError('Unexpected argument[1]: cannot be null.'); } if (arg1 instanceof Tensor) { - throw new TypeError('\'fetches\' cannot be a Tensor'); + throw new TypeError("'fetches' cannot be a Tensor"); } if (Array.isArray(arg1)) { if (arg1.length === 0) { - throw new TypeError('\'fetches\' cannot be an empty array.'); + throw new TypeError("'fetches' cannot be an empty array."); } isFetchesEmpty = false; // output names for (const name of arg1) { if (typeof name !== 'string') { - throw new TypeError('\'fetches\' must be a string array or an object.'); + throw new TypeError("'fetches' must be a string array or an object."); } if (this.outputNames.indexOf(name) === -1) { throw new RangeError(`'fetches' contains invalid output name: ${name}.`); @@ -59,7 +60,7 @@ export class InferenceSession implements InferenceSessionInterface { if (typeof arg2 === 'object' && arg2 !== null) { options = arg2; } else if (typeof arg2 !== 'undefined') { - throw new TypeError('\'options\' must be an object.'); + throw new TypeError("'options' must be an object."); } } else { // decide whether arg1 is fetches or options @@ -81,14 +82,14 @@ export class InferenceSession implements InferenceSessionInterface { if (typeof arg2 === 'object' && arg2 !== null) { options = arg2; } else if (typeof arg2 !== 'undefined') { - throw new TypeError('\'options\' must be an object.'); + throw new TypeError("'options' must be an object."); } } else { options = arg1 as RunOptions; } } } else if (typeof arg1 !== 'undefined') { - throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.'); + throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'."); } // check if all inputs are in feed @@ -108,7 +109,7 @@ export class InferenceSession implements InferenceSessionInterface { // feeds, fetches and options are prepared const results = await this.handler.run(feeds, fetches, options); - const returnValue: {[name: string]: OnnxValue} = {}; + const returnValue: { [name: string]: OnnxValue } = {}; for (const key in results) { if (Object.hasOwnProperty.call(results, key)) { const result = results[key]; @@ -129,15 +130,22 @@ export class InferenceSession implements InferenceSessionInterface { static create(path: string, options?: SessionOptions): Promise; static create(buffer: ArrayBufferLike, options?: SessionOptions): Promise; - static create(buffer: ArrayBufferLike, byteOffset: number, byteLength?: number, options?: SessionOptions): - Promise; + static create( + buffer: ArrayBufferLike, + byteOffset: number, + byteLength?: number, + options?: SessionOptions, + ): Promise; static create(buffer: Uint8Array, options?: SessionOptions): Promise; static async create( - arg0: string|ArrayBufferLike|Uint8Array, arg1?: SessionOptions|number, arg2?: number, - arg3?: SessionOptions): Promise { + arg0: string | ArrayBufferLike | Uint8Array, + arg1?: SessionOptions | number, + arg2?: number, + arg3?: SessionOptions, + ): Promise { TRACE_FUNC_BEGIN(); // either load from a file or buffer - let filePathOrUint8Array: string|Uint8Array; + let filePathOrUint8Array: string | Uint8Array; let options: SessionOptions = {}; if (typeof arg0 === 'string') { @@ -145,18 +153,19 @@ export class InferenceSession implements InferenceSessionInterface { if (typeof arg1 === 'object' && arg1 !== null) { options = arg1; } else if (typeof arg1 !== 'undefined') { - throw new TypeError('\'options\' must be an object.'); + throw new TypeError("'options' must be an object."); } } else if (arg0 instanceof Uint8Array) { filePathOrUint8Array = arg0; if (typeof arg1 === 'object' && arg1 !== null) { options = arg1; } else if (typeof arg1 !== 'undefined') { - throw new TypeError('\'options\' must be an object.'); + throw new TypeError("'options' must be an object."); } } else if ( - arg0 instanceof ArrayBuffer || - (typeof SharedArrayBuffer !== 'undefined' && arg0 instanceof SharedArrayBuffer)) { + arg0 instanceof ArrayBuffer || + (typeof SharedArrayBuffer !== 'undefined' && arg0 instanceof SharedArrayBuffer) + ) { const buffer = arg0; let byteOffset = 0; let byteLength = arg0.byteLength; @@ -165,7 +174,7 @@ export class InferenceSession implements InferenceSessionInterface { } else if (typeof arg1 === 'number') { byteOffset = arg1; if (!Number.isSafeInteger(byteOffset)) { - throw new RangeError('\'byteOffset\' must be an integer.'); + throw new RangeError("'byteOffset' must be an integer."); } if (byteOffset < 0 || byteOffset >= buffer.byteLength) { throw new RangeError(`'byteOffset' is out of range [0, ${buffer.byteLength}).`); @@ -174,7 +183,7 @@ export class InferenceSession implements InferenceSessionInterface { if (typeof arg2 === 'number') { byteLength = arg2; if (!Number.isSafeInteger(byteLength)) { - throw new RangeError('\'byteLength\' must be an integer.'); + throw new RangeError("'byteLength' must be an integer."); } if (byteLength <= 0 || byteOffset + byteLength > buffer.byteLength) { throw new RangeError(`'byteLength' is out of range (0, ${buffer.byteLength - byteOffset}].`); @@ -182,17 +191,17 @@ export class InferenceSession implements InferenceSessionInterface { if (typeof arg3 === 'object' && arg3 !== null) { options = arg3; } else if (typeof arg3 !== 'undefined') { - throw new TypeError('\'options\' must be an object.'); + throw new TypeError("'options' must be an object."); } } else if (typeof arg2 !== 'undefined') { - throw new TypeError('\'byteLength\' must be a number.'); + throw new TypeError("'byteLength' must be a number."); } } else if (typeof arg1 !== 'undefined') { - throw new TypeError('\'options\' must be an object.'); + throw new TypeError("'options' must be an object."); } filePathOrUint8Array = new Uint8Array(buffer, byteOffset, byteLength); } else { - throw new TypeError('Unexpected argument[0]: must be \'path\' or \'buffer\'.'); + throw new TypeError("Unexpected argument[0]: must be 'path' or 'buffer'."); } // resolve backend, update session options with validated EPs, and create session handler diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 069fd9b49e484..af8a8c76c8fe4 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -1,17 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession as InferenceSessionImpl} from './inference-session-impl.js'; -import {OnnxModelOptions} from './onnx-model.js'; -import {OnnxValue, OnnxValueDataLocation} from './onnx-value.js'; +import { InferenceSession as InferenceSessionImpl } from './inference-session-impl.js'; +import { OnnxModelOptions } from './onnx-model.js'; +import { OnnxValue, OnnxValueDataLocation } from './onnx-value.js'; /* eslint-disable @typescript-eslint/no-redeclare */ export declare namespace InferenceSession { // #region input/output types - type OnnxValueMapType = {readonly [name: string]: OnnxValue}; - type NullableOnnxValueMapType = {readonly [name: string]: OnnxValue | null}; + type OnnxValueMapType = { readonly [name: string]: OnnxValue }; + type NullableOnnxValueMapType = { readonly [name: string]: OnnxValue | null }; /** * A feeds (model inputs) is an object that uses input names as keys and OnnxValue as corresponding values. @@ -30,7 +30,7 @@ export declare namespace InferenceSession { * used as a pre-allocated value by the inference engine; if omitted, inference engine will allocate buffer * internally. */ - type FetchesType = readonly string[]|NullableOnnxValueMapType; + type FetchesType = readonly string[] | NullableOnnxValueMapType; /** * A inferencing return type is an object that uses output names as keys and OnnxValue as corresponding values. @@ -72,14 +72,14 @@ export declare namespace InferenceSession { * * This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend */ - freeDimensionOverrides?: {readonly [dimensionName: string]: number}; + freeDimensionOverrides?: { readonly [dimensionName: string]: number }; /** * The optimization level. * * This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend */ - graphOptimizationLevel?: 'disabled'|'basic'|'extended'|'all'; + graphOptimizationLevel?: 'disabled' | 'basic' | 'extended' | 'all'; /** * Whether enable CPU memory arena. @@ -100,7 +100,7 @@ export declare namespace InferenceSession { * * This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend */ - executionMode?: 'sequential'|'parallel'; + executionMode?: 'sequential' | 'parallel'; /** * Optimized model file path. @@ -137,7 +137,7 @@ export declare namespace InferenceSession { * * This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend */ - logSeverityLevel?: 0|1|2|3|4; + logSeverityLevel?: 0 | 1 | 2 | 3 | 4; /** * Log verbosity level. @@ -152,7 +152,7 @@ export declare namespace InferenceSession { * * This setting is available only in ONNXRuntime Web for WebGL and WebGPU EP. */ - preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation}; + preferredOutputLocation?: OnnxValueDataLocation | { readonly [outputName: string]: OnnxValueDataLocation }; /** * Whether enable graph capture. @@ -207,7 +207,10 @@ export declare namespace InferenceSession { type ExecutionProviderName = keyof ExecutionProviderOptionMap; type ExecutionProviderConfig = - ExecutionProviderOptionMap[ExecutionProviderName]|ExecutionProviderOption|ExecutionProviderName|string; + | ExecutionProviderOptionMap[ExecutionProviderName] + | ExecutionProviderOption + | ExecutionProviderName + | string; export interface ExecutionProviderOption { readonly name: string; @@ -240,7 +243,7 @@ export declare namespace InferenceSession { } export interface WebGpuExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webgpu'; - preferredLayout?: 'NCHW'|'NHWC'; + preferredLayout?: 'NCHW' | 'NHWC'; } // #region WebNN options @@ -255,9 +258,9 @@ export declare namespace InferenceSession { * @see https://www.w3.org/TR/webnn/#dictdef-mlcontextoptions */ export interface WebNNContextOptions { - deviceType?: 'cpu'|'gpu'|'npu'; + deviceType?: 'cpu' | 'gpu' | 'npu'; numThreads?: number; - powerPreference?: 'default'|'low-power'|'high-performance'; + powerPreference?: 'default' | 'low-power' | 'high-performance'; } /** @@ -275,9 +278,10 @@ export declare namespace InferenceSession { * * @see https://www.w3.org/TR/webnn/#dom-ml-createcontext */ - export interface WebNNOptionsWithMLContext extends WebNNExecutionProviderName, - Omit, - Required> { + export interface WebNNOptionsWithMLContext + extends WebNNExecutionProviderName, + Omit, + Required> { context: unknown /* MLContext */; } @@ -294,7 +298,10 @@ export declare namespace InferenceSession { /** * Options for WebNN execution provider. */ - export type WebNNExecutionProviderOption = WebNNOptionsWithoutMLContext|WebNNOptionsWithMLContext|WebNNOptionsWebGpu; + export type WebNNExecutionProviderOption = + | WebNNOptionsWithoutMLContext + | WebNNOptionsWithMLContext + | WebNNOptionsWebGpu; // #endregion @@ -362,7 +369,7 @@ export declare namespace InferenceSession { * * This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend */ - logSeverityLevel?: 0|1|2|3|4; + logSeverityLevel?: 0 | 1 | 2 | 3 | 4; /** * Log verbosity level. @@ -441,8 +448,11 @@ export interface InferenceSession { * @param options - Optional. A set of options that controls the behavior of model inference. * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. */ - run(feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions): Promise; + run( + feeds: InferenceSession.FeedsType, + fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions, + ): Promise; // #endregion @@ -524,8 +534,12 @@ export interface InferenceSessionFactory { * @param options - specify configuration for creating a new inference session. * @returns A promise that resolves to an InferenceSession object. */ - create(buffer: ArrayBufferLike, byteOffset: number, byteLength?: number, options?: InferenceSession.SessionOptions): - Promise; + create( + buffer: ArrayBufferLike, + byteOffset: number, + byteLength?: number, + options?: InferenceSession.SessionOptions, + ): Promise; /** * Create a new inference session and load model asynchronously from a Uint8Array. diff --git a/js/common/lib/onnx-model.ts b/js/common/lib/onnx-model.ts index 1cd3eedb6fcca..4000628d5909c 100644 --- a/js/common/lib/onnx-model.ts +++ b/js/common/lib/onnx-model.ts @@ -18,12 +18,12 @@ export type FileBlob = Blob; * * When it is an ArrayBuffer or SharedArrayBuffer, the whole buffer is assumed to be the file content. */ -export type FileData = Uint8Array|ArrayBufferLike; +export type FileData = Uint8Array | ArrayBufferLike; /** * Represents a file that can be loaded by the ONNX Runtime JavaScript API. */ -export type FileType = FileUrlOrPath|FileBlob|FileData; +export type FileType = FileUrlOrPath | FileBlob | FileData; /** * Represents an external data file. @@ -44,7 +44,7 @@ export interface ExternalDataFileDescription { * * When using a string, it should be a file URL or path that in the same directory as the model file. */ -export type ExternalDataFileType = ExternalDataFileDescription|FileUrlOrPath; +export type ExternalDataFileType = ExternalDataFileDescription | FileUrlOrPath; /** * Options for model loading. diff --git a/js/common/lib/onnx-value.ts b/js/common/lib/onnx-value.ts index 72369ce8b4209..9dd1cc52b14a1 100644 --- a/js/common/lib/onnx-value.ts +++ b/js/common/lib/onnx-value.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from './tensor.js'; +import { Tensor } from './tensor.js'; export type NonTensorType = never; @@ -10,7 +10,7 @@ export type NonTensorType = never; * * NOTE: currently not support non-tensor */ -export type OnnxValue = Tensor|NonTensorType; +export type OnnxValue = Tensor | NonTensorType; /** * Type OnnxValueDataLocation represents the location of the data of an OnnxValue. diff --git a/js/common/lib/tensor-conversion-impl.ts b/js/common/lib/tensor-conversion-impl.ts index b1de48a10c0e1..743d0e6b352c6 100644 --- a/js/common/lib/tensor-conversion-impl.ts +++ b/js/common/lib/tensor-conversion-impl.ts @@ -1,18 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js'; -import {Tensor} from './tensor.js'; +import { TensorToDataUrlOptions, TensorToImageDataOptions } from './tensor-conversion.js'; +import { Tensor } from './tensor.js'; /** * implementation of Tensor.toDataURL() */ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions): string => { - const canvas = typeof document !== 'undefined' ? document.createElement('canvas') : (new OffscreenCanvas(1, 1)); + const canvas = typeof document !== 'undefined' ? document.createElement('canvas') : new OffscreenCanvas(1, 1); canvas.width = tensor.dims[3]; canvas.height = tensor.dims[2]; - const pixels2DContext = - canvas.getContext('2d') as (CanvasRenderingContext2D | OffscreenCanvasRenderingContext2D | null); + const pixels2DContext = canvas.getContext('2d') as + | CanvasRenderingContext2D + | OffscreenCanvasRenderingContext2D + | null; if (pixels2DContext != null) { // Default values for height and width & format @@ -21,7 +23,8 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions if (options?.tensorLayout !== undefined && options.tensorLayout === 'NHWC') { width = tensor.dims[2]; height = tensor.dims[3]; - } else { // Default layout is NCWH + } else { + // Default layout is NCWH width = tensor.dims[3]; height = tensor.dims[2]; } @@ -34,7 +37,7 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions if (norm === undefined || norm.mean === undefined) { normMean = [255, 255, 255, 255]; } else { - if (typeof (norm.mean) === 'number') { + if (typeof norm.mean === 'number') { normMean = [norm.mean, norm.mean, norm.mean, norm.mean]; } else { normMean = [norm.mean[0], norm.mean[1], norm.mean[2], 0]; @@ -46,7 +49,7 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions if (norm === undefined || norm.bias === undefined) { normBias = [0, 0, 0, 0]; } else { - if (typeof (norm.bias) === 'number') { + if (typeof norm.bias === 'number') { normBias = [norm.bias, norm.bias, norm.bias, norm.bias]; } else { normBias = [norm.bias[0], norm.bias[1], norm.bias[2], 0]; @@ -58,7 +61,10 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions const stride = height * width; // Default pointer assignments - let rTensorPointer = 0, gTensorPointer = stride, bTensorPointer = stride * 2, aTensorPointer = -1; + let rTensorPointer = 0, + gTensorPointer = stride, + bTensorPointer = stride * 2, + aTensorPointer = -1; // Updating the pointer assignments based on the input image format if (inputformat === 'RGBA') { @@ -78,12 +84,10 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions for (let i = 0; i < height; i++) { for (let j = 0; j < width; j++) { - const R = ((tensor.data[rTensorPointer++] as number) - normBias[0]) * normMean[0]; // R value - const G = ((tensor.data[gTensorPointer++] as number) - normBias[1]) * normMean[1]; // G value - const B = ((tensor.data[bTensorPointer++] as number) - normBias[2]) * normMean[2]; // B value - const A = aTensorPointer === -1 ? - 255 : - ((tensor.data[aTensorPointer++] as number) - normBias[3]) * normMean[3]; // A value + const R = ((tensor.data[rTensorPointer++] as number) - normBias[0]) * normMean[0]; // R value + const G = ((tensor.data[gTensorPointer++] as number) - normBias[1]) * normMean[1]; // G value + const B = ((tensor.data[bTensorPointer++] as number) - normBias[2]) * normMean[2]; // B value + const A = aTensorPointer === -1 ? 255 : ((tensor.data[aTensorPointer++] as number) - normBias[3]) * normMean[3]; // A value // eslint-disable-next-line @typescript-eslint/restrict-plus-operands pixels2DContext.fillStyle = 'rgba(' + R + ',' + G + ',' + B + ',' + A + ')'; pixels2DContext.fillRect(j, i, 1, 1); @@ -103,9 +107,10 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions * implementation of Tensor.toImageData() */ export const tensorToImageData = (tensor: Tensor, options?: TensorToImageDataOptions): ImageData => { - const pixels2DContext = typeof document !== 'undefined' ? - document.createElement('canvas').getContext('2d') : - new OffscreenCanvas(1, 1).getContext('2d') as OffscreenCanvasRenderingContext2D; + const pixels2DContext = + typeof document !== 'undefined' + ? document.createElement('canvas').getContext('2d') + : (new OffscreenCanvas(1, 1).getContext('2d') as OffscreenCanvasRenderingContext2D); let image: ImageData; if (pixels2DContext != null) { // Default values for height and width & format @@ -116,7 +121,8 @@ export const tensorToImageData = (tensor: Tensor, options?: TensorToImageDataOpt width = tensor.dims[2]; height = tensor.dims[1]; channels = tensor.dims[3]; - } else { // Default layout is NCWH + } else { + // Default layout is NCWH width = tensor.dims[3]; height = tensor.dims[2]; channels = tensor.dims[1]; @@ -129,7 +135,7 @@ export const tensorToImageData = (tensor: Tensor, options?: TensorToImageDataOpt if (norm === undefined || norm.mean === undefined) { normMean = [255, 255, 255, 255]; } else { - if (typeof (norm.mean) === 'number') { + if (typeof norm.mean === 'number') { normMean = [norm.mean, norm.mean, norm.mean, norm.mean]; } else { normMean = [norm.mean[0], norm.mean[1], norm.mean[2], 255]; @@ -141,7 +147,7 @@ export const tensorToImageData = (tensor: Tensor, options?: TensorToImageDataOpt if (norm === undefined || norm.bias === undefined) { normBias = [0, 0, 0, 0]; } else { - if (typeof (norm.bias) === 'number') { + if (typeof norm.bias === 'number') { normBias = [norm.bias, norm.bias, norm.bias, norm.bias]; } else { normBias = [norm.bias[0], norm.bias[1], norm.bias[2], 0]; @@ -153,16 +159,24 @@ export const tensorToImageData = (tensor: Tensor, options?: TensorToImageDataOpt const stride = height * width; if (options !== undefined) { - if (options.format !== undefined && (channels === 4 && options.format !== 'RGBA') || - (channels === 3 && (options.format !== 'RGB' && options.format !== 'BGR'))) { - throw new Error('Tensor format doesn\'t match input tensor dims'); + if ( + (options.format !== undefined && channels === 4 && options.format !== 'RGBA') || + (channels === 3 && options.format !== 'RGB' && options.format !== 'BGR') + ) { + throw new Error("Tensor format doesn't match input tensor dims"); } } // Default pointer assignments const step = 4; - let rImagePointer = 0, gImagePointer = 1, bImagePointer = 2, aImagePointer = 3; - let rTensorPointer = 0, gTensorPointer = stride, bTensorPointer = stride * 2, aTensorPointer = -1; + let rImagePointer = 0, + gImagePointer = 1, + bImagePointer = 2, + aImagePointer = 3; + let rTensorPointer = 0, + gTensorPointer = stride, + bTensorPointer = stride * 2, + aTensorPointer = -1; // Updating the pointer assignments based on the input image format if (inputformat === 'RGBA') { @@ -182,16 +196,17 @@ export const tensorToImageData = (tensor: Tensor, options?: TensorToImageDataOpt image = pixels2DContext.createImageData(width, height); - for (let i = 0; i < height * width; - rImagePointer += step, gImagePointer += step, bImagePointer += step, aImagePointer += step, i++) { - image.data[rImagePointer] = ((tensor.data[rTensorPointer++] as number) - normBias[0]) * normMean[0]; // R value - image.data[gImagePointer] = ((tensor.data[gTensorPointer++] as number) - normBias[1]) * normMean[1]; // G value - image.data[bImagePointer] = ((tensor.data[bTensorPointer++] as number) - normBias[2]) * normMean[2]; // B value - image.data[aImagePointer] = aTensorPointer === -1 ? - 255 : - ((tensor.data[aTensorPointer++] as number) - normBias[3]) * normMean[3]; // A value + for ( + let i = 0; + i < height * width; + rImagePointer += step, gImagePointer += step, bImagePointer += step, aImagePointer += step, i++ + ) { + image.data[rImagePointer] = ((tensor.data[rTensorPointer++] as number) - normBias[0]) * normMean[0]; // R value + image.data[gImagePointer] = ((tensor.data[gTensorPointer++] as number) - normBias[1]) * normMean[1]; // G value + image.data[bImagePointer] = ((tensor.data[bTensorPointer++] as number) - normBias[2]) * normMean[2]; // B value + image.data[aImagePointer] = + aTensorPointer === -1 ? 255 : ((tensor.data[aTensorPointer++] as number) - normBias[3]) * normMean[3]; // A value } - } else { throw new Error('Can not access image data'); } diff --git a/js/common/lib/tensor-conversion.ts b/js/common/lib/tensor-conversion.ts index 4542b3b4a773c..b6b3b911e7b2d 100644 --- a/js/common/lib/tensor-conversion.ts +++ b/js/common/lib/tensor-conversion.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {OptionsFormat, OptionsNormalizationParameters, OptionsTensorLayout} from './tensor-factory.js'; +import { OptionsFormat, OptionsNormalizationParameters, OptionsTensorLayout } from './tensor-factory.js'; export interface TensorToDataUrlOptions extends OptionsTensorLayout, OptionsFormat, OptionsNormalizationParameters {} diff --git a/js/common/lib/tensor-factory-impl.ts b/js/common/lib/tensor-factory-impl.ts index 19c62cb54bfed..52e028a9fcd31 100644 --- a/js/common/lib/tensor-factory-impl.ts +++ b/js/common/lib/tensor-factory-impl.ts @@ -1,12 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {OptionsDimensions, OptionsFormat, OptionsNormalizationParameters, OptionsTensorFormat, OptionsTensorLayout, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions} from './tensor-factory.js'; -import {Tensor} from './tensor-impl.js'; -import {Tensor as TensorInterface} from './tensor.js'; - -interface BufferToTensorOptions extends OptionsDimensions, OptionsTensorLayout, OptionsNormalizationParameters, - OptionsFormat, OptionsTensorFormat {} +import { + OptionsDimensions, + OptionsFormat, + OptionsNormalizationParameters, + OptionsTensorFormat, + OptionsTensorLayout, + TensorFromGpuBufferOptions, + TensorFromImageBitmapOptions, + TensorFromImageDataOptions, + TensorFromImageElementOptions, + TensorFromTextureOptions, + TensorFromUrlOptions, +} from './tensor-factory.js'; +import { Tensor } from './tensor-impl.js'; +import { Tensor as TensorInterface } from './tensor.js'; + +interface BufferToTensorOptions + extends OptionsDimensions, + OptionsTensorLayout, + OptionsNormalizationParameters, + OptionsFormat, + OptionsTensorFormat {} /** * Create a new tensor object from image object @@ -15,7 +31,7 @@ interface BufferToTensorOptions extends OptionsDimensions, OptionsTensorLayout, * @param imageFormat - input image configuration - required configurations height, width, format * @param tensorFormat - output tensor configuration - Default is RGB format */ -export const bufferToTensor = (buffer: Uint8ClampedArray|undefined, options: BufferToTensorOptions): Tensor => { +export const bufferToTensor = (buffer: Uint8ClampedArray | undefined, options: BufferToTensorOptions): Tensor => { if (buffer === undefined) { throw new Error('Image buffer must be defined'); } @@ -26,19 +42,19 @@ export const bufferToTensor = (buffer: Uint8ClampedArray|undefined, options: Buf throw new Error('NHWC Tensor layout is not supported yet'); } - const {height, width} = options; + const { height, width } = options; - const norm = options.norm ?? {mean: 255, bias: 0}; + const norm = options.norm ?? { mean: 255, bias: 0 }; let normMean: [number, number, number, number]; let normBias: [number, number, number, number]; - if (typeof (norm.mean) === 'number') { + if (typeof norm.mean === 'number') { normMean = [norm.mean, norm.mean, norm.mean, norm.mean]; } else { normMean = [norm.mean![0], norm.mean![1], norm.mean![2], norm.mean![3] ?? 255]; } - if (typeof (norm.bias) === 'number') { + if (typeof norm.bias === 'number') { normBias = [norm.bias, norm.bias, norm.bias, norm.bias]; } else { normBias = [norm.bias![0], norm.bias![1], norm.bias![2], norm.bias![3] ?? 0]; @@ -48,13 +64,20 @@ export const bufferToTensor = (buffer: Uint8ClampedArray|undefined, options: Buf // default value is RGBA since imagedata and HTMLImageElement uses it const outputformat = - options.tensorFormat !== undefined ? (options.tensorFormat !== undefined ? options.tensorFormat : 'RGB') : 'RGB'; + options.tensorFormat !== undefined ? (options.tensorFormat !== undefined ? options.tensorFormat : 'RGB') : 'RGB'; const stride = height * width; const float32Data = outputformat === 'RGBA' ? new Float32Array(stride * 4) : new Float32Array(stride * 3); // Default pointer assignments - let step = 4, rImagePointer = 0, gImagePointer = 1, bImagePointer = 2, aImagePointer = 3; - let rTensorPointer = 0, gTensorPointer = stride, bTensorPointer = stride * 2, aTensorPointer = -1; + let step = 4, + rImagePointer = 0, + gImagePointer = 1, + bImagePointer = 2, + aImagePointer = 3; + let rTensorPointer = 0, + gTensorPointer = stride, + bTensorPointer = stride * 2, + aTensorPointer = -1; // Updating the pointer assignments based on the input image format if (inputformat === 'RGB') { @@ -78,8 +101,11 @@ export const bufferToTensor = (buffer: Uint8ClampedArray|undefined, options: Buf rTensorPointer = stride * 2; } - for (let i = 0; i < stride; - i++, rImagePointer += step, bImagePointer += step, gImagePointer += step, aImagePointer += step) { + for ( + let i = 0; + i < stride; + i++, rImagePointer += step, bImagePointer += step, gImagePointer += step, aImagePointer += step + ) { float32Data[rTensorPointer++] = (buffer[rImagePointer] + normBias[0]) / normMean[0]; float32Data[gTensorPointer++] = (buffer[gImagePointer] + normBias[1]) / normMean[1]; float32Data[bTensorPointer++] = (buffer[bImagePointer] + normBias[2]) / normMean[2]; @@ -89,25 +115,31 @@ export const bufferToTensor = (buffer: Uint8ClampedArray|undefined, options: Buf } // Float32Array -> ort.Tensor - const outputTensor = outputformat === 'RGBA' ? new Tensor('float32', float32Data, [1, 4, height, width]) : - new Tensor('float32', float32Data, [1, 3, height, width]); + const outputTensor = + outputformat === 'RGBA' + ? new Tensor('float32', float32Data, [1, 4, height, width]) + : new Tensor('float32', float32Data, [1, 3, height, width]); return outputTensor; }; /** * implementation of Tensor.fromImage(). */ -export const tensorFromImage = async( - image: ImageData|HTMLImageElement|ImageBitmap|string, - options?: TensorFromImageDataOptions|TensorFromImageElementOptions|TensorFromImageBitmapOptions| - TensorFromUrlOptions): Promise => { +export const tensorFromImage = async ( + image: ImageData | HTMLImageElement | ImageBitmap | string, + options?: + | TensorFromImageDataOptions + | TensorFromImageElementOptions + | TensorFromImageBitmapOptions + | TensorFromUrlOptions, +): Promise => { // checking the type of image object - const isHTMLImageEle = typeof (HTMLImageElement) !== 'undefined' && image instanceof HTMLImageElement; - const isImageDataEle = typeof (ImageData) !== 'undefined' && image instanceof ImageData; - const isImageBitmap = typeof (ImageBitmap) !== 'undefined' && image instanceof ImageBitmap; + const isHTMLImageEle = typeof HTMLImageElement !== 'undefined' && image instanceof HTMLImageElement; + const isImageDataEle = typeof ImageData !== 'undefined' && image instanceof ImageData; + const isImageBitmap = typeof ImageBitmap !== 'undefined' && image instanceof ImageBitmap; const isString = typeof image === 'string'; - let data: Uint8ClampedArray|undefined; + let data: Uint8ClampedArray | undefined; let bufferToTensorOptions: BufferToTensorOptions = options ?? {}; const createCanvas = () => { @@ -119,7 +151,7 @@ export const tensorFromImage = async( throw new Error('Canvas is not supported'); } }; - const createCanvasContext = (canvas: HTMLCanvasElement|OffscreenCanvas) => { + const createCanvasContext = (canvas: HTMLCanvasElement | OffscreenCanvas) => { if (canvas instanceof HTMLCanvasElement) { return canvas.getContext('2d'); } else if (canvas instanceof OffscreenCanvas) { @@ -258,25 +290,31 @@ export const tensorFromImage = async( * implementation of Tensor.fromTexture(). */ export const tensorFromTexture = ( - texture: TensorInterface.TextureType, options: TensorFromTextureOptions): Tensor => { - const {width, height, download, dispose} = options; + texture: TensorInterface.TextureType, + options: TensorFromTextureOptions, +): Tensor => { + const { width, height, download, dispose } = options; // Always assume RGBAF32. TODO: support different texture format const dims = [1, height, width, 4]; - return new Tensor({location: 'texture', type: 'float32', texture, dims, download, dispose}); + return new Tensor({ location: 'texture', type: 'float32', texture, dims, download, dispose }); }; /** * implementation of Tensor.fromGpuBuffer(). */ export const tensorFromGpuBuffer = ( - gpuBuffer: TensorInterface.GpuBufferType, options: TensorFromGpuBufferOptions): Tensor => { - const {dataType, dims, download, dispose} = options; - return new Tensor({location: 'gpu-buffer', type: dataType ?? 'float32', gpuBuffer, dims, download, dispose}); + gpuBuffer: TensorInterface.GpuBufferType, + options: TensorFromGpuBufferOptions, +): Tensor => { + const { dataType, dims, download, dispose } = options; + return new Tensor({ location: 'gpu-buffer', type: dataType ?? 'float32', gpuBuffer, dims, download, dispose }); }; /** * implementation of Tensor.fromPinnedBuffer(). */ export const tensorFromPinnedBuffer = ( - type: T, buffer: TensorInterface.DataTypeMap[T], dims?: readonly number[]): Tensor => - new Tensor({location: 'cpu-pinned', type, data: buffer, dims: dims ?? [buffer.length]}); + type: T, + buffer: TensorInterface.DataTypeMap[T], + dims?: readonly number[], +): Tensor => new Tensor({ location: 'cpu-pinned', type, data: buffer, dims: dims ?? [buffer.length] }); diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts index 431de4c3635c2..7938b4a4eb927 100644 --- a/js/common/lib/tensor-factory.ts +++ b/js/common/lib/tensor-factory.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor, TypedTensor} from './tensor.js'; +import { Tensor, TypedTensor } from './tensor.js'; -export type ImageFormat = 'RGB'|'RGBA'|'BGR'|'RBG'; -export type ImageTensorLayout = 'NHWC'|'NCHW'; +export type ImageFormat = 'RGB' | 'RGBA' | 'BGR' | 'RBG'; +export type ImageTensorLayout = 'NHWC' | 'NCHW'; // the following region contains type definitions for constructing tensor from a specific location. @@ -42,8 +42,8 @@ interface GpuResourceConstructorParameters { /** * represent the parameter for constructing a tensor from a pinned CPU buffer */ -export interface CpuPinnedConstructorParameters extends - CommonConstructorParameters { +export interface CpuPinnedConstructorParameters + extends CommonConstructorParameters { /** * Specify the location of the data to be 'cpu-pinned'. */ @@ -57,8 +57,9 @@ export interface CpuPinnedConstructorParameters extends - CommonConstructorParameters, GpuResourceConstructorParameters { +export interface TextureConstructorParameters + extends CommonConstructorParameters, + GpuResourceConstructorParameters { /** * Specify the location of the data to be 'texture'. */ @@ -72,8 +73,9 @@ export interface TextureConstructorParameters extends - CommonConstructorParameters, GpuResourceConstructorParameters { +export interface GpuBufferConstructorParameters + extends CommonConstructorParameters, + GpuResourceConstructorParameters { /** * Specify the location of the data to be 'gpu-buffer'. */ @@ -112,7 +114,7 @@ export interface OptionsTensorDataType { /** * Describes the data type of the tensor. */ - dataType?: 'float32'|'uint8'; + dataType?: 'float32' | 'uint8'; } export interface OptionsTensorLayout { @@ -158,7 +160,7 @@ export interface OptionsNormalizationParameters { * - If it's an array of 3 or 4 numbers, apply element-wise. Number of elements need to match the number of channels * for the corresponding image format */ - bias?: number|[number, number, number]|[number, number, number, number]; + bias?: number | [number, number, number] | [number, number, number, number]; /** * The 'mean' value for image normalization. * - If omitted, use default value 255. @@ -174,25 +176,43 @@ export interface OptionsNormalizationParameters { // #region Options composition -export interface TensorFromImageDataOptions extends OptionResizedDimensions, OptionsTensorFormat, OptionsTensorLayout, - OptionsTensorDataType, OptionsNormalizationParameters {} - -export interface TensorFromImageElementOptions extends OptionResizedDimensions, OptionsTensorFormat, - OptionsTensorLayout, OptionsTensorDataType, - OptionsNormalizationParameters {} - -export interface TensorFromUrlOptions extends OptionsDimensions, OptionResizedDimensions, OptionsTensorFormat, - OptionsTensorLayout, OptionsTensorDataType, - OptionsNormalizationParameters {} - -export interface TensorFromImageBitmapOptions extends OptionResizedDimensions, OptionsTensorFormat, OptionsTensorLayout, - OptionsTensorDataType, OptionsNormalizationParameters {} - -export interface TensorFromTextureOptions extends - Required, OptionsFormat, GpuResourceConstructorParameters/* TODO: add more */ {} - -export interface TensorFromGpuBufferOptions extends - Pick, GpuResourceConstructorParameters { +export interface TensorFromImageDataOptions + extends OptionResizedDimensions, + OptionsTensorFormat, + OptionsTensorLayout, + OptionsTensorDataType, + OptionsNormalizationParameters {} + +export interface TensorFromImageElementOptions + extends OptionResizedDimensions, + OptionsTensorFormat, + OptionsTensorLayout, + OptionsTensorDataType, + OptionsNormalizationParameters {} + +export interface TensorFromUrlOptions + extends OptionsDimensions, + OptionResizedDimensions, + OptionsTensorFormat, + OptionsTensorLayout, + OptionsTensorDataType, + OptionsNormalizationParameters {} + +export interface TensorFromImageBitmapOptions + extends OptionResizedDimensions, + OptionsTensorFormat, + OptionsTensorLayout, + OptionsTensorDataType, + OptionsNormalizationParameters {} + +export interface TensorFromTextureOptions + extends Required, + OptionsFormat, + GpuResourceConstructorParameters /* TODO: add more */ {} + +export interface TensorFromGpuBufferOptions + extends Pick, + GpuResourceConstructorParameters { /** * Describes the data type of the tensor. */ @@ -218,8 +238,10 @@ export interface TensorFactory { * - `dataType`: `'float32'` * @returns A promise that resolves to a tensor object */ - fromImage(imageData: ImageData, options?: TensorFromImageDataOptions): - Promise|TypedTensor<'uint8'>>; + fromImage( + imageData: ImageData, + options?: TensorFromImageDataOptions, + ): Promise | TypedTensor<'uint8'>>; /** * create a tensor from a HTMLImageElement object @@ -233,8 +255,10 @@ export interface TensorFactory { * - `dataType`: `'float32'` * @returns A promise that resolves to a tensor object */ - fromImage(imageElement: HTMLImageElement, options?: TensorFromImageElementOptions): - Promise|TypedTensor<'uint8'>>; + fromImage( + imageElement: HTMLImageElement, + options?: TensorFromImageElementOptions, + ): Promise | TypedTensor<'uint8'>>; /** * create a tensor from URL @@ -248,7 +272,7 @@ export interface TensorFactory { * - `dataType`: `'float32'` * @returns A promise that resolves to a tensor object */ - fromImage(urlSource: string, options?: TensorFromUrlOptions): Promise|TypedTensor<'uint8'>>; + fromImage(urlSource: string, options?: TensorFromUrlOptions): Promise | TypedTensor<'uint8'>>; /** * create a tensor from an ImageBitmap object @@ -262,8 +286,10 @@ export interface TensorFactory { * - `dataType`: `'float32'` * @returns A promise that resolves to a tensor object */ - fromImage(bitmap: ImageBitmap, options: TensorFromImageBitmapOptions): - Promise|TypedTensor<'uint8'>>; + fromImage( + bitmap: ImageBitmap, + options: TensorFromImageBitmapOptions, + ): Promise | TypedTensor<'uint8'>>; /** * create a tensor from a WebGL texture @@ -284,7 +310,9 @@ export interface TensorFactory { * @returns a tensor object */ fromTexture( - texture: Tensor.TextureType, options: TensorFromTextureOptions): TypedTensor<'float32'>; + texture: Tensor.TextureType, + options: TensorFromTextureOptions, + ): TypedTensor<'float32'>; /** * create a tensor from a WebGPU buffer @@ -304,7 +332,9 @@ export interface TensorFactory { * @returns a tensor object */ fromGpuBuffer( - buffer: Tensor.GpuBufferType, options: TensorFromGpuBufferOptions): TypedTensor; + buffer: Tensor.GpuBufferType, + options: TensorFromGpuBufferOptions, + ): TypedTensor; /** * create a tensor from a pre-allocated buffer. The buffer will be used as a pinned buffer. @@ -316,5 +346,8 @@ export interface TensorFactory { * @returns a tensor object */ fromPinnedBuffer>( - type: T, buffer: Tensor.DataTypeMap[T], dims?: readonly number[]): TypedTensor; + type: T, + buffer: Tensor.DataTypeMap[T], + dims?: readonly number[], + ): TypedTensor; } diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index b29cb8cbd6d35..8e68ba31348ca 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -1,11 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from './tensor.js'; +import { Tensor } from './tensor.js'; -export type SupportedTypedArrayConstructors = Float32ArrayConstructor|Uint8ArrayConstructor|Int8ArrayConstructor| - Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor|Uint8ArrayConstructor| - Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor; +export type SupportedTypedArrayConstructors = + | Float32ArrayConstructor + | Uint8ArrayConstructor + | Int8ArrayConstructor + | Uint16ArrayConstructor + | Int16ArrayConstructor + | Int32ArrayConstructor + | BigInt64ArrayConstructor + | Uint8ArrayConstructor + | Float64ArrayConstructor + | Uint32ArrayConstructor + | BigUint64ArrayConstructor; export type SupportedTypedArray = InstanceType; // a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap. diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 56682ef98e117..cb2e467fead8c 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -1,13 +1,34 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js'; -import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js'; -import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js'; -import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js'; -import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; -import {calculateSize, tensorReshape} from './tensor-utils-impl.js'; -import {Tensor as TensorInterface} from './tensor.js'; +import { tensorToDataURL, tensorToImageData } from './tensor-conversion-impl.js'; +import { TensorToDataUrlOptions, TensorToImageDataOptions } from './tensor-conversion.js'; +import { + tensorFromGpuBuffer, + tensorFromImage, + tensorFromPinnedBuffer, + tensorFromTexture, +} from './tensor-factory-impl.js'; +import { + CpuPinnedConstructorParameters, + GpuBufferConstructorParameters, + TensorFromGpuBufferOptions, + TensorFromImageBitmapOptions, + TensorFromImageDataOptions, + TensorFromImageElementOptions, + TensorFromTextureOptions, + TensorFromUrlOptions, + TextureConstructorParameters, +} from './tensor-factory.js'; +import { + checkTypedArray, + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, + NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, + SupportedTypedArray, + SupportedTypedArrayConstructors, +} from './tensor-impl-type-mapping.js'; +import { calculateSize, tensorReshape } from './tensor-utils-impl.js'; +import { Tensor as TensorInterface } from './tensor.js'; // type aliases for those exported from Tensor interface @@ -29,12 +50,14 @@ export class Tensor implements TensorInterface { * Construct a new CPU tensor object from the given type, data and dims. */ constructor( - type: TensorType, data: TensorDataType|readonly string[]|readonly number[]|readonly boolean[], - dims?: readonly number[]); + type: TensorType, + data: TensorDataType | readonly string[] | readonly number[] | readonly boolean[], + dims?: readonly number[], + ); /** * Construct a new CPU tensor object from the given data and dims. Type is inferred from data. */ - constructor(data: TensorDataType|readonly string[]|readonly boolean[], dims?: readonly number[]); + constructor(data: TensorDataType | readonly string[] | readonly boolean[], dims?: readonly number[]); /** * Construct a new tensor object from the pinned CPU data with the given type and dims. * @@ -64,9 +87,17 @@ export class Tensor implements TensorInterface { * implementation. */ constructor( - arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters| - TextureConstructorParameters|GpuBufferConstructorParameters, - arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) { + arg0: + | TensorType + | TensorDataType + | readonly string[] + | readonly boolean[] + | CpuPinnedConstructorParameters + | TextureConstructorParameters + | GpuBufferConstructorParameters, + arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[], + arg2?: readonly number[], + ) { // perform one-time check for BigInt/Float16Array support checkTypedArray(); @@ -102,8 +133,15 @@ export class Tensor implements TensorInterface { break; } case 'gpu-buffer': { - if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' && - type !== 'uint8' && type !== 'bool')) { + if ( + type !== 'float32' && + type !== 'float16' && + type !== 'int32' && + type !== 'int64' && + type !== 'uint32' && + type !== 'uint8' && + type !== 'bool' + ) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } this.gpuBufferData = arg0.gpuBuffer; @@ -119,7 +157,7 @@ export class Tensor implements TensorInterface { // constructing tensor of location 'cpu' // let data: TensorDataType; - let maybeDims: typeof arg1|typeof arg2; + let maybeDims: typeof arg1 | typeof arg2; // check whether arg0 is type or data if (typeof arg0 === 'string') { // @@ -130,7 +168,7 @@ export class Tensor implements TensorInterface { if (arg0 === 'string') { // string tensor if (!Array.isArray(arg1)) { - throw new TypeError('A string tensor\'s data must be a string array.'); + throw new TypeError("A string tensor's data must be a string array."); } // we don't check whether every element in the array is string; this is too slow. we assume it's correct and // error will be populated at inference @@ -149,7 +187,8 @@ export class Tensor implements TensorInterface { // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call // Uint16Array.from(arg1) which generates wrong data. throw new TypeError( - 'Creating a float16 tensor from number array is not supported. Please use Uint16Array as data.'); + 'Creating a float16 tensor from number array is not supported. Please use Uint16Array as data.', + ); } else if (arg0 === 'uint64' || arg0 === 'int64') { // use 'as any' here because: // 1. TypeScript's check on type of 'Array.isArray()' does not work with readonly arrays. @@ -199,8 +238,9 @@ export class Tensor implements TensorInterface { } } else { // get tensor type from TypedArray - const mappedType = - NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get(arg0.constructor as SupportedTypedArrayConstructors); + const mappedType = NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get( + arg0.constructor as SupportedTypedArrayConstructors, + ); if (mappedType === undefined) { throw new TypeError(`Unsupported type for tensor data: ${arg0.constructor}.`); } @@ -214,7 +254,7 @@ export class Tensor implements TensorInterface { // assume 1-D tensor if dims omitted maybeDims = [data.length]; } else if (!Array.isArray(maybeDims)) { - throw new TypeError('A tensor\'s dims must be a number array'); + throw new TypeError("A tensor's dims must be a number array"); } dims = maybeDims as readonly number[]; @@ -237,24 +277,35 @@ export class Tensor implements TensorInterface { // #region factory static async fromImage( - image: ImageData|HTMLImageElement|ImageBitmap|string, - options?: TensorFromImageDataOptions|TensorFromImageElementOptions|TensorFromImageBitmapOptions| - TensorFromUrlOptions): Promise { + image: ImageData | HTMLImageElement | ImageBitmap | string, + options?: + | TensorFromImageDataOptions + | TensorFromImageElementOptions + | TensorFromImageBitmapOptions + | TensorFromUrlOptions, + ): Promise { return tensorFromImage(image, options); } static fromTexture( - texture: TensorTextureType, options: TensorFromTextureOptions): TensorInterface { + texture: TensorTextureType, + options: TensorFromTextureOptions, + ): TensorInterface { return tensorFromTexture(texture, options); } static fromGpuBuffer( - gpuBuffer: TensorGpuBufferType, options: TensorFromGpuBufferOptions): TensorInterface { + gpuBuffer: TensorGpuBufferType, + options: TensorFromGpuBufferOptions, + ): TensorInterface { return tensorFromGpuBuffer(gpuBuffer, options); } static fromPinnedBuffer( - type: T, buffer: TensorInterface.DataTypeMap[T], dims?: readonly number[]): Tensor { + type: T, + buffer: TensorInterface.DataTypeMap[T], + dims?: readonly number[], + ): Tensor { return tensorFromPinnedBuffer(type, buffer, dims); } @@ -319,8 +370,9 @@ export class Tensor implements TensorInterface { this.ensureValid(); if (!this.cpuData) { throw new Error( - 'The data is not on CPU. Use `getData()` to download GPU data to CPU, ' + - 'or use `texture` or `gpuBuffer` property to access the GPU data directly.'); + 'The data is not on CPU. Use `getData()` to download GPU data to CPU, ' + + 'or use `texture` or `gpuBuffer` property to access the GPU data directly.', + ); } return this.cpuData; } @@ -375,7 +427,6 @@ export class Tensor implements TensorInterface { } return data; - } finally { this.isDownloading = false; } diff --git a/js/common/lib/tensor-utils-impl.ts b/js/common/lib/tensor-utils-impl.ts index bd3080b724651..9c633cd95fac3 100644 --- a/js/common/lib/tensor-utils-impl.ts +++ b/js/common/lib/tensor-utils-impl.ts @@ -1,8 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TextureConstructorParameters} from './tensor-factory.js'; -import {Tensor} from './tensor-impl.js'; +import { + CpuPinnedConstructorParameters, + GpuBufferConstructorParameters, + TextureConstructorParameters, +} from './tensor-factory.js'; +import { Tensor } from './tensor-impl.js'; /** * calculate size from dims. diff --git a/js/common/lib/tensor-utils.ts b/js/common/lib/tensor-utils.ts index b24075aad2953..a732560adb6ae 100644 --- a/js/common/lib/tensor-utils.ts +++ b/js/common/lib/tensor-utils.ts @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {ConversionUtils} from './tensor-conversion.js'; -import {Tensor, TypedTensor} from './tensor.js'; +import { ConversionUtils } from './tensor-conversion.js'; +import { Tensor, TypedTensor } from './tensor.js'; interface Properties { /** diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 20319ebb800c2..6b4165a222791 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorFactory} from './tensor-factory.js'; -import {Tensor as TensorImpl} from './tensor-impl.js'; -import {TypedTensorUtils} from './tensor-utils.js'; +import { TensorFactory } from './tensor-factory.js'; +import { Tensor as TensorImpl } from './tensor-impl.js'; +import { TypedTensorUtils } from './tensor-utils.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -74,7 +74,7 @@ export declare namespace Tensor { int64: BigInt64Array; string: string[]; bool: Uint8Array; - float16: Uint16Array; // Keep using Uint16Array until we have a concrete solution for float 16. + float16: Uint16Array; // Keep using Uint16Array until we have a concrete solution for float 16. float64: Float64Array; uint32: Uint32Array; uint64: BigUint64Array; @@ -93,7 +93,7 @@ export declare namespace Tensor { int64: bigint; string: string; bool: boolean; - float16: number; // Keep using Uint16Array until we have a concrete solution for float 16. + float16: number; // Keep using Uint16Array until we have a concrete solution for float 16. float64: number; uint32: number; uint64: bigint; @@ -130,17 +130,17 @@ export declare namespace Tensor { * * for more info see https://github.com/gpuweb/types/issues/127 */ - export type GpuBufferType = {size: number; mapState: 'unmapped' | 'pending' | 'mapped'}; + export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' }; /** * supported data types for constructing a tensor from a WebGPU buffer */ - export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool'; + export type GpuBufferDataTypes = 'float32' | 'float16' | 'int32' | 'int64' | 'uint32' | 'uint8' | 'bool'; /** * represent where the tensor data is stored */ - export type DataLocation = 'none'|'cpu'|'cpu-pinned'|'texture'|'gpu-buffer'; + export type DataLocation = 'none' | 'cpu' | 'cpu-pinned' | 'texture' | 'gpu-buffer'; /** * represent the data type of a tensor @@ -169,8 +169,11 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(type: 'string', data: Tensor.DataTypeMap['string']|readonly string[], - dims?: readonly number[]): TypedTensor<'string'>; + new ( + type: 'string', + data: Tensor.DataTypeMap['string'] | readonly string[], + dims?: readonly number[], + ): TypedTensor<'string'>; /** * Construct a new bool tensor object from the given type, data and dims. @@ -179,7 +182,11 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(type: 'bool', data: Tensor.DataTypeMap['bool']|readonly boolean[], dims?: readonly number[]): TypedTensor<'bool'>; + new ( + type: 'bool', + data: Tensor.DataTypeMap['bool'] | readonly boolean[], + dims?: readonly number[], + ): TypedTensor<'bool'>; /** * Construct a new 64-bit integer typed tensor object from the given type, data and dims. @@ -188,9 +195,11 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new( - type: T, data: Tensor.DataTypeMap[T]|readonly bigint[]|readonly number[], - dims?: readonly number[]): TypedTensor; + new ( + type: T, + data: Tensor.DataTypeMap[T] | readonly bigint[] | readonly number[], + dims?: readonly number[], + ): TypedTensor; /** * Construct a new numeric tensor object from the given type, data and dims. @@ -199,8 +208,11 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new>( - type: T, data: Tensor.DataTypeMap[T]|readonly number[], dims?: readonly number[]): TypedTensor; + new >( + type: T, + data: Tensor.DataTypeMap[T] | readonly number[], + dims?: readonly number[], + ): TypedTensor; // #endregion // #region CPU tensor - infer element types @@ -211,7 +223,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: Float32Array, dims?: readonly number[]): TypedTensor<'float32'>; + new (data: Float32Array, dims?: readonly number[]): TypedTensor<'float32'>; /** * Construct a new int8 tensor object from the given data and dims. @@ -219,7 +231,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: Int8Array, dims?: readonly number[]): TypedTensor<'int8'>; + new (data: Int8Array, dims?: readonly number[]): TypedTensor<'int8'>; /** * Construct a new uint8 tensor object from the given data and dims. @@ -227,7 +239,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>; + new (data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>; /** * Construct a new uint16 tensor object from the given data and dims. @@ -235,7 +247,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: Uint16Array, dims?: readonly number[]): TypedTensor<'uint16'>; + new (data: Uint16Array, dims?: readonly number[]): TypedTensor<'uint16'>; /** * Construct a new int16 tensor object from the given data and dims. @@ -243,7 +255,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: Int16Array, dims?: readonly number[]): TypedTensor<'int16'>; + new (data: Int16Array, dims?: readonly number[]): TypedTensor<'int16'>; /** * Construct a new int32 tensor object from the given data and dims. @@ -251,7 +263,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: Int32Array, dims?: readonly number[]): TypedTensor<'int32'>; + new (data: Int32Array, dims?: readonly number[]): TypedTensor<'int32'>; /** * Construct a new int64 tensor object from the given data and dims. @@ -259,7 +271,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: BigInt64Array, dims?: readonly number[]): TypedTensor<'int64'>; + new (data: BigInt64Array, dims?: readonly number[]): TypedTensor<'int64'>; /** * Construct a new string tensor object from the given data and dims. @@ -267,7 +279,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: readonly string[], dims?: readonly number[]): TypedTensor<'string'>; + new (data: readonly string[], dims?: readonly number[]): TypedTensor<'string'>; /** * Construct a new bool tensor object from the given data and dims. @@ -275,7 +287,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: readonly boolean[], dims?: readonly number[]): TypedTensor<'bool'>; + new (data: readonly boolean[], dims?: readonly number[]): TypedTensor<'bool'>; /** * Construct a new float64 tensor object from the given data and dims. @@ -283,7 +295,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: Float64Array, dims?: readonly number[]): TypedTensor<'float64'>; + new (data: Float64Array, dims?: readonly number[]): TypedTensor<'float64'>; /** * Construct a new uint32 tensor object from the given data and dims. @@ -291,7 +303,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: Uint32Array, dims?: readonly number[]): TypedTensor<'uint32'>; + new (data: Uint32Array, dims?: readonly number[]): TypedTensor<'uint32'>; /** * Construct a new uint64 tensor object from the given data and dims. @@ -299,7 +311,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: BigUint64Array, dims?: readonly number[]): TypedTensor<'uint64'>; + new (data: BigUint64Array, dims?: readonly number[]): TypedTensor<'uint64'>; // #endregion @@ -312,8 +324,11 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(type: Tensor.Type, data: Tensor.DataType|readonly number[]|readonly string[]|readonly bigint[]|readonly boolean[], - dims?: readonly number[]): Tensor; + new ( + type: Tensor.Type, + data: Tensor.DataType | readonly number[] | readonly string[] | readonly bigint[] | readonly boolean[], + dims?: readonly number[], + ): Tensor; /** * Construct a new tensor object from the given data and dims. @@ -321,7 +336,7 @@ export interface TensorConstructor extends TensorFactory { * @param data - Specify the CPU tensor data. * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. */ - new(data: Tensor.DataType, dims?: readonly number[]): Tensor; + new (data: Tensor.DataType, dims?: readonly number[]): Tensor; // #endregion } diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts index 44ad6cacb4bb4..25d178f15a29d 100644 --- a/js/common/lib/trace.ts +++ b/js/common/lib/trace.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env} from './env-impl.js'; +import { env } from './env-impl.js'; /** * @ignore diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index bae38b0dfda5a..21dbe5fe51bb9 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {resolveBackendAndExecutionProviders} from './backend-impl.js'; -import {SessionHandler, TrainingSessionHandler} from './backend.js'; -import {InferenceSession as InferenceSession} from './inference-session.js'; -import {OnnxValue} from './onnx-value.js'; -import {Tensor} from './tensor.js'; -import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; +import { resolveBackendAndExecutionProviders } from './backend-impl.js'; +import { SessionHandler, TrainingSessionHandler } from './backend.js'; +import { InferenceSession as InferenceSession } from './inference-session.js'; +import { OnnxValue } from './onnx-value.js'; +import { Tensor } from './tensor.js'; +import { TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions } from './training-session.js'; type SessionOptions = InferenceSession.SessionOptions; type FeedsType = InferenceSession.FeedsType; @@ -14,8 +14,8 @@ type FetchesType = InferenceSession.FetchesType; type ReturnType = InferenceSession.ReturnType; type RunOptions = InferenceSession.RunOptions; -const noBackendErrMsg: string = 'Training backend could not be resolved. ' + - 'Make sure you\'re using the correct configuration & WebAssembly files.'; +const noBackendErrMsg: string = + 'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files."; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) { @@ -49,18 +49,24 @@ export class TrainingSession implements TrainingSessionInterface { } } - static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): - Promise { - const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; - const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; + static async create( + trainingOptions: TrainingSessionCreateOptions, + sessionOptions?: SessionOptions, + ): Promise { + const evalModel: string | Uint8Array = trainingOptions.evalModel || ''; + const optimizerModel: string | Uint8Array = trainingOptions.optimizerModel || ''; const options: SessionOptions = sessionOptions || {}; // resolve backend, update session options with validated EPs, and create session handler const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); if (backend.createTrainingSessionHandler) { const handler = await backend.createTrainingSessionHandler( - trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, - optionsWithValidatedEPs); + trainingOptions.checkpointState, + trainingOptions.trainModel, + evalModel, + optimizerModel, + optionsWithValidatedEPs, + ); return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); } else { throw new Error(noBackendErrMsg); @@ -81,14 +87,19 @@ export class TrainingSession implements TrainingSessionInterface { * @returns */ typeNarrowingForRunStep( - inputNames: readonly string[], outputNames: readonly string[], feeds: FeedsType, arg1?: FetchesType|RunOptions, - arg2?: RunOptions): [SessionHandler.FetchesType, RunOptions] { - const fetches: {[name: string]: OnnxValue|null} = {}; + inputNames: readonly string[], + outputNames: readonly string[], + feeds: FeedsType, + arg1?: FetchesType | RunOptions, + arg2?: RunOptions, + ): [SessionHandler.FetchesType, RunOptions] { + const fetches: { [name: string]: OnnxValue | null } = {}; let options: RunOptions = {}; // check inputs if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { throw new TypeError( - '\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.'); + "'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.", + ); } let isFetchesEmpty = true; @@ -98,18 +109,18 @@ export class TrainingSession implements TrainingSessionInterface { throw new TypeError('Unexpected argument[1]: cannot be null.'); } if (arg1 instanceof Tensor) { - throw new TypeError('\'fetches\' cannot be a Tensor'); + throw new TypeError("'fetches' cannot be a Tensor"); } if (Array.isArray(arg1)) { if (arg1.length === 0) { - throw new TypeError('\'fetches\' cannot be an empty array.'); + throw new TypeError("'fetches' cannot be an empty array."); } isFetchesEmpty = false; // output names for (const name of arg1) { if (typeof name !== 'string') { - throw new TypeError('\'fetches\' must be a string array or an object.'); + throw new TypeError("'fetches' must be a string array or an object."); } if (outputNames.indexOf(name) === -1) { throw new RangeError(`'fetches' contains invalid output name: ${name}.`); @@ -120,7 +131,7 @@ export class TrainingSession implements TrainingSessionInterface { if (typeof arg2 === 'object' && arg2 !== null) { options = arg2; } else if (typeof arg2 !== 'undefined') { - throw new TypeError('\'options\' must be an object.'); + throw new TypeError("'options' must be an object."); } } else { // decide whether arg1 is fetches or options @@ -142,14 +153,14 @@ export class TrainingSession implements TrainingSessionInterface { if (typeof arg2 === 'object' && arg2 !== null) { options = arg2; } else if (typeof arg2 !== 'undefined') { - throw new TypeError('\'options\' must be an object.'); + throw new TypeError("'options' must be an object."); } } else { options = arg1 as RunOptions; } } } else if (typeof arg1 !== 'undefined') { - throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.'); + throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'."); } // check if all inputs are in feed @@ -177,7 +188,7 @@ export class TrainingSession implements TrainingSessionInterface { * @returns */ convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType { - const returnValue: {[name: string]: OnnxValue} = {}; + const returnValue: { [name: string]: OnnxValue } = {}; for (const key in results) { if (Object.hasOwnProperty.call(results, key)) { const result = results[key]; @@ -197,14 +208,19 @@ export class TrainingSession implements TrainingSessionInterface { runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; - async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { - const [fetches, options] = - this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2); + async runTrainStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { + const [fetches, options] = this.typeNarrowingForRunStep( + this.trainingInputNames, + this.trainingOutputNames, + feeds, + arg1, + arg2, + ); const results = await this.handler.runTrainStep(feeds, fetches, options); return this.convertHandlerReturnTypeToMapOfTensors(results); } - async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise { + async runOptimizerStep(options?: InferenceSession.RunOptions | undefined): Promise { if (this.hasOptimizerModel) { await this.handler.runOptimizerStep(options || {}); } else { @@ -212,12 +228,17 @@ export class TrainingSession implements TrainingSessionInterface { } } - runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise; - runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise; - async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + runEvalStep(feeds: FeedsType, options?: RunOptions | undefined): Promise; + runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions | undefined): Promise; + async runEvalStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { if (this.hasEvalModel) { - const [fetches, options] = - this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2); + const [fetches, options] = this.typeNarrowingForRunStep( + this.evalInputNames, + this.evalOutputNames, + feeds, + arg1, + arg2, + ); const results = await this.handler.runEvalStep(feeds, fetches, options); return this.convertHandlerReturnTypeToMapOfTensors(results); } else { @@ -235,8 +256,9 @@ export class TrainingSession implements TrainingSessionInterface { // of parameters if (array.length !== 4 * paramsSize) { throw new Error( - 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + - 'the model. Please use getParametersSize method to check.'); + 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + + 'the model. Please use getParametersSize method to check.', + ); } return this.handler.loadParametersBuffer(array, trainableOnly); } diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index f9de77e3ac7d0..45dcafc46deb5 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession} from './inference-session.js'; -import {OnnxValue} from './onnx-value.js'; -import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; +import { InferenceSession } from './inference-session.js'; +import { OnnxValue } from './onnx-value.js'; +import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -11,7 +11,7 @@ export declare namespace TrainingSession { /** * Either URI file path (string) or Uint8Array containing model or checkpoint information. */ - type UriOrBuffer = string|Uint8Array; + type UriOrBuffer = string | Uint8Array; } /** @@ -36,8 +36,10 @@ export interface TrainingSession { * @param options - Optional. A set of options that controls the behavior of model training. * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. */ - runTrainStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions): - Promise; + runTrainStep( + feeds: InferenceSession.FeedsType, + options?: InferenceSession.RunOptions, + ): Promise; /** * Run a single train step with the given inputs and options. @@ -50,8 +52,10 @@ export interface TrainingSession { values. */ runTrainStep( - feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions): Promise; + feeds: InferenceSession.FeedsType, + fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions, + ): Promise; /** * Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model. @@ -68,8 +72,10 @@ export interface TrainingSession { * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. */ - runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions): - Promise; + runEvalStep( + feeds: InferenceSession.FeedsType, + options?: InferenceSession.RunOptions, + ): Promise; /** * Run a single eval step with the given inputs and options using the eval model. @@ -82,8 +88,10 @@ export interface TrainingSession { values. */ runEvalStep( - feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions): Promise; + feeds: InferenceSession.FeedsType, + fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions, + ): Promise; // #endregion @@ -186,8 +194,10 @@ export interface TrainingSessionFactory { * * @returns Promise that resolves to a TrainingSession object */ - create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: InferenceSession.SessionOptions): - Promise; + create( + trainingOptions: TrainingSessionCreateOptions, + sessionOptions?: InferenceSession.SessionOptions, + ): Promise; // #endregion } diff --git a/js/common/test/type-tests.ts b/js/common/test/type-tests.ts index afa53a694514d..70681bb420e5f 100644 --- a/js/common/test/type-tests.ts +++ b/js/common/test/type-tests.ts @@ -3,9 +3,9 @@ import globby from 'globby'; import assert from 'node:assert'; -import {readFileSync} from 'node:fs'; -import {dirname, join, normalize, relative} from 'node:path'; -import {fileURLToPath} from 'node:url'; +import { readFileSync } from 'node:fs'; +import { dirname, join, normalize, relative } from 'node:path'; +import { fileURLToPath } from 'node:url'; import npmlog from 'npmlog'; import typescript from 'typescript'; @@ -46,20 +46,19 @@ const TYPE_TESTS_DIR = join(dirname(fileURLToPath(import.meta.url)), './type-tes * @returns list of test files */ const prepareTestFileList = () => - // - globby.sync('**/*.ts', { - cwd: TYPE_TESTS_DIR, - absolute: true, - }); + // + globby.sync('**/*.ts', { + cwd: TYPE_TESTS_DIR, + absolute: true, + }); /** * Run typescript compiler on the given files. */ const compileTypeScriptFiles = (filepaths: string[]): readonly typescript.Diagnostic[] => { // TypeScript compiler options, base URL is reset to `TYPE_TESTS_DIR`. - const compilerOptions = - JSON.parse(readFileSync(new URL('./type-tests/tsconfig.json', import.meta.url), 'utf-8')).compilerOptions as - typescript.CompilerOptions; + const compilerOptions = JSON.parse(readFileSync(new URL('./type-tests/tsconfig.json', import.meta.url), 'utf-8')) + .compilerOptions as typescript.CompilerOptions; compilerOptions.baseUrl = TYPE_TESTS_DIR; // Run TypeScript compiler @@ -81,39 +80,40 @@ const prepareTestCases = () => { npmlog.info('PrepareTestCases', `Preparing test file lists... DONE, ${testFiles.length} file(s) in total.`); npmlog.info('PrepareTestCases', 'Running TypeScript Compiler...'); - const compileResult = compileTypeScriptFiles(testFiles).map( - diagnostic => ({ - fileName: normalize(diagnostic.file?.fileName ?? ''), - line: diagnostic.file?.getLineAndCharacterOfPosition(diagnostic.start!)?.line ?? -1, - code: diagnostic.code, - })); + const compileResult = compileTypeScriptFiles(testFiles).map((diagnostic) => ({ + fileName: normalize(diagnostic.file?.fileName ?? ''), + line: diagnostic.file?.getLineAndCharacterOfPosition(diagnostic.start!)?.line ?? -1, + code: diagnostic.code, + })); npmlog.info('PrepareTestCases', 'Running TypeScript Compiler... DONE.'); npmlog.info('PrepareTestCases', 'Parsing test source files for expected failures...'); - const testCases = testFiles.map(filepath => { + const testCases = testFiles.map((filepath) => { const normalizedFilePath = normalize(filepath); const normalizedRelativePath = normalize(relative(TYPE_TESTS_DIR, filepath)); - const fileAllLines = readFileSync(filepath, 'utf-8').split('\n').map(line => line.trim()); - const expectedFailures: Array<{line: number; code: number}> = []; + const fileAllLines = readFileSync(filepath, 'utf-8') + .split('\n') + .map((line) => line.trim()); + const expectedFailures: Array<{ line: number; code: number }> = []; fileAllLines.forEach((line, i) => { if (line.startsWith('// {type-tests}|fail|')) { const splitted = line.split('|'); assert(splitted.length === 4, `invalid expected failure comment: ${line}`); const lineOffset = Number.parseInt(splitted[2], 10); const code = Number.parseInt(splitted[3], 10); - expectedFailures.push({line: i + lineOffset, code}); + expectedFailures.push({ line: i + lineOffset, code }); } }); const actualFailures: typeof compileResult = []; - return {filepath: normalizedFilePath, relativePath: normalizedRelativePath, expectedFailures, actualFailures}; + return { filepath: normalizedFilePath, relativePath: normalizedRelativePath, expectedFailures, actualFailures }; }); npmlog.info('PrepareTestCases', 'Parsing test source files for expected failures... DONE.'); // now check if file names is matched - const filePathToTestCaseMap = new Map(testCases.map(testCase => [testCase.filepath, testCase])); + const filePathToTestCaseMap = new Map(testCases.map((testCase) => [testCase.filepath, testCase])); for (const error of compileResult) { // check file name exists assert(error.fileName, 'Each compile error should have a file name. Please check TypeScript compiler options.'); @@ -125,15 +125,15 @@ const prepareTestCases = () => { testCase.actualFailures.push(error); } - return testCases.map(testCase => { - const {relativePath, expectedFailures, actualFailures} = testCase; + return testCases.map((testCase) => { + const { relativePath, expectedFailures, actualFailures } = testCase; const testFunction = () => { if (expectedFailures.length === 0) { assert.equal(actualFailures.length, 0, `expected to pass but failed: ${JSON.stringify(actualFailures)}`); } else { - actualFailures.forEach(error => { - const {line, code} = error; - const foundIndex = expectedFailures.findIndex(f => f.line === line && f.code === code); + actualFailures.forEach((error) => { + const { line, code } = error; + const foundIndex = expectedFailures.findIndex((f) => f.line === line && f.code === code); assert.notEqual(foundIndex, -1, `unexpected failure: line=${line}, code=${code}`); expectedFailures.splice(foundIndex, 1); }); @@ -141,12 +141,12 @@ const prepareTestCases = () => { } }; - return {title: relativePath, testBody: testFunction}; + return { title: relativePath, testBody: testFunction }; }); }; describe('TypeScript type tests', () => { - for (const {title, testBody} of prepareTestCases()) { + for (const { title, testBody } of prepareTestCases()) { it(title, testBody); } }); diff --git a/js/common/test/type-tests/tensor/create-new-bool.ts b/js/common/test/type-tests/tensor/create-new-bool.ts index 8692af97bd07a..017fc1ca0d6f5 100644 --- a/js/common/test/type-tests/tensor/create-new-bool.ts +++ b/js/common/test/type-tests/tensor/create-new-bool.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor, TypedTensor} from 'onnxruntime-common'; +import { Tensor, TypedTensor } from 'onnxruntime-common'; // construct from type, data (boolean array) and shape (number array) // diff --git a/js/common/test/type-tests/tensor/create-new-f32.ts b/js/common/test/type-tests/tensor/create-new-f32.ts index af24a3e8aaf3c..8e8b46deec0af 100644 --- a/js/common/test/type-tests/tensor/create-new-f32.ts +++ b/js/common/test/type-tests/tensor/create-new-f32.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor, TypedTensor} from 'onnxruntime-common'; +import { Tensor, TypedTensor } from 'onnxruntime-common'; // construct from type, data (number array) and shape (number array) // diff --git a/js/common/test/type-tests/tensor/create-new-string.ts b/js/common/test/type-tests/tensor/create-new-string.ts index d8c2870f7a879..71849cf9a4c12 100644 --- a/js/common/test/type-tests/tensor/create-new-string.ts +++ b/js/common/test/type-tests/tensor/create-new-string.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor, TypedTensor} from 'onnxruntime-common'; +import { Tensor, TypedTensor } from 'onnxruntime-common'; // construct from type, data (string array) and shape (number array) // diff --git a/js/common/test/unit-tests/common.ts b/js/common/test/unit-tests/common.ts index 49ebe872880a2..0a6e4e5dd6ebd 100644 --- a/js/common/test/unit-tests/common.ts +++ b/js/common/test/unit-tests/common.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import assert from 'assert/strict'; -import {Tensor} from 'onnxruntime-common'; +import { Tensor } from 'onnxruntime-common'; /** * A list of numerical types that are compatible with JavaScript 'number' value. @@ -26,10 +26,7 @@ export const NUMBER_COMPATIBLE_NUMERICAL_TYPES = [ /** * Big integer types */ -export const BIGINT_TYPES = [ - ['int64', BigInt64Array, true] as const, - ['uint64', BigUint64Array, true] as const, -]; +export const BIGINT_TYPES = [['int64', BigInt64Array, true] as const, ['uint64', BigUint64Array, true] as const]; /** * float16 type, data represented by Uint16Array @@ -46,7 +43,7 @@ export const ALL_NUMERICAL_TYPES = [...NUMBER_COMPATIBLE_NUMERICAL_TYPES, ...BIG /** * a helper function to assert that a value is an array of a certain type */ -export const assertIsArrayOf = (value: unknown, type: 'string'|'number'|'boolean'): void => { +export const assertIsArrayOf = (value: unknown, type: 'string' | 'number' | 'boolean'): void => { assert(Array.isArray(value), 'array should be an array'); for (let i = 0; i < value.length; i++) { assert.equal(typeof value[i], type, `array should be an array of ${type}s`); @@ -58,4 +55,4 @@ export const assertIsArrayOf = (value: unknown, type: 'string'|'number'|'boolean * * This allows to write test code to pass invalid parameters to Tensor constructor and check the behavior. */ -export const TensorAny = Tensor as unknown as {new (...args: unknown[]): Tensor}; +export const TensorAny = Tensor as unknown as { new (...args: unknown[]): Tensor }; diff --git a/js/common/test/unit-tests/tensor/constructor-type.ts b/js/common/test/unit-tests/tensor/constructor-type.ts index 891b457006ba8..def711684d7f5 100644 --- a/js/common/test/unit-tests/tensor/constructor-type.ts +++ b/js/common/test/unit-tests/tensor/constructor-type.ts @@ -2,9 +2,15 @@ // Licensed under the MIT License. import assert from 'assert/strict'; -import {Tensor} from 'onnxruntime-common'; +import { Tensor } from 'onnxruntime-common'; -import {ALL_NUMERICAL_TYPES, assertIsArrayOf, BIGINT_TYPES, NUMBER_COMPATIBLE_NUMERICAL_TYPES, TensorAny} from '../common.js'; +import { + ALL_NUMERICAL_TYPES, + assertIsArrayOf, + BIGINT_TYPES, + NUMBER_COMPATIBLE_NUMERICAL_TYPES, + TensorAny, +} from '../common.js'; describe('Tensor Constructor Tests - check types', () => { for (const [type, typedArrayConstructor, canBeInferredFromType] of ALL_NUMERICAL_TYPES) { @@ -16,8 +22,9 @@ describe('Tensor Constructor Tests - check types', () => { it(`[${type}] new Tensor(type, typedArray, dims): "tensor.data" should be instance of expected typed array`, () => { const tensor = new Tensor(type, new typedArrayConstructor(4), [2, 2]); assert( - tensor.data instanceof typedArrayConstructor, - `tensor.data should be an instance of '${typedArrayConstructor.name}'`); + tensor.data instanceof typedArrayConstructor, + `tensor.data should be an instance of '${typedArrayConstructor.name}'`, + ); }); if (canBeInferredFromType) { @@ -36,14 +43,14 @@ describe('Tensor Constructor Tests - check types', () => { }); } - for (const [type, ] of NUMBER_COMPATIBLE_NUMERICAL_TYPES) { + for (const [type] of NUMBER_COMPATIBLE_NUMERICAL_TYPES) { it(`[${type}] new Tensor(type, numbers, dims): tensor can be constructed from number array`, () => { const tensor = new Tensor(type, [1, 2, 3, 4], [2, 2]); assert.equal(tensor.type, type, `tensor.type should be '${type}'`); }); } - for (const [type, ] of BIGINT_TYPES) { + for (const [type] of BIGINT_TYPES) { it(`[${type}] new Tensor(type, numbers, dims): tensor can be constructed from number array`, () => { const tensor = new Tensor(type, [1, 2, 3, 4], [2, 2]); assert.equal(tensor.type, type, `tensor.type should be '${type}'`); @@ -57,12 +64,12 @@ describe('Tensor Constructor Tests - check types', () => { it('[string] new Tensor(\'string\', strings, dims): "tensor.type" should match type passed in', () => { const tensor = new Tensor('string', ['a', 'b', 'c', 'd'], [2, 2]); - assert.equal(tensor.type, 'string', 'tensor.type should be \'string\''); + assert.equal(tensor.type, 'string', "tensor.type should be 'string'"); }); it('[string] new Tensor(strings, dims): "tensor.data" should match inferred type', () => { const tensor = new Tensor(['a', 'b', 'c', 'd'], [2, 2]); - assert.equal(tensor.type, 'string', 'tensor.type should be \'string\''); + assert.equal(tensor.type, 'string', "tensor.type should be 'string'"); }); it('[string] new Tensor(\'string\', strings, dims): "tensor.data" should be a string array', () => { @@ -72,31 +79,33 @@ describe('Tensor Constructor Tests - check types', () => { it('[bool] new Tensor(\'bool\', booleans, dims): "tensor.type" should match type passed in', () => { const tensor = new Tensor('bool', [true, false, true, false], [2, 2]); - assert.equal(tensor.type, 'bool', 'tensor.type should be \'bool\''); + assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); }); - it('[bool] new Tensor(\'bool\', uint8Array, dims): tensor can be constructed from Uint8Array', () => { + it("[bool] new Tensor('bool', uint8Array, dims): tensor can be constructed from Uint8Array", () => { const tensor = new Tensor('bool', new Uint8Array([1, 0, 1, 0]), [2, 2]); - assert.equal(tensor.type, 'bool', 'tensor.type should be \'bool\''); + assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); }); it('[bool] new Tensor(booleans, dims): "tensor.data" should match inferred type', () => { const tensor = new Tensor([true, false, true, false], [2, 2]); - assert.equal(tensor.type, 'bool', 'tensor.type should be \'bool\''); + assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); }); it('[bool] new Tensor(\'bool\', booleans, dims): "tensor.data" should be a boolean array', () => { const tensor = new Tensor('bool', [true, false, true, false], [2, 2]); - assert(tensor.data instanceof Uint8Array, 'tensor.data should be an instance of \'Uint8Array\''); + assert(tensor.data instanceof Uint8Array, "tensor.data should be an instance of 'Uint8Array'"); }); - it('[float16] new Tensor(\'float16\', numbers, dims): ' + - 'expect to throw because it\'s not allowed to construct \'float16\' tensor from number array', - () => { - assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError); - }); + it( + "[float16] new Tensor('float16', numbers, dims): " + + "expect to throw because it's not allowed to construct 'float16' tensor from number array", + () => { + assert.throws(() => new Tensor('float16', [1, 2, 3, 4], [2, 2]), TypeError); + }, + ); - it('[badtype] new Tensor(\'a\', numbers, dims): expect to throw because \'a\' is an invalid type', () => { + it("[badtype] new Tensor('a', numbers, dims): expect to throw because 'a' is an invalid type", () => { assert.throws(() => new TensorAny('a', [1, 2, 3, 4], [2, 2]), TypeError); }); }); diff --git a/js/common/webpack.config.js b/js/common/webpack.config.js index b9d1536f4b99c..03593e7850bca 100644 --- a/js/common/webpack.config.js +++ b/js/common/webpack.config.js @@ -4,16 +4,16 @@ 'use strict'; import webpack from 'webpack'; -import {resolve} from 'node:path'; -import {DEFAULT_ES_VERSION, addCopyrightBannerPlugin} from '../webpack.shared.mjs'; +import { resolve } from 'node:path'; +import { DEFAULT_ES_VERSION, addCopyrightBannerPlugin } from '../webpack.shared.mjs'; function buildConfig({ - suffix = '.js', // '.js', '.min.js', ... - format = 'umd', // 'umd', 'commonjs' - target = 'web', // 'web', 'node' - esVersion = DEFAULT_ES_VERSION, // 'es5', 'es6', ... - mode = 'production', // 'development', 'production' - devtool = 'source-map' // 'inline-source-map', 'source-map' + suffix = '.js', // '.js', '.min.js', ... + format = 'umd', // 'umd', 'commonjs' + target = 'web', // 'web', 'node' + esVersion = DEFAULT_ES_VERSION, // 'es5', 'es6', ... + mode = 'production', // 'development', 'production' + devtool = 'source-map', // 'inline-source-map', 'source-map' }) { // output file name const filename = `ort-common${suffix}`; @@ -29,24 +29,28 @@ function buildConfig({ output: { path: resolve('./dist'), filename, - library: {name: exportName, type: format}, + library: { name: exportName, type: format }, }, resolve: { extensions: ['.ts', '.js'], - extensionAlias: {'.js': ['.ts', '.js']}, + extensionAlias: { '.js': ['.ts', '.js'] }, }, plugins: [ - new webpack.WatchIgnorePlugin({paths: [/\.js$/, /\.d\.ts$/]}), + new webpack.WatchIgnorePlugin({ paths: [/\.js$/, /\.d\.ts$/] }), addCopyrightBannerPlugin(mode, 'common', esVersion), ], module: { - rules: [{ - test: /\.ts$/, - use: [{ - loader: 'ts-loader', - options: {compilerOptions: {target: esVersion}}, - }] - }] + rules: [ + { + test: /\.ts$/, + use: [ + { + loader: 'ts-loader', + options: { compilerOptions: { target: esVersion } }, + }, + ], + }, + ], }, mode, devtool, @@ -55,9 +59,9 @@ function buildConfig({ export default (env, argv) => { return [ - buildConfig({suffix: '.es5.min.js', target: 'web', esVersion: 'es5'}), - buildConfig({suffix: '.min.js'}), - buildConfig({mode: 'development', devtool: 'inline-source-map'}), + buildConfig({ suffix: '.es5.min.js', target: 'web', esVersion: 'es5' }), + buildConfig({ suffix: '.min.js' }), + buildConfig({ mode: 'development', devtool: 'inline-source-map' }), buildConfig({ suffix: '.node.cjs', target: 'node', diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index 927953b4f1dd6..46f8b83b0c5c2 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, InferenceSession, InferenceSessionHandler, SessionHandler} from 'onnxruntime-common'; +import { Backend, InferenceSession, InferenceSessionHandler, SessionHandler } from 'onnxruntime-common'; -import {Binding, binding} from './binding'; +import { Binding, binding } from './binding'; class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; - constructor(pathOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions) { + constructor(pathOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions) { this.#inferenceSession = new binding.InferenceSession(); if (typeof pathOrBuffer === 'string') { this.#inferenceSession.loadModel(pathOrBuffer, options); @@ -33,8 +33,11 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { // TODO: implement profiling } - async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): - Promise { + async run( + feeds: SessionHandler.FeedsType, + fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions, + ): Promise { return new Promise((resolve, reject) => { setImmediate(() => { try { @@ -53,8 +56,10 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler( + pathOrBuffer: string | Uint8Array, + options?: InferenceSession.SessionOptions, + ): Promise { return new Promise((resolve, reject) => { setImmediate(() => { try { diff --git a/js/node/lib/binding.ts b/js/node/lib/binding.ts index 54b5767139904..d6d592a1665b3 100644 --- a/js/node/lib/binding.ts +++ b/js/node/lib/binding.ts @@ -1,21 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, OnnxValue} from 'onnxruntime-common'; +import { InferenceSession, OnnxValue } from 'onnxruntime-common'; type SessionOptions = InferenceSession.SessionOptions; type FeedsType = { [name: string]: OnnxValue; }; type FetchesType = { - [name: string]: OnnxValue|null; + [name: string]: OnnxValue | null; }; type ReturnType = { [name: string]: OnnxValue; }; type RunOptions = InferenceSession.RunOptions; - /** * Binding exports a simple synchronized inference session object wrap. */ @@ -33,7 +32,7 @@ export declare namespace Binding { } export interface InferenceSessionConstructor { - new(): InferenceSession; + new (): InferenceSession; } export interface SupportedBackend { @@ -44,9 +43,9 @@ export declare namespace Binding { // export native binding export const binding = - // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires - require(`../bin/napi-v3/${process.platform}/${process.arch}/onnxruntime_binding.node`) as { - // eslint-disable-next-line @typescript-eslint/naming-convention - InferenceSession: Binding.InferenceSessionConstructor; - listSupportedBackends: () => Binding.SupportedBackend[]; -}; + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + require(`../bin/napi-v3/${process.platform}/${process.arch}/onnxruntime_binding.node`) as { + // eslint-disable-next-line @typescript-eslint/naming-convention + InferenceSession: Binding.InferenceSessionConstructor; + listSupportedBackends: () => Binding.SupportedBackend[]; + }; diff --git a/js/node/lib/index.ts b/js/node/lib/index.ts index 69b1ef1d96af6..ab00219665c4b 100644 --- a/js/node/lib/index.ts +++ b/js/node/lib/index.ts @@ -2,14 +2,14 @@ // Licensed under the MIT License. export * from 'onnxruntime-common'; -export {listSupportedBackends} from './backend'; -import {registerBackend, env} from 'onnxruntime-common'; -import {version} from './version'; -import {onnxruntimeBackend, listSupportedBackends} from './backend'; +export { listSupportedBackends } from './backend'; +import { registerBackend, env } from 'onnxruntime-common'; +import { version } from './version'; +import { onnxruntimeBackend, listSupportedBackends } from './backend'; const backends = listSupportedBackends(); for (const backend of backends) { registerBackend(backend.name, onnxruntimeBackend, 100); } -Object.defineProperty(env.versions, 'node', {value: version, enumerable: true}); +Object.defineProperty(env.versions, 'node', { value: version, enumerable: true }); diff --git a/js/node/script/build.ts b/js/node/script/build.ts index 3f0f804ed368e..133d1a0d981a0 100644 --- a/js/node/script/build.ts +++ b/js/node/script/build.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {spawnSync} from 'child_process'; +import { spawnSync } from 'child_process'; import * as fs from 'fs-extra'; import minimist from 'minimist'; import * as os from 'os'; @@ -11,13 +11,13 @@ import * as path from 'path'; const buildArgs = minimist(process.argv.slice(2)); // --config=Debug|Release|RelWithDebInfo -const CONFIG: 'Debug'|'Release'|'RelWithDebInfo' = - buildArgs.config || (os.platform() === 'win32' ? 'RelWithDebInfo' : 'Release'); +const CONFIG: 'Debug' | 'Release' | 'RelWithDebInfo' = + buildArgs.config || (os.platform() === 'win32' ? 'RelWithDebInfo' : 'Release'); if (CONFIG !== 'Debug' && CONFIG !== 'Release' && CONFIG !== 'RelWithDebInfo') { throw new Error(`unrecognized config: ${CONFIG}`); } // --arch=x64|ia32|arm64|arm -const ARCH: 'x64'|'ia32'|'arm64'|'arm' = buildArgs.arch || os.arch(); +const ARCH: 'x64' | 'ia32' | 'arm64' | 'arm' = buildArgs.arch || os.arch(); if (ARCH !== 'x64' && ARCH !== 'ia32' && ARCH !== 'arm64' && ARCH !== 'arm') { throw new Error(`unrecognized architecture: ${ARCH}`); } @@ -51,7 +51,7 @@ if (REBUILD) { const args = [ 'cmake-js', - (REBUILD ? 'reconfigure' : 'configure'), + REBUILD ? 'reconfigure' : 'configure', `--arch=${ARCH}`, '--CDnapi_build_version=6', `--CDCMAKE_BUILD_TYPE=${CONFIG}`, @@ -92,12 +92,13 @@ if (os.platform() === 'darwin') { // In Windows, "npx cmake-js configure" uses a powershell script to detect the Visual Studio installation. // The script uses the environment variable LIB. If an invalid path is specified in LIB, the script will fail. // So we override the LIB environment variable to remove invalid paths. -const envOverride = os.platform() === 'win32' && process.env.LIB ? - {...process.env, LIB: process.env.LIB.split(';').filter(fs.existsSync).join(';')} : - process.env; +const envOverride = + os.platform() === 'win32' && process.env.LIB + ? { ...process.env, LIB: process.env.LIB.split(';').filter(fs.existsSync).join(';') } + : process.env; // launch cmake-js configure -const procCmakejs = spawnSync('npx', args, {shell: true, stdio: 'inherit', cwd: ROOT_FOLDER, env: envOverride}); +const procCmakejs = spawnSync('npx', args, { shell: true, stdio: 'inherit', cwd: ROOT_FOLDER, env: envOverride }); if (procCmakejs.status !== 0) { if (procCmakejs.error) { console.error(procCmakejs.error); @@ -106,8 +107,11 @@ if (procCmakejs.status !== 0) { } // launch cmake to build -const procCmake = - spawnSync('cmake', ['--build', '.', '--config', CONFIG], {shell: true, stdio: 'inherit', cwd: BUILD_FOLDER}); +const procCmake = spawnSync('cmake', ['--build', '.', '--config', CONFIG], { + shell: true, + stdio: 'inherit', + cwd: BUILD_FOLDER, +}); if (procCmake.status !== 0) { if (procCmake.error) { console.error(procCmake.error); diff --git a/js/node/script/install.js b/js/node/script/install.js index 5136fbccbfe35..b15bc03840599 100644 --- a/js/node/script/install.js +++ b/js/node/script/install.js @@ -21,7 +21,7 @@ const os = require('os'); const fs = require('fs'); const path = require('path'); const tar = require('tar'); -const {Readable} = require('stream'); +const { Readable } = require('stream'); // commandline flag: // --onnxruntime-node-install-cuda Force install the CUDA EP binaries. Try to detect the CUDA version. @@ -49,7 +49,7 @@ const ORT_VERSION = require('../package.json').version; const npm_config_local_prefix = process.env.npm_config_local_prefix; const npm_package_json = process.env.npm_package_json; const SKIP_LOCAL_INSTALL = - npm_config_local_prefix && npm_package_json && path.dirname(npm_package_json) === npm_config_local_prefix; + npm_config_local_prefix && npm_package_json && path.dirname(npm_package_json) === npm_config_local_prefix; const shouldInstall = FORCE_INSTALL || (!SKIP_LOCAL_INSTALL && IS_LINUX_X64 && BIN_FOLDER_EXISTS && !CUDA_DLL_EXISTS); if (NO_INSTALL || !shouldInstall) { @@ -59,12 +59,14 @@ if (NO_INSTALL || !shouldInstall) { // Step.2: Download the required binaries const artifactUrl = { 11: `https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-gpu-${ - ORT_VERSION}.tgz`, + ORT_VERSION + }.tgz`, 12: `https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/onnxruntime-linux-x64-gpu-cuda12-${ - ORT_VERSION}.tgz` + ORT_VERSION + }.tgz`, }[INSTALL_CUDA_FLAG || tryGetCudaVersion()]; console.log(`Downloading "${artifactUrl}"...`); -fetch(artifactUrl).then(res => { +fetch(artifactUrl).then((res) => { if (!res.ok) { throw new Error(`Failed to download the binaries: ${res.status} ${res.statusText}. @@ -81,7 +83,8 @@ Use "--onnxruntime-node-install-cuda=skip" to skip the installation. You will st ]); Readable.fromWeb(res.body) - .pipe(tar.t({ + .pipe( + tar.t({ strict: true, onentry: (entry) => { const filename = path.basename(entry.path); @@ -92,16 +95,16 @@ Use "--onnxruntime-node-install-cuda=skip" to skip the installation. You will st console.log(`Finished extracting "${filename}".`); }); } - } - })) - .on('error', (err) => { - throw new Error(`Failed to extract the binaries: ${err.message}. + }, + }), + ) + .on('error', (err) => { + throw new Error(`Failed to extract the binaries: ${err.message}. Use "--onnxruntime-node-install-cuda=skip" to skip the installation. You will still be able to use ONNX Runtime, but the CUDA EP will not be available.`); - }); + }); }); - function tryGetCudaVersion() { // Should only return 11 or 12. diff --git a/js/node/script/prepack.ts b/js/node/script/prepack.ts index 4c5941d8dae12..d7c0ff3959fc6 100644 --- a/js/node/script/prepack.ts +++ b/js/node/script/prepack.ts @@ -12,7 +12,7 @@ function updatePackageJson() { const packageSelf = fs.readJSONSync(selfPackageJsonPath); const version = packageCommon.version; packageSelf.dependencies['onnxruntime-common'] = `${version}`; - fs.writeJSONSync(selfPackageJsonPath, packageSelf, {spaces: 2}); + fs.writeJSONSync(selfPackageJsonPath, packageSelf, { spaces: 2 }); console.log('=== finished updating package.json.'); } diff --git a/js/node/src/common.h b/js/node/src/common.h index 9a2528fb8c2e4..b60d059bb673b 100644 --- a/js/node/src/common.h +++ b/js/node/src/common.h @@ -8,39 +8,42 @@ #include #include -inline void MakeStringInternal(std::ostringstream & /*ss*/) noexcept {} +inline void MakeStringInternal(std::ostringstream& /*ss*/) noexcept {} -template inline void MakeStringInternal(std::ostringstream &ss, const T &t) noexcept { ss << t; } +template +inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept { ss << t; } template -inline void MakeStringInternal(std::ostringstream &ss, const T &t, const Args &...args) noexcept { +inline void MakeStringInternal(std::ostringstream& ss, const T& t, const Args&... args) noexcept { ::MakeStringInternal(ss, t); ::MakeStringInternal(ss, args...); } -template std::string MakeString(const Args &...args) { +template +std::string MakeString(const Args&... args) { std::ostringstream ss; ::MakeStringInternal(ss, args...); return std::string(ss.str()); } -#define ORT_NAPI_THROW(ERROR, ENV, ...) \ - do { \ - throw Napi::ERROR::New((ENV), MakeString(__VA_ARGS__)); \ +#define ORT_NAPI_THROW(ERROR, ENV, ...) \ + do { \ + throw Napi::ERROR::New((ENV), MakeString(__VA_ARGS__)); \ } while (false) #define ORT_NAPI_THROW_ERROR(ENV, ...) ORT_NAPI_THROW(Error, ENV, __VA_ARGS__) #define ORT_NAPI_THROW_TYPEERROR(ENV, ...) ORT_NAPI_THROW(TypeError, ENV, __VA_ARGS__) #define ORT_NAPI_THROW_RANGEERROR(ENV, ...) ORT_NAPI_THROW(RangeError, ENV, __VA_ARGS__) -#define ORT_NAPI_THROW_IF(COND, ERROR, ENV, ...) \ - if (COND) { \ - ORT_NAPI_THROW(ERROR, ENV, __VA_ARGS__); \ +#define ORT_NAPI_THROW_IF(COND, ERROR, ENV, ...) \ + if (COND) { \ + ORT_NAPI_THROW(ERROR, ENV, __VA_ARGS__); \ } #define ORT_NAPI_THROW_ERROR_IF(COND, ENV, ...) ORT_NAPI_THROW_IF(COND, Error, ENV, __VA_ARGS__) #define ORT_NAPI_THROW_TYPEERROR_IF(COND, ENV, ...) ORT_NAPI_THROW_IF(COND, TypeError, ENV, __VA_ARGS__) #define ORT_NAPI_THROW_RANGEERROR_IF(COND, ENV, ...) ORT_NAPI_THROW_IF(COND, RangeError, ENV, __VA_ARGS__) -template Napi::Value CreateNapiArrayFrom(napi_env env, const std::vector &vec) { +template +Napi::Value CreateNapiArrayFrom(napi_env env, const std::vector& vec) { Napi::EscapableHandleScope scope(env); auto array = Napi::Array::New(env, vec.size()); for (uint32_t i = 0; i < vec.size(); i++) { diff --git a/js/node/src/directml_load_helper.cc b/js/node/src/directml_load_helper.cc index 7017f627fd3d7..6aafe4d5fa788 100644 --- a/js/node/src/directml_load_helper.cc +++ b/js/node/src/directml_load_helper.cc @@ -13,13 +13,13 @@ void LoadDirectMLDll(Napi::Env env) { GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, reinterpret_cast(&LoadDirectMLDll), &moduleHandle); - DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); + DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) { int ret = GetLastError(); if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) { pathLen *= 2; path.resize(pathLen); - getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); + getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); } else { ORT_NAPI_THROW_ERROR(env, "Failed getting path to load DirectML.dll, error code: ", ret); } diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index b85104cadc6ed..057066507621b 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -45,11 +45,10 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { return exports; } -InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo &info) - : Napi::ObjectWrap(info), initialized_(false), disposed_(false), session_(nullptr), - defaultRunOptions_(nullptr) {} +InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info) + : Napi::ObjectWrap(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {} -Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { +Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); Napi::HandleScope scope(env); @@ -69,7 +68,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { ParseSessionOptions(info[1].As(), sessionOptions); this->session_.reset(new Ort::Session(*env.GetInstanceData(), #ifdef _WIN32 - reinterpret_cast(value.Utf16Value().c_str()), + reinterpret_cast(value.Utf16Value().c_str()), #else value.Utf8Value().c_str(), #endif @@ -77,13 +76,13 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { } else if (argsLength == 4 && info[0].IsArrayBuffer() && info[1].IsNumber() && info[2].IsNumber() && info[3].IsObject()) { - void *buffer = info[0].As().Data(); + void* buffer = info[0].As().Data(); int64_t bytesOffset = info[1].As().Int64Value(); int64_t bytesLength = info[2].As().Int64Value(); ParseSessionOptions(info[3].As(), sessionOptions); this->session_.reset(new Ort::Session(*env.GetInstanceData(), - reinterpret_cast(buffer) + bytesOffset, bytesLength, + reinterpret_cast(buffer) + bytesOffset, bytesLength, sessionOptions)); } else { ORT_NAPI_THROW_TYPEERROR( @@ -119,16 +118,16 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { ? typeInfo.GetTensorTypeAndShapeInfo().GetElementType() : ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); } - } catch (Napi::Error const &e) { + } catch (Napi::Error const& e) { throw e; - } catch (std::exception const &e) { + } catch (std::exception const& e) { ORT_NAPI_THROW_ERROR(env, e.what()); } this->initialized_ = true; return env.Undefined(); } -Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo &info) { +Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); @@ -137,7 +136,7 @@ Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo &info) return scope.Escape(CreateNapiArrayFrom(env, inputNames_)); } -Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo &info) { +Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); @@ -146,7 +145,7 @@ Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo &info) return scope.Escape(CreateNapiArrayFrom(env, outputNames_)); } -Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) { +Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); @@ -161,17 +160,17 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) { auto feed = info[0].As(); auto fetch = info[1].As(); - std::vector inputNames_cstr; + std::vector inputNames_cstr; std::vector inputValues; - std::vector outputNames_cstr; + std::vector outputNames_cstr; std::vector outputValues; std::vector reuseOutput; size_t inputIndex = 0; size_t outputIndex = 0; - OrtMemoryInfo *memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release(); + OrtMemoryInfo* memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release(); try { - for (auto &name : inputNames_) { + for (auto& name : inputNames_) { if (feed.Has(name)) { inputIndex++; inputNames_cstr.push_back(name.c_str()); @@ -179,7 +178,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) { inputValues.push_back(NapiValueToOrtValue(env, value, memory_info)); } } - for (auto &name : outputNames_) { + for (auto& name : outputNames_) { if (fetch.Has(name)) { outputIndex++; outputNames_cstr.push_back(name.c_str()); @@ -207,14 +206,14 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) { } return scope.Escape(result); - } catch (Napi::Error const &e) { + } catch (Napi::Error const& e) { throw e; - } catch (std::exception const &e) { + } catch (std::exception const& e) { ORT_NAPI_THROW_ERROR(env, e.what()); } } -Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo &info) { +Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); @@ -226,12 +225,12 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo &info) { return env.Undefined(); } -Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo &info) { +Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); Napi::EscapableHandleScope scope(env); Napi::Array result = Napi::Array::New(env); - auto createObject = [&env](const std::string &name, const bool bundled) -> Napi::Object { + auto createObject = [&env](const std::string& name, const bool bundled) -> Napi::Object { Napi::Object result = Napi::Object::New(env); result.Set("name", name); result.Set("bundled", bundled); diff --git a/js/node/src/inference_session_wrap.h b/js/node/src/inference_session_wrap.h index 1e789c4814cd6..effdd83e3aa02 100644 --- a/js/node/src/inference_session_wrap.h +++ b/js/node/src/inference_session_wrap.h @@ -10,16 +10,16 @@ // class InferenceSessionWrap is a N-API object wrapper for native InferenceSession. class InferenceSessionWrap : public Napi::ObjectWrap { -public: + public: static Napi::Object Init(Napi::Env env, Napi::Object exports); - InferenceSessionWrap(const Napi::CallbackInfo &info); + InferenceSessionWrap(const Napi::CallbackInfo& info); -private: + private: /** * [sync] list supported backend list * @returns array with objects { "name": "cpu", requirementsInstalled: true } */ - static Napi::Value ListSupportedBackends(const Napi::CallbackInfo &info); + static Napi::Value ListSupportedBackends(const Napi::CallbackInfo& info); /** * [sync] create the session. @@ -27,7 +27,7 @@ class InferenceSessionWrap : public Napi::ObjectWrap { * @returns nothing * @throw error if status code != 0 */ - Napi::Value LoadModel(const Napi::CallbackInfo &info); + Napi::Value LoadModel(const Napi::CallbackInfo& info); // following functions have to be called after model is loaded. @@ -37,14 +37,14 @@ class InferenceSessionWrap : public Napi::ObjectWrap { * @returns a string array. * @throw nothing */ - Napi::Value GetInputNames(const Napi::CallbackInfo &info); + Napi::Value GetInputNames(const Napi::CallbackInfo& info); /** * [sync] get output names. * @param nothing * @returns a string array. * @throw nothing */ - Napi::Value GetOutputNames(const Napi::CallbackInfo &info); + Napi::Value GetOutputNames(const Napi::CallbackInfo& info); /** * [sync] run the model. @@ -53,7 +53,7 @@ class InferenceSessionWrap : public Napi::ObjectWrap { * @returns an object that every output specified will present and value must be object * @throw error if status code != 0 */ - Napi::Value Run(const Napi::CallbackInfo &info); + Napi::Value Run(const Napi::CallbackInfo& info); /** * [sync] dispose the session. @@ -61,7 +61,7 @@ class InferenceSessionWrap : public Napi::ObjectWrap { * @returns nothing * @throw nothing */ - Napi::Value Dispose(const Napi::CallbackInfo &info); + Napi::Value Dispose(const Napi::CallbackInfo& info); // private members diff --git a/js/node/src/run_options_helper.cc b/js/node/src/run_options_helper.cc index 18f18be3df67d..352f828970c66 100644 --- a/js/node/src/run_options_helper.cc +++ b/js/node/src/run_options_helper.cc @@ -9,7 +9,7 @@ #include "common.h" #include "run_options_helper.h" -void ParseRunOptions(const Napi::Object options, Ort::RunOptions &runOptions) { +void ParseRunOptions(const Napi::Object options, Ort::RunOptions& runOptions) { // Log severity level if (options.Has("logSeverityLevel")) { auto logLevelValue = options.Get("logSeverityLevel"); diff --git a/js/node/src/run_options_helper.h b/js/node/src/run_options_helper.h index 2174973eaf9a3..104fae150bb0e 100644 --- a/js/node/src/run_options_helper.h +++ b/js/node/src/run_options_helper.h @@ -10,4 +10,4 @@ struct RunOptions; } // parse a Javascript run options object and fill the native RunOptions object. -void ParseRunOptions(const Napi::Object options, Ort::RunOptions &runOptions); +void ParseRunOptions(const Napi::Object options, Ort::RunOptions& runOptions); diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 46e08010b7835..0ed1ba08e6bf7 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -31,7 +31,7 @@ const std::unordered_map GRAPH_OPT_LEVEL_NA const std::unordered_map EXECUTION_MODE_NAME_TO_ID_MAP = {{"sequential", ORT_SEQUENTIAL}, {"parallel", ORT_PARALLEL}}; -void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions &sessionOptions) { +void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sessionOptions) { for (uint32_t i = 0; i < epList.Length(); i++) { Napi::Value epValue = epList[i]; std::string name; @@ -59,7 +59,7 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions &sess // TODO: handling CPU EP options #ifdef USE_CUDA } else if (name == "cuda") { - OrtCUDAProviderOptionsV2 *options; + OrtCUDAProviderOptionsV2* options; Ort::GetApi().CreateCUDAProviderOptions(&options); options->device_id = deviceId; sessionOptions.AppendExecutionProvider_CUDA_V2(*options); @@ -67,7 +67,7 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions &sess #endif #ifdef USE_TENSORRT } else if (name == "tensorrt") { - OrtTensorRTProviderOptionsV2 *options; + OrtTensorRTProviderOptionsV2* options; Ort::GetApi().CreateTensorRTProviderOptions(&options); options->device_id = deviceId; sessionOptions.AppendExecutionProvider_TensorRT_V2(*options); @@ -95,7 +95,7 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions &sess } } -void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions &sessionOptions) { +void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions) { // Execution provider if (options.Has("executionProviders")) { auto epsValue = options.Get("executionProviders"); diff --git a/js/node/src/session_options_helper.h b/js/node/src/session_options_helper.h index 00725468342d8..c0a9ae0d683e9 100644 --- a/js/node/src/session_options_helper.h +++ b/js/node/src/session_options_helper.h @@ -10,4 +10,4 @@ struct SessionOptions; } // parse a Javascript session options object and fill the native SessionOptions object. -void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions &sessionOptions); +void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions); diff --git a/js/node/src/tensor_helper.cc b/js/node/src/tensor_helper.cc index 1062d89f76c5f..54f1c5a09906e 100644 --- a/js/node/src/tensor_helper.cc +++ b/js/node/src/tensor_helper.cc @@ -31,82 +31,76 @@ constexpr size_t ONNX_TENSOR_ELEMENT_DATA_TYPE_COUNT = 17; // size of element in bytes for each data type. 0 indicates not supported. constexpr size_t DATA_TYPE_ELEMENT_SIZE_MAP[] = { - 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED not supported - 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT - 1, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 - 1, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 - 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 - 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 - 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 - 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 - 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING N/A - 1, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 - 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE - 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 - 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 - 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported - 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported - 0 // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported + 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED not supported + 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT + 1, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 + 1, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 + 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 + 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 + 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 + 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 + 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING N/A + 1, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL + 2, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 + 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE + 4, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 + 8, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 + 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported + 0, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported + 0 // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported }; static_assert(sizeof(DATA_TYPE_ELEMENT_SIZE_MAP) == sizeof(size_t) * ONNX_TENSOR_ELEMENT_DATA_TYPE_COUNT, "definition not matching"); constexpr napi_typedarray_type DATA_TYPE_TYPEDARRAY_MAP[] = { - (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED not supported - napi_float32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT - napi_uint8_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 - napi_int8_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 - napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 - napi_int16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 - napi_int32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 - napi_bigint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 - (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING not supported - napi_uint8_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 uses Uint16Array - napi_float64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE - napi_uint32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 - napi_biguint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 - (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported - (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported - (napi_typedarray_type)(-1) // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported + (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED not supported + napi_float32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT + napi_uint8_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 + napi_int8_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 + napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 + napi_int16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 + napi_int32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 + napi_bigint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 + (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING not supported + napi_uint8_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL + napi_uint16_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 FLOAT16 uses Uint16Array + napi_float64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE + napi_uint32_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 + napi_biguint64_array, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 + (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 not supported + (napi_typedarray_type)(-1), // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 not supported + (napi_typedarray_type)(-1) // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 not supported }; static_assert(sizeof(DATA_TYPE_TYPEDARRAY_MAP) == sizeof(napi_typedarray_type) * ONNX_TENSOR_ELEMENT_DATA_TYPE_COUNT, "definition not matching"); -constexpr const char *DATA_TYPE_ID_TO_NAME_MAP[] = { - nullptr, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED - "float32", // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT - "uint8", // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 - "int8", // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 - "uint16", // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 - "int16", // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 - "int32", // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 - "int64", // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 - "string", // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - "bool", // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - "float16", // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 - "float64", // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE - "uint32", // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 - "uint64", // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 - nullptr, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 - nullptr, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 - nullptr // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 +constexpr const char* DATA_TYPE_ID_TO_NAME_MAP[] = { + nullptr, // ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED + "float32", // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT + "uint8", // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 + "int8", // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 + "uint16", // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 + "int16", // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 + "int32", // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 + "int64", // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 + "string", // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + "bool", // ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL + "float16", // ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 + "float64", // ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE + "uint32", // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 + "uint64", // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 + nullptr, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 + nullptr, // ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 + nullptr // ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 }; -static_assert(sizeof(DATA_TYPE_ID_TO_NAME_MAP) == sizeof(const char *) * ONNX_TENSOR_ELEMENT_DATA_TYPE_COUNT, +static_assert(sizeof(DATA_TYPE_ID_TO_NAME_MAP) == sizeof(const char*) * ONNX_TENSOR_ELEMENT_DATA_TYPE_COUNT, "definition not matching"); const std::unordered_map DATA_TYPE_NAME_TO_ID_MAP = { - {"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, - {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, - {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, - {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, {"string", ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, - {"bool", ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, {"float16", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16}, - {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, - {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}}; + {"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, {"string", ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, {"bool", ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, {"float16", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16}, {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}}; // currently only support tensor -Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo *memory_info) { +Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info) { ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), env, "Tensor must be an object."); // check 'dims' @@ -144,7 +138,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo * auto tensorDataArray = tensorDataValue.As(); auto tensorDataSize = tensorDataArray.Length(); std::vector stringData; - std::vector stringDataCStr; + std::vector stringDataCStr; stringData.reserve(tensorDataSize); stringDataCStr.reserve(tensorDataSize); for (uint32_t i = 0; i < tensorDataSize; i++) { @@ -180,7 +174,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo * "Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ", tensorTypeString, " tensors, but got typed array (", typedArrayType, ")."); - char *buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); + char* buffer = reinterpret_cast(tensorDataTypedArray.ArrayBuffer().Data()); size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); size_t bufferByteLength = tensorDataTypedArray.ByteLength(); return Ort::Value::CreateTensor(memory_info, buffer + bufferByteOffset, bufferByteLength, @@ -188,7 +182,7 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo * } } -Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value &value) { +Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value) { Napi::EscapableHandleScope scope(env); auto returnValue = Napi::Object::New(env); diff --git a/js/node/src/tensor_helper.h b/js/node/src/tensor_helper.h index d5e8ef709f53e..56b399ccc24ee 100644 --- a/js/node/src/tensor_helper.h +++ b/js/node/src/tensor_helper.h @@ -9,7 +9,7 @@ #include "onnxruntime_cxx_api.h" // convert a Javascript OnnxValue object to an OrtValue object -Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo *memory_info); +Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* memory_info); // convert an OrtValue object to a Javascript OnnxValue object -Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value &value); +Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value& value); diff --git a/js/node/test/e2e/inference-session-run.ts b/js/node/test/e2e/inference-session-run.ts index faac3ceee3be0..820dec0945a8e 100644 --- a/js/node/test/e2e/inference-session-run.ts +++ b/js/node/test/e2e/inference-session-run.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, Tensor} from 'onnxruntime-common'; +import { InferenceSession, Tensor } from 'onnxruntime-common'; import * as path from 'path'; -import {assertTensorEqual, SQUEEZENET_INPUT0_DATA, SQUEEZENET_OUTPUT0_DATA, TEST_DATA_ROOT} from '../test-utils'; +import { assertTensorEqual, SQUEEZENET_INPUT0_DATA, SQUEEZENET_OUTPUT0_DATA, TEST_DATA_ROOT } from '../test-utils'; describe('E2E Tests - InferenceSession.run()', async () => { let session: InferenceSession; @@ -17,7 +17,7 @@ describe('E2E Tests - InferenceSession.run()', async () => { it('multiple run() calls', async () => { for (let i = 0; i < 1000; i++) { - const result = await session!.run({'data_0': input0}, ['softmaxout_1']); + const result = await session!.run({ data_0: input0 }, ['softmaxout_1']); assertTensorEqual(result.softmaxout_1, expectedOutput0); } }).timeout(process.arch === 'x64' ? '120s' : 0); diff --git a/js/node/test/e2e/simple-e2e-tests.ts b/js/node/test/e2e/simple-e2e-tests.ts index 70ac6ca1e0f94..6841dae316304 100644 --- a/js/node/test/e2e/simple-e2e-tests.ts +++ b/js/node/test/e2e/simple-e2e-tests.ts @@ -2,102 +2,111 @@ // Licensed under the MIT License. import assert from 'assert'; -import {InferenceSession, Tensor} from 'onnxruntime-common'; +import { InferenceSession, Tensor } from 'onnxruntime-common'; import * as path from 'path'; -import {assertDataEqual, TEST_DATA_ROOT} from '../test-utils'; +import { assertDataEqual, TEST_DATA_ROOT } from '../test-utils'; - -const MODEL_TEST_TYPES_CASES: - Array<{model: string; type: Tensor.Type; input0: Tensor.DataType; expectedOutput0: Tensor.DataType}> = [ - { - model: path.join(TEST_DATA_ROOT, 'test_types_bool.onnx'), - type: 'bool', - input0: Uint8Array.from([1, 0, 0, 1, 0]), - expectedOutput0: Uint8Array.from([1, 0, 0, 1, 0]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_double.onnx'), - type: 'float64', - input0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]), - expectedOutput0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_float.onnx'), - type: 'float32', - input0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]), - expectedOutput0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_int8.onnx'), - type: 'int8', - input0: Int8Array.from([1, -2, 3, 4, -5]), - expectedOutput0: Int8Array.from([1, -2, 3, 4, -5]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_int16.onnx'), - type: 'int16', - input0: Int16Array.from([1, -2, 3, 4, -5]), - expectedOutput0: Int16Array.from([1, -2, 3, 4, -5]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_int32.onnx'), - type: 'int32', - input0: Int32Array.from([1, -2, 3, 4, -5]), - expectedOutput0: Int32Array.from([1, -2, 3, 4, -5]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_int64.onnx'), - type: 'int64', - input0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]), - expectedOutput0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_string.onnx'), - type: 'string', - input0: ['a', 'b', 'c', 'd', 'e'], - expectedOutput0: ['a', 'b', 'c', 'd', 'e'] - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_uint8.onnx'), - type: 'uint8', - input0: Uint8Array.from([1, 2, 3, 4, 5]), - expectedOutput0: Uint8Array.from([1, 2, 3, 4, 5]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_uint16.onnx'), - type: 'uint16', - input0: Uint16Array.from([1, 2, 3, 4, 5]), - expectedOutput0: Uint16Array.from([1, 2, 3, 4, 5]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_uint32.onnx'), - type: 'uint32', - input0: Uint32Array.from([1, 2, 3, 4, 5]), - expectedOutput0: Uint32Array.from([1, 2, 3, 4, 5]) - }, - { - model: path.join(TEST_DATA_ROOT, 'test_types_uint64.onnx'), - type: 'uint64', - input0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]), - expectedOutput0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]) - }, - ]; +const MODEL_TEST_TYPES_CASES: Array<{ + model: string; + type: Tensor.Type; + input0: Tensor.DataType; + expectedOutput0: Tensor.DataType; +}> = [ + { + model: path.join(TEST_DATA_ROOT, 'test_types_bool.onnx'), + type: 'bool', + input0: Uint8Array.from([1, 0, 0, 1, 0]), + expectedOutput0: Uint8Array.from([1, 0, 0, 1, 0]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_double.onnx'), + type: 'float64', + input0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]), + expectedOutput0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_float.onnx'), + type: 'float32', + input0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]), + expectedOutput0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_int8.onnx'), + type: 'int8', + input0: Int8Array.from([1, -2, 3, 4, -5]), + expectedOutput0: Int8Array.from([1, -2, 3, 4, -5]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_int16.onnx'), + type: 'int16', + input0: Int16Array.from([1, -2, 3, 4, -5]), + expectedOutput0: Int16Array.from([1, -2, 3, 4, -5]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_int32.onnx'), + type: 'int32', + input0: Int32Array.from([1, -2, 3, 4, -5]), + expectedOutput0: Int32Array.from([1, -2, 3, 4, -5]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_int64.onnx'), + type: 'int64', + input0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]), + expectedOutput0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_string.onnx'), + type: 'string', + input0: ['a', 'b', 'c', 'd', 'e'], + expectedOutput0: ['a', 'b', 'c', 'd', 'e'], + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_uint8.onnx'), + type: 'uint8', + input0: Uint8Array.from([1, 2, 3, 4, 5]), + expectedOutput0: Uint8Array.from([1, 2, 3, 4, 5]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_uint16.onnx'), + type: 'uint16', + input0: Uint16Array.from([1, 2, 3, 4, 5]), + expectedOutput0: Uint16Array.from([1, 2, 3, 4, 5]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_uint32.onnx'), + type: 'uint32', + input0: Uint32Array.from([1, 2, 3, 4, 5]), + expectedOutput0: Uint32Array.from([1, 2, 3, 4, 5]), + }, + { + model: path.join(TEST_DATA_ROOT, 'test_types_uint64.onnx'), + type: 'uint64', + input0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]), + expectedOutput0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]), + }, +]; describe('E2E Tests - simple E2E tests', () => { - MODEL_TEST_TYPES_CASES.forEach(testCase => { + MODEL_TEST_TYPES_CASES.forEach((testCase) => { it(`${testCase.model}`, async () => { const session = await InferenceSession.create(testCase.model); - const output = await session.run({'input': new Tensor(testCase.type, testCase.input0, [1, 5])}); - assert(Object.prototype.hasOwnProperty.call(output, 'output'), '\'output\' should be in the result object.'); + const output = await session.run({ input: new Tensor(testCase.type, testCase.input0, [1, 5]) }); + assert(Object.prototype.hasOwnProperty.call(output, 'output'), "'output' should be in the result object."); assert(output.output instanceof Tensor, 'result[output] should be a Tensor object.'); assert.strictEqual(output.output.size, 5, `output size expected 5, got ${output.output.size}.`); assert.strictEqual( - output.output.type, testCase.type, `tensor type expected ${testCase.type}, got ${output.output.type}.`); + output.output.type, + testCase.type, + `tensor type expected ${testCase.type}, got ${output.output.type}.`, + ); assert.strictEqual( - Object.getPrototypeOf(output.output.data), Object.getPrototypeOf(testCase.expectedOutput0), - `tensor data expected ${Object.getPrototypeOf(testCase.expectedOutput0).constructor.name}, got ${ - Object.getPrototypeOf(output.output.data).constructor.name}`); + Object.getPrototypeOf(output.output.data), + Object.getPrototypeOf(testCase.expectedOutput0), + `tensor data expected ${Object.getPrototypeOf(testCase.expectedOutput0).constructor.name}, got ${ + Object.getPrototypeOf(output.output.data).constructor.name + }`, + ); assertDataEqual(testCase.type, output.output.data, testCase.expectedOutput0); }); }); diff --git a/js/node/test/ort-schema/protobuf/README.md b/js/node/test/ort-schema/protobuf/README.md index f5f52c602f1ad..35f61310db9aa 100644 --- a/js/node/test/ort-schema/protobuf/README.md +++ b/js/node/test/ort-schema/protobuf/README.md @@ -12,10 +12,10 @@ The ONNX protobuf uses protobufjs@7.2.4, which depends on long@5.2.3, the versio - type export does not work with commonjs. described in https://github.com/dcodeIO/long.js/pull/124. added a "postinstall" script to fix. - in the generated typescript declaration file 'onnx.d.ts', the following line: ```ts - import Long = require("long"); + import Long = require('long'); ``` need to be replaced to fix type import error: ```ts - import Long from "long"; + import Long from 'long'; ``` this replacement is done and code format is also applied to file 'onnx.d.ts'. diff --git a/js/node/test/ort-schema/protobuf/onnx.js b/js/node/test/ort-schema/protobuf/onnx.js index 681855132d4e8..24ccb627acff7 100644 --- a/js/node/test/ort-schema/protobuf/onnx.js +++ b/js/node/test/ort-schema/protobuf/onnx.js @@ -1,7658 +1,7391 @@ /*eslint-disable block-scoped-var, id-length, no-control-regex, no-magic-numbers, no-prototype-builtins, no-redeclare, no-shadow, no-var, sort-vars*/ -"use strict"; +'use strict'; -var $protobuf = require("protobufjs/minimal"); +var $protobuf = require('protobufjs/minimal'); // Common aliases -var $Reader = $protobuf.Reader, $Writer = $protobuf.Writer, $util = $protobuf.util; +var $Reader = $protobuf.Reader, + $Writer = $protobuf.Writer, + $util = $protobuf.util; // Exported root namespace -var $root = $protobuf.roots["default"] || ($protobuf.roots["default"] = {}); +var $root = $protobuf.roots['default'] || ($protobuf.roots['default'] = {}); + +$root.onnx = (function () { + /** + * Namespace onnx. + * @exports onnx + * @namespace + */ + var onnx = {}; + + /** + * Version enum. + * @name onnx.Version + * @enum {number} + * @property {number} _START_VERSION=0 _START_VERSION value + * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value + * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value + * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value + * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value + * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value + * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value + * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value + * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value + * @property {number} IR_VERSION=9 IR_VERSION value + */ + onnx.Version = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = '_START_VERSION')] = 0; + values[(valuesById[1] = 'IR_VERSION_2017_10_10')] = 1; + values[(valuesById[2] = 'IR_VERSION_2017_10_30')] = 2; + values[(valuesById[3] = 'IR_VERSION_2017_11_3')] = 3; + values[(valuesById[4] = 'IR_VERSION_2019_1_22')] = 4; + values[(valuesById[5] = 'IR_VERSION_2019_3_18')] = 5; + values[(valuesById[6] = 'IR_VERSION_2019_9_19')] = 6; + values[(valuesById[7] = 'IR_VERSION_2020_5_8')] = 7; + values[(valuesById[8] = 'IR_VERSION_2021_7_30')] = 8; + values[(valuesById[9] = 'IR_VERSION')] = 9; + return values; + })(); + + onnx.AttributeProto = (function () { + /** + * Properties of an AttributeProto. + * @memberof onnx + * @interface IAttributeProto + * @property {string|null} [name] AttributeProto name + * @property {string|null} [refAttrName] AttributeProto refAttrName + * @property {string|null} [docString] AttributeProto docString + * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type + * @property {number|null} [f] AttributeProto f + * @property {number|Long|null} [i] AttributeProto i + * @property {Uint8Array|null} [s] AttributeProto s + * @property {onnx.ITensorProto|null} [t] AttributeProto t + * @property {onnx.IGraphProto|null} [g] AttributeProto g + * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor + * @property {onnx.ITypeProto|null} [tp] AttributeProto tp + * @property {Array.|null} [floats] AttributeProto floats + * @property {Array.|null} [ints] AttributeProto ints + * @property {Array.|null} [strings] AttributeProto strings + * @property {Array.|null} [tensors] AttributeProto tensors + * @property {Array.|null} [graphs] AttributeProto graphs + * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors + * @property {Array.|null} [typeProtos] AttributeProto typeProtos + */ + + /** + * Constructs a new AttributeProto. + * @memberof onnx + * @classdesc Represents an AttributeProto. + * @implements IAttributeProto + * @constructor + * @param {onnx.IAttributeProto=} [properties] Properties to set + */ + function AttributeProto(properties) { + this.floats = []; + this.ints = []; + this.strings = []; + this.tensors = []; + this.graphs = []; + this.sparseTensors = []; + this.typeProtos = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * AttributeProto name. + * @member {string} name + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.name = ''; + + /** + * AttributeProto refAttrName. + * @member {string} refAttrName + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.refAttrName = ''; + + /** + * AttributeProto docString. + * @member {string} docString + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.docString = ''; + + /** + * AttributeProto type. + * @member {onnx.AttributeProto.AttributeType} type + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.type = 0; + + /** + * AttributeProto f. + * @member {number} f + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.f = 0; + + /** + * AttributeProto i. + * @member {number|Long} i + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * AttributeProto s. + * @member {Uint8Array} s + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.s = $util.newBuffer([]); + + /** + * AttributeProto t. + * @member {onnx.ITensorProto|null|undefined} t + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.t = null; + + /** + * AttributeProto g. + * @member {onnx.IGraphProto|null|undefined} g + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.g = null; + + /** + * AttributeProto sparseTensor. + * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensor = null; + + /** + * AttributeProto tp. + * @member {onnx.ITypeProto|null|undefined} tp + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tp = null; + + /** + * AttributeProto floats. + * @member {Array.} floats + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.floats = $util.emptyArray; + + /** + * AttributeProto ints. + * @member {Array.} ints + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.ints = $util.emptyArray; + + /** + * AttributeProto strings. + * @member {Array.} strings + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.strings = $util.emptyArray; + + /** + * AttributeProto tensors. + * @member {Array.} tensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tensors = $util.emptyArray; + + /** + * AttributeProto graphs. + * @member {Array.} graphs + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.graphs = $util.emptyArray; + + /** + * AttributeProto sparseTensors. + * @member {Array.} sparseTensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensors = $util.emptyArray; + + /** + * AttributeProto typeProtos. + * @member {Array.} typeProtos + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.typeProtos = $util.emptyArray; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @function create + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto=} [properties] Properties to set + * @returns {onnx.AttributeProto} AttributeProto instance + */ + AttributeProto.create = function create(properties) { + return new AttributeProto(properties); + }; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encode + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.name); + if (message.f != null && Object.hasOwnProperty.call(message, 'f')) + writer.uint32(/* id 2, wireType 5 =*/ 21).float(message.f); + if (message.i != null && Object.hasOwnProperty.call(message, 'i')) + writer.uint32(/* id 3, wireType 0 =*/ 24).int64(message.i); + if (message.s != null && Object.hasOwnProperty.call(message, 's')) + writer.uint32(/* id 4, wireType 2 =*/ 34).bytes(message.s); + if (message.t != null && Object.hasOwnProperty.call(message, 't')) + $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/ 42).fork()).ldelim(); + if (message.g != null && Object.hasOwnProperty.call(message, 'g')) + $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/ 50).fork()).ldelim(); + if (message.floats != null && message.floats.length) { + writer.uint32(/* id 7, wireType 2 =*/ 58).fork(); + for (var i = 0; i < message.floats.length; ++i) writer.float(message.floats[i]); + writer.ldelim(); + } + if (message.ints != null && message.ints.length) { + writer.uint32(/* id 8, wireType 2 =*/ 66).fork(); + for (var i = 0; i < message.ints.length; ++i) writer.int64(message.ints[i]); + writer.ldelim(); + } + if (message.strings != null && message.strings.length) + for (var i = 0; i < message.strings.length; ++i) + writer.uint32(/* id 9, wireType 2 =*/ 74).bytes(message.strings[i]); + if (message.tensors != null && message.tensors.length) + for (var i = 0; i < message.tensors.length; ++i) + $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/ 82).fork()).ldelim(); + if (message.graphs != null && message.graphs.length) + for (var i = 0; i < message.graphs.length; ++i) + $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/ 90).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 13, wireType 2 =*/ 106).string(message.docString); + if (message.tp != null && Object.hasOwnProperty.call(message, 'tp')) + $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/ 114).fork()).ldelim(); + if (message.typeProtos != null && message.typeProtos.length) + for (var i = 0; i < message.typeProtos.length; ++i) + $root.onnx.TypeProto.encode( + message.typeProtos[i], + writer.uint32(/* id 15, wireType 2 =*/ 122).fork(), + ).ldelim(); + if (message.type != null && Object.hasOwnProperty.call(message, 'type')) + writer.uint32(/* id 20, wireType 0 =*/ 160).int32(message.type); + if (message.refAttrName != null && Object.hasOwnProperty.call(message, 'refAttrName')) + writer.uint32(/* id 21, wireType 2 =*/ 170).string(message.refAttrName); + if (message.sparseTensor != null && Object.hasOwnProperty.call(message, 'sparseTensor')) + $root.onnx.SparseTensorProto.encode( + message.sparseTensor, + writer.uint32(/* id 22, wireType 2 =*/ 178).fork(), + ).ldelim(); + if (message.sparseTensors != null && message.sparseTensors.length) + for (var i = 0; i < message.sparseTensors.length; ++i) + $root.onnx.SparseTensorProto.encode( + message.sparseTensors[i], + writer.uint32(/* id 23, wireType 2 =*/ 186).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.AttributeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 21: { + message.refAttrName = reader.string(); + break; + } + case 13: { + message.docString = reader.string(); + break; + } + case 20: { + message.type = reader.int32(); + break; + } + case 2: { + message.f = reader.float(); + break; + } + case 3: { + message.i = reader.int64(); + break; + } + case 4: { + message.s = reader.bytes(); + break; + } + case 5: { + message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 6: { + message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 22: { + message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); + break; + } + case 14: { + message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 7: { + if (!(message.floats && message.floats.length)) message.floats = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.floats.push(reader.float()); + } else message.floats.push(reader.float()); + break; + } + case 8: { + if (!(message.ints && message.ints.length)) message.ints = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.ints.push(reader.int64()); + } else message.ints.push(reader.int64()); + break; + } + case 9: { + if (!(message.strings && message.strings.length)) message.strings = []; + message.strings.push(reader.bytes()); + break; + } + case 10: { + if (!(message.tensors && message.tensors.length)) message.tensors = []; + message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 11: { + if (!(message.graphs && message.graphs.length)) message.graphs = []; + message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); + break; + } + case 23: { + if (!(message.sparseTensors && message.sparseTensors.length)) message.sparseTensors = []; + message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.typeProtos && message.typeProtos.length)) message.typeProtos = []; + message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an AttributeProto message. + * @function verify + * @memberof onnx.AttributeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + AttributeProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.refAttrName != null && message.hasOwnProperty('refAttrName')) + if (!$util.isString(message.refAttrName)) return 'refAttrName: string expected'; + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.type != null && message.hasOwnProperty('type')) + switch (message.type) { + default: + return 'type: enum value expected'; + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 11: + case 13: + case 6: + case 7: + case 8: + case 9: + case 10: + case 12: + case 14: + break; + } + if (message.f != null && message.hasOwnProperty('f')) + if (typeof message.f !== 'number') return 'f: number expected'; + if (message.i != null && message.hasOwnProperty('i')) + if ( + !$util.isInteger(message.i) && + !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high)) + ) + return 'i: integer|Long expected'; + if (message.s != null && message.hasOwnProperty('s')) + if (!((message.s && typeof message.s.length === 'number') || $util.isString(message.s))) + return 's: buffer expected'; + if (message.t != null && message.hasOwnProperty('t')) { + var error = $root.onnx.TensorProto.verify(message.t); + if (error) return 't.' + error; + } + if (message.g != null && message.hasOwnProperty('g')) { + var error = $root.onnx.GraphProto.verify(message.g); + if (error) return 'g.' + error; + } + if (message.sparseTensor != null && message.hasOwnProperty('sparseTensor')) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); + if (error) return 'sparseTensor.' + error; + } + if (message.tp != null && message.hasOwnProperty('tp')) { + var error = $root.onnx.TypeProto.verify(message.tp); + if (error) return 'tp.' + error; + } + if (message.floats != null && message.hasOwnProperty('floats')) { + if (!Array.isArray(message.floats)) return 'floats: array expected'; + for (var i = 0; i < message.floats.length; ++i) + if (typeof message.floats[i] !== 'number') return 'floats: number[] expected'; + } + if (message.ints != null && message.hasOwnProperty('ints')) { + if (!Array.isArray(message.ints)) return 'ints: array expected'; + for (var i = 0; i < message.ints.length; ++i) + if ( + !$util.isInteger(message.ints[i]) && + !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high)) + ) + return 'ints: integer|Long[] expected'; + } + if (message.strings != null && message.hasOwnProperty('strings')) { + if (!Array.isArray(message.strings)) return 'strings: array expected'; + for (var i = 0; i < message.strings.length; ++i) + if ( + !( + (message.strings[i] && typeof message.strings[i].length === 'number') || + $util.isString(message.strings[i]) + ) + ) + return 'strings: buffer[] expected'; + } + if (message.tensors != null && message.hasOwnProperty('tensors')) { + if (!Array.isArray(message.tensors)) return 'tensors: array expected'; + for (var i = 0; i < message.tensors.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.tensors[i]); + if (error) return 'tensors.' + error; + } + } + if (message.graphs != null && message.hasOwnProperty('graphs')) { + if (!Array.isArray(message.graphs)) return 'graphs: array expected'; + for (var i = 0; i < message.graphs.length; ++i) { + var error = $root.onnx.GraphProto.verify(message.graphs[i]); + if (error) return 'graphs.' + error; + } + } + if (message.sparseTensors != null && message.hasOwnProperty('sparseTensors')) { + if (!Array.isArray(message.sparseTensors)) return 'sparseTensors: array expected'; + for (var i = 0; i < message.sparseTensors.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); + if (error) return 'sparseTensors.' + error; + } + } + if (message.typeProtos != null && message.hasOwnProperty('typeProtos')) { + if (!Array.isArray(message.typeProtos)) return 'typeProtos: array expected'; + for (var i = 0; i < message.typeProtos.length; ++i) { + var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); + if (error) return 'typeProtos.' + error; + } + } + return null; + }; + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.AttributeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.AttributeProto} AttributeProto + */ + AttributeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.AttributeProto) return object; + var message = new $root.onnx.AttributeProto(); + if (object.name != null) message.name = String(object.name); + if (object.refAttrName != null) message.refAttrName = String(object.refAttrName); + if (object.docString != null) message.docString = String(object.docString); + switch (object.type) { + default: + if (typeof object.type === 'number') { + message.type = object.type; + break; + } + break; + case 'UNDEFINED': + case 0: + message.type = 0; + break; + case 'FLOAT': + case 1: + message.type = 1; + break; + case 'INT': + case 2: + message.type = 2; + break; + case 'STRING': + case 3: + message.type = 3; + break; + case 'TENSOR': + case 4: + message.type = 4; + break; + case 'GRAPH': + case 5: + message.type = 5; + break; + case 'SPARSE_TENSOR': + case 11: + message.type = 11; + break; + case 'TYPE_PROTO': + case 13: + message.type = 13; + break; + case 'FLOATS': + case 6: + message.type = 6; + break; + case 'INTS': + case 7: + message.type = 7; + break; + case 'STRINGS': + case 8: + message.type = 8; + break; + case 'TENSORS': + case 9: + message.type = 9; + break; + case 'GRAPHS': + case 10: + message.type = 10; + break; + case 'SPARSE_TENSORS': + case 12: + message.type = 12; + break; + case 'TYPE_PROTOS': + case 14: + message.type = 14; + break; + } + if (object.f != null) message.f = Number(object.f); + if (object.i != null) + if ($util.Long) (message.i = $util.Long.fromValue(object.i)).unsigned = false; + else if (typeof object.i === 'string') message.i = parseInt(object.i, 10); + else if (typeof object.i === 'number') message.i = object.i; + else if (typeof object.i === 'object') + message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); + if (object.s != null) + if (typeof object.s === 'string') + $util.base64.decode(object.s, (message.s = $util.newBuffer($util.base64.length(object.s))), 0); + else if (object.s.length >= 0) message.s = object.s; + if (object.t != null) { + if (typeof object.t !== 'object') throw TypeError('.onnx.AttributeProto.t: object expected'); + message.t = $root.onnx.TensorProto.fromObject(object.t); + } + if (object.g != null) { + if (typeof object.g !== 'object') throw TypeError('.onnx.AttributeProto.g: object expected'); + message.g = $root.onnx.GraphProto.fromObject(object.g); + } + if (object.sparseTensor != null) { + if (typeof object.sparseTensor !== 'object') + throw TypeError('.onnx.AttributeProto.sparseTensor: object expected'); + message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); + } + if (object.tp != null) { + if (typeof object.tp !== 'object') throw TypeError('.onnx.AttributeProto.tp: object expected'); + message.tp = $root.onnx.TypeProto.fromObject(object.tp); + } + if (object.floats) { + if (!Array.isArray(object.floats)) throw TypeError('.onnx.AttributeProto.floats: array expected'); + message.floats = []; + for (var i = 0; i < object.floats.length; ++i) message.floats[i] = Number(object.floats[i]); + } + if (object.ints) { + if (!Array.isArray(object.ints)) throw TypeError('.onnx.AttributeProto.ints: array expected'); + message.ints = []; + for (var i = 0; i < object.ints.length; ++i) + if ($util.Long) (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; + else if (typeof object.ints[i] === 'string') message.ints[i] = parseInt(object.ints[i], 10); + else if (typeof object.ints[i] === 'number') message.ints[i] = object.ints[i]; + else if (typeof object.ints[i] === 'object') + message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); + } + if (object.strings) { + if (!Array.isArray(object.strings)) throw TypeError('.onnx.AttributeProto.strings: array expected'); + message.strings = []; + for (var i = 0; i < object.strings.length; ++i) + if (typeof object.strings[i] === 'string') + $util.base64.decode( + object.strings[i], + (message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i]))), + 0, + ); + else if (object.strings[i].length >= 0) message.strings[i] = object.strings[i]; + } + if (object.tensors) { + if (!Array.isArray(object.tensors)) throw TypeError('.onnx.AttributeProto.tensors: array expected'); + message.tensors = []; + for (var i = 0; i < object.tensors.length; ++i) { + if (typeof object.tensors[i] !== 'object') throw TypeError('.onnx.AttributeProto.tensors: object expected'); + message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); + } + } + if (object.graphs) { + if (!Array.isArray(object.graphs)) throw TypeError('.onnx.AttributeProto.graphs: array expected'); + message.graphs = []; + for (var i = 0; i < object.graphs.length; ++i) { + if (typeof object.graphs[i] !== 'object') throw TypeError('.onnx.AttributeProto.graphs: object expected'); + message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); + } + } + if (object.sparseTensors) { + if (!Array.isArray(object.sparseTensors)) throw TypeError('.onnx.AttributeProto.sparseTensors: array expected'); + message.sparseTensors = []; + for (var i = 0; i < object.sparseTensors.length; ++i) { + if (typeof object.sparseTensors[i] !== 'object') + throw TypeError('.onnx.AttributeProto.sparseTensors: object expected'); + message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); + } + } + if (object.typeProtos) { + if (!Array.isArray(object.typeProtos)) throw TypeError('.onnx.AttributeProto.typeProtos: array expected'); + message.typeProtos = []; + for (var i = 0; i < object.typeProtos.length; ++i) { + if (typeof object.typeProtos[i] !== 'object') + throw TypeError('.onnx.AttributeProto.typeProtos: object expected'); + message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); + } + } + return message; + }; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.AttributeProto + * @static + * @param {onnx.AttributeProto} message AttributeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + AttributeProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.floats = []; + object.ints = []; + object.strings = []; + object.tensors = []; + object.graphs = []; + object.typeProtos = []; + object.sparseTensors = []; + } + if (options.defaults) { + object.name = ''; + object.f = 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.i = options.longs === String ? '0' : 0; + if (options.bytes === String) object.s = ''; + else { + object.s = []; + if (options.bytes !== Array) object.s = $util.newBuffer(object.s); + } + object.t = null; + object.g = null; + object.docString = ''; + object.tp = null; + object.type = options.enums === String ? 'UNDEFINED' : 0; + object.refAttrName = ''; + object.sparseTensor = null; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.f != null && message.hasOwnProperty('f')) + object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; + if (message.i != null && message.hasOwnProperty('i')) + if (typeof message.i === 'number') object.i = options.longs === String ? String(message.i) : message.i; + else + object.i = + options.longs === String + ? $util.Long.prototype.toString.call(message.i) + : options.longs === Number + ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() + : message.i; + if (message.s != null && message.hasOwnProperty('s')) + object.s = + options.bytes === String + ? $util.base64.encode(message.s, 0, message.s.length) + : options.bytes === Array + ? Array.prototype.slice.call(message.s) + : message.s; + if (message.t != null && message.hasOwnProperty('t')) + object.t = $root.onnx.TensorProto.toObject(message.t, options); + if (message.g != null && message.hasOwnProperty('g')) + object.g = $root.onnx.GraphProto.toObject(message.g, options); + if (message.floats && message.floats.length) { + object.floats = []; + for (var j = 0; j < message.floats.length; ++j) + object.floats[j] = + options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; + } + if (message.ints && message.ints.length) { + object.ints = []; + for (var j = 0; j < message.ints.length; ++j) + if (typeof message.ints[j] === 'number') + object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; + else + object.ints[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.ints[j]) + : options.longs === Number + ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() + : message.ints[j]; + } + if (message.strings && message.strings.length) { + object.strings = []; + for (var j = 0; j < message.strings.length; ++j) + object.strings[j] = + options.bytes === String + ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) + : options.bytes === Array + ? Array.prototype.slice.call(message.strings[j]) + : message.strings[j]; + } + if (message.tensors && message.tensors.length) { + object.tensors = []; + for (var j = 0; j < message.tensors.length; ++j) + object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); + } + if (message.graphs && message.graphs.length) { + object.graphs = []; + for (var j = 0; j < message.graphs.length; ++j) + object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.tp != null && message.hasOwnProperty('tp')) + object.tp = $root.onnx.TypeProto.toObject(message.tp, options); + if (message.typeProtos && message.typeProtos.length) { + object.typeProtos = []; + for (var j = 0; j < message.typeProtos.length; ++j) + object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); + } + if (message.type != null && message.hasOwnProperty('type')) + object.type = + options.enums === String + ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined + ? message.type + : $root.onnx.AttributeProto.AttributeType[message.type] + : message.type; + if (message.refAttrName != null && message.hasOwnProperty('refAttrName')) + object.refAttrName = message.refAttrName; + if (message.sparseTensor != null && message.hasOwnProperty('sparseTensor')) + object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); + if (message.sparseTensors && message.sparseTensors.length) { + object.sparseTensors = []; + for (var j = 0; j < message.sparseTensors.length; ++j) + object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); + } + return object; + }; + + /** + * Converts this AttributeProto to JSON. + * @function toJSON + * @memberof onnx.AttributeProto + * @instance + * @returns {Object.} JSON object + */ + AttributeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for AttributeProto + * @function getTypeUrl + * @memberof onnx.AttributeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.AttributeProto'; + }; + + /** + * AttributeType enum. + * @name onnx.AttributeProto.AttributeType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} INT=2 INT value + * @property {number} STRING=3 STRING value + * @property {number} TENSOR=4 TENSOR value + * @property {number} GRAPH=5 GRAPH value + * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value + * @property {number} TYPE_PROTO=13 TYPE_PROTO value + * @property {number} FLOATS=6 FLOATS value + * @property {number} INTS=7 INTS value + * @property {number} STRINGS=8 STRINGS value + * @property {number} TENSORS=9 TENSORS value + * @property {number} GRAPHS=10 GRAPHS value + * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value + * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value + */ + AttributeProto.AttributeType = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = 'UNDEFINED')] = 0; + values[(valuesById[1] = 'FLOAT')] = 1; + values[(valuesById[2] = 'INT')] = 2; + values[(valuesById[3] = 'STRING')] = 3; + values[(valuesById[4] = 'TENSOR')] = 4; + values[(valuesById[5] = 'GRAPH')] = 5; + values[(valuesById[11] = 'SPARSE_TENSOR')] = 11; + values[(valuesById[13] = 'TYPE_PROTO')] = 13; + values[(valuesById[6] = 'FLOATS')] = 6; + values[(valuesById[7] = 'INTS')] = 7; + values[(valuesById[8] = 'STRINGS')] = 8; + values[(valuesById[9] = 'TENSORS')] = 9; + values[(valuesById[10] = 'GRAPHS')] = 10; + values[(valuesById[12] = 'SPARSE_TENSORS')] = 12; + values[(valuesById[14] = 'TYPE_PROTOS')] = 14; + return values; + })(); + + return AttributeProto; + })(); + + onnx.ValueInfoProto = (function () { + /** + * Properties of a ValueInfoProto. + * @memberof onnx + * @interface IValueInfoProto + * @property {string|null} [name] ValueInfoProto name + * @property {onnx.ITypeProto|null} [type] ValueInfoProto type + * @property {string|null} [docString] ValueInfoProto docString + */ + + /** + * Constructs a new ValueInfoProto. + * @memberof onnx + * @classdesc Represents a ValueInfoProto. + * @implements IValueInfoProto + * @constructor + * @param {onnx.IValueInfoProto=} [properties] Properties to set + */ + function ValueInfoProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * ValueInfoProto name. + * @member {string} name + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.name = ''; + + /** + * ValueInfoProto type. + * @member {onnx.ITypeProto|null|undefined} type + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.type = null; + + /** + * ValueInfoProto docString. + * @member {string} docString + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.docString = ''; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @function create + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto=} [properties] Properties to set + * @returns {onnx.ValueInfoProto} ValueInfoProto instance + */ + ValueInfoProto.create = function create(properties) { + return new ValueInfoProto(properties); + }; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.name); + if (message.type != null && Object.hasOwnProperty.call(message, 'type')) + $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.docString); + return writer; + }; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.ValueInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 2: { + message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 3: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ValueInfoProto message. + * @function verify + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ValueInfoProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.type != null && message.hasOwnProperty('type')) { + var error = $root.onnx.TypeProto.verify(message.type); + if (error) return 'type.' + error; + } + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + return null; + }; + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ValueInfoProto} ValueInfoProto + */ + ValueInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ValueInfoProto) return object; + var message = new $root.onnx.ValueInfoProto(); + if (object.name != null) message.name = String(object.name); + if (object.type != null) { + if (typeof object.type !== 'object') throw TypeError('.onnx.ValueInfoProto.type: object expected'); + message.type = $root.onnx.TypeProto.fromObject(object.type); + } + if (object.docString != null) message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.ValueInfoProto} message ValueInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ValueInfoProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.name = ''; + object.type = null; + object.docString = ''; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.type != null && message.hasOwnProperty('type')) + object.type = $root.onnx.TypeProto.toObject(message.type, options); + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + return object; + }; + + /** + * Converts this ValueInfoProto to JSON. + * @function toJSON + * @memberof onnx.ValueInfoProto + * @instance + * @returns {Object.} JSON object + */ + ValueInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ValueInfoProto + * @function getTypeUrl + * @memberof onnx.ValueInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.ValueInfoProto'; + }; + + return ValueInfoProto; + })(); + + onnx.NodeProto = (function () { + /** + * Properties of a NodeProto. + * @memberof onnx + * @interface INodeProto + * @property {Array.|null} [input] NodeProto input + * @property {Array.|null} [output] NodeProto output + * @property {string|null} [name] NodeProto name + * @property {string|null} [opType] NodeProto opType + * @property {string|null} [domain] NodeProto domain + * @property {Array.|null} [attribute] NodeProto attribute + * @property {string|null} [docString] NodeProto docString + */ + + /** + * Constructs a new NodeProto. + * @memberof onnx + * @classdesc Represents a NodeProto. + * @implements INodeProto + * @constructor + * @param {onnx.INodeProto=} [properties] Properties to set + */ + function NodeProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * NodeProto input. + * @member {Array.} input + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.input = $util.emptyArray; + + /** + * NodeProto output. + * @member {Array.} output + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.output = $util.emptyArray; + + /** + * NodeProto name. + * @member {string} name + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.name = ''; + + /** + * NodeProto opType. + * @member {string} opType + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.opType = ''; + + /** + * NodeProto domain. + * @member {string} domain + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.domain = ''; + + /** + * NodeProto attribute. + * @member {Array.} attribute + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.attribute = $util.emptyArray; + + /** + * NodeProto docString. + * @member {string} docString + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.docString = ''; + + /** + * Creates a new NodeProto instance using the specified properties. + * @function create + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto=} [properties] Properties to set + * @returns {onnx.NodeProto} NodeProto instance + */ + NodeProto.create = function create(properties) { + return new NodeProto(properties); + }; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encode + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.output[i]); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.name); + if (message.opType != null && Object.hasOwnProperty.call(message, 'opType')) + writer.uint32(/* id 4, wireType 2 =*/ 34).string(message.opType); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + $root.onnx.AttributeProto.encode( + message.attribute[i], + writer.uint32(/* id 5, wireType 2 =*/ 42).fork(), + ).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.docString); + if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + writer.uint32(/* id 7, wireType 2 =*/ 58).string(message.domain); + return writer; + }; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.NodeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.input && message.input.length)) message.input = []; + message.input.push(reader.string()); + break; + } + case 2: { + if (!(message.output && message.output.length)) message.output = []; + message.output.push(reader.string()); + break; + } + case 3: { + message.name = reader.string(); + break; + } + case 4: { + message.opType = reader.string(); + break; + } + case 7: { + message.domain = reader.string(); + break; + } + case 5: { + if (!(message.attribute && message.attribute.length)) message.attribute = []; + message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a NodeProto message. + * @function verify + * @memberof onnx.NodeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + NodeProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.input != null && message.hasOwnProperty('input')) { + if (!Array.isArray(message.input)) return 'input: array expected'; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) return 'input: string[] expected'; + } + if (message.output != null && message.hasOwnProperty('output')) { + if (!Array.isArray(message.output)) return 'output: array expected'; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) return 'output: string[] expected'; + } + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.opType != null && message.hasOwnProperty('opType')) + if (!$util.isString(message.opType)) return 'opType: string expected'; + if (message.domain != null && message.hasOwnProperty('domain')) + if (!$util.isString(message.domain)) return 'domain: string expected'; + if (message.attribute != null && message.hasOwnProperty('attribute')) { + if (!Array.isArray(message.attribute)) return 'attribute: array expected'; + for (var i = 0; i < message.attribute.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attribute[i]); + if (error) return 'attribute.' + error; + } + } + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + return null; + }; + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.NodeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.NodeProto} NodeProto + */ + NodeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.NodeProto) return object; + var message = new $root.onnx.NodeProto(); + if (object.input) { + if (!Array.isArray(object.input)) throw TypeError('.onnx.NodeProto.input: array expected'); + message.input = []; + for (var i = 0; i < object.input.length; ++i) message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) throw TypeError('.onnx.NodeProto.output: array expected'); + message.output = []; + for (var i = 0; i < object.output.length; ++i) message.output[i] = String(object.output[i]); + } + if (object.name != null) message.name = String(object.name); + if (object.opType != null) message.opType = String(object.opType); + if (object.domain != null) message.domain = String(object.domain); + if (object.attribute) { + if (!Array.isArray(object.attribute)) throw TypeError('.onnx.NodeProto.attribute: array expected'); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) { + if (typeof object.attribute[i] !== 'object') throw TypeError('.onnx.NodeProto.attribute: object expected'); + message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); + } + } + if (object.docString != null) message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.NodeProto + * @static + * @param {onnx.NodeProto} message NodeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + NodeProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + } + if (options.defaults) { + object.name = ''; + object.opType = ''; + object.docString = ''; + object.domain = ''; + } + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) object.output[j] = message.output[j]; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.opType != null && message.hasOwnProperty('opType')) object.opType = message.opType; + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + return object; + }; + + /** + * Converts this NodeProto to JSON. + * @function toJSON + * @memberof onnx.NodeProto + * @instance + * @returns {Object.} JSON object + */ + NodeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for NodeProto + * @function getTypeUrl + * @memberof onnx.NodeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.NodeProto'; + }; + + return NodeProto; + })(); + + onnx.TrainingInfoProto = (function () { + /** + * Properties of a TrainingInfoProto. + * @memberof onnx + * @interface ITrainingInfoProto + * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization + * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm + * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding + * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding + */ + + /** + * Constructs a new TrainingInfoProto. + * @memberof onnx + * @classdesc Represents a TrainingInfoProto. + * @implements ITrainingInfoProto + * @constructor + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + */ + function TrainingInfoProto(properties) { + this.initializationBinding = []; + this.updateBinding = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TrainingInfoProto initialization. + * @member {onnx.IGraphProto|null|undefined} initialization + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initialization = null; + + /** + * TrainingInfoProto algorithm. + * @member {onnx.IGraphProto|null|undefined} algorithm + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.algorithm = null; + + /** + * TrainingInfoProto initializationBinding. + * @member {Array.} initializationBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; + + /** + * TrainingInfoProto updateBinding. + * @member {Array.} updateBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.updateBinding = $util.emptyArray; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @function create + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance + */ + TrainingInfoProto.create = function create(properties) { + return new TrainingInfoProto(properties); + }; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.initialization != null && Object.hasOwnProperty.call(message, 'initialization')) + $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + if (message.algorithm != null && Object.hasOwnProperty.call(message, 'algorithm')) + $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if (message.initializationBinding != null && message.initializationBinding.length) + for (var i = 0; i < message.initializationBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.initializationBinding[i], + writer.uint32(/* id 3, wireType 2 =*/ 26).fork(), + ).ldelim(); + if (message.updateBinding != null && message.updateBinding.length) + for (var i = 0; i < message.updateBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.updateBinding[i], + writer.uint32(/* id 4, wireType 2 =*/ 34).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TrainingInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.initializationBinding && message.initializationBinding.length)) + message.initializationBinding = []; + message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 4: { + if (!(message.updateBinding && message.updateBinding.length)) message.updateBinding = []; + message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TrainingInfoProto message. + * @function verify + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TrainingInfoProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.initialization != null && message.hasOwnProperty('initialization')) { + var error = $root.onnx.GraphProto.verify(message.initialization); + if (error) return 'initialization.' + error; + } + if (message.algorithm != null && message.hasOwnProperty('algorithm')) { + var error = $root.onnx.GraphProto.verify(message.algorithm); + if (error) return 'algorithm.' + error; + } + if (message.initializationBinding != null && message.hasOwnProperty('initializationBinding')) { + if (!Array.isArray(message.initializationBinding)) return 'initializationBinding: array expected'; + for (var i = 0; i < message.initializationBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); + if (error) return 'initializationBinding.' + error; + } + } + if (message.updateBinding != null && message.hasOwnProperty('updateBinding')) { + if (!Array.isArray(message.updateBinding)) return 'updateBinding: array expected'; + for (var i = 0; i < message.updateBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); + if (error) return 'updateBinding.' + error; + } + } + return null; + }; + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + */ + TrainingInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TrainingInfoProto) return object; + var message = new $root.onnx.TrainingInfoProto(); + if (object.initialization != null) { + if (typeof object.initialization !== 'object') + throw TypeError('.onnx.TrainingInfoProto.initialization: object expected'); + message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); + } + if (object.algorithm != null) { + if (typeof object.algorithm !== 'object') throw TypeError('.onnx.TrainingInfoProto.algorithm: object expected'); + message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); + } + if (object.initializationBinding) { + if (!Array.isArray(object.initializationBinding)) + throw TypeError('.onnx.TrainingInfoProto.initializationBinding: array expected'); + message.initializationBinding = []; + for (var i = 0; i < object.initializationBinding.length; ++i) { + if (typeof object.initializationBinding[i] !== 'object') + throw TypeError('.onnx.TrainingInfoProto.initializationBinding: object expected'); + message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject( + object.initializationBinding[i], + ); + } + } + if (object.updateBinding) { + if (!Array.isArray(object.updateBinding)) + throw TypeError('.onnx.TrainingInfoProto.updateBinding: array expected'); + message.updateBinding = []; + for (var i = 0; i < object.updateBinding.length; ++i) { + if (typeof object.updateBinding[i] !== 'object') + throw TypeError('.onnx.TrainingInfoProto.updateBinding: object expected'); + message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.TrainingInfoProto} message TrainingInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TrainingInfoProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.initializationBinding = []; + object.updateBinding = []; + } + if (options.defaults) { + object.initialization = null; + object.algorithm = null; + } + if (message.initialization != null && message.hasOwnProperty('initialization')) + object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); + if (message.algorithm != null && message.hasOwnProperty('algorithm')) + object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); + if (message.initializationBinding && message.initializationBinding.length) { + object.initializationBinding = []; + for (var j = 0; j < message.initializationBinding.length; ++j) + object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject( + message.initializationBinding[j], + options, + ); + } + if (message.updateBinding && message.updateBinding.length) { + object.updateBinding = []; + for (var j = 0; j < message.updateBinding.length; ++j) + object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); + } + return object; + }; + + /** + * Converts this TrainingInfoProto to JSON. + * @function toJSON + * @memberof onnx.TrainingInfoProto + * @instance + * @returns {Object.} JSON object + */ + TrainingInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TrainingInfoProto + * @function getTypeUrl + * @memberof onnx.TrainingInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TrainingInfoProto'; + }; + + return TrainingInfoProto; + })(); + + onnx.ModelProto = (function () { + /** + * Properties of a ModelProto. + * @memberof onnx + * @interface IModelProto + * @property {number|Long|null} [irVersion] ModelProto irVersion + * @property {Array.|null} [opsetImport] ModelProto opsetImport + * @property {string|null} [producerName] ModelProto producerName + * @property {string|null} [producerVersion] ModelProto producerVersion + * @property {string|null} [domain] ModelProto domain + * @property {number|Long|null} [modelVersion] ModelProto modelVersion + * @property {string|null} [docString] ModelProto docString + * @property {onnx.IGraphProto|null} [graph] ModelProto graph + * @property {Array.|null} [metadataProps] ModelProto metadataProps + * @property {Array.|null} [trainingInfo] ModelProto trainingInfo + * @property {Array.|null} [functions] ModelProto functions + */ + + /** + * Constructs a new ModelProto. + * @memberof onnx + * @classdesc Represents a ModelProto. + * @implements IModelProto + * @constructor + * @param {onnx.IModelProto=} [properties] Properties to set + */ + function ModelProto(properties) { + this.opsetImport = []; + this.metadataProps = []; + this.trainingInfo = []; + this.functions = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * ModelProto irVersion. + * @member {number|Long} irVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * ModelProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.opsetImport = $util.emptyArray; + + /** + * ModelProto producerName. + * @member {string} producerName + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerName = ''; + + /** + * ModelProto producerVersion. + * @member {string} producerVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerVersion = ''; + + /** + * ModelProto domain. + * @member {string} domain + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.domain = ''; + + /** + * ModelProto modelVersion. + * @member {number|Long} modelVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * ModelProto docString. + * @member {string} docString + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.docString = ''; + + /** + * ModelProto graph. + * @member {onnx.IGraphProto|null|undefined} graph + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.graph = null; + + /** + * ModelProto metadataProps. + * @member {Array.} metadataProps + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.metadataProps = $util.emptyArray; + + /** + * ModelProto trainingInfo. + * @member {Array.} trainingInfo + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.trainingInfo = $util.emptyArray; + + /** + * ModelProto functions. + * @member {Array.} functions + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.functions = $util.emptyArray; + + /** + * Creates a new ModelProto instance using the specified properties. + * @function create + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto=} [properties] Properties to set + * @returns {onnx.ModelProto} ModelProto instance + */ + ModelProto.create = function create(properties) { + return new ModelProto(properties); + }; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encode + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.irVersion != null && Object.hasOwnProperty.call(message, 'irVersion')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int64(message.irVersion); + if (message.producerName != null && Object.hasOwnProperty.call(message, 'producerName')) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.producerName); + if (message.producerVersion != null && Object.hasOwnProperty.call(message, 'producerVersion')) + writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.producerVersion); + if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + writer.uint32(/* id 4, wireType 2 =*/ 34).string(message.domain); + if (message.modelVersion != null && Object.hasOwnProperty.call(message, 'modelVersion')) + writer.uint32(/* id 5, wireType 0 =*/ 40).int64(message.modelVersion); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.docString); + if (message.graph != null && Object.hasOwnProperty.call(message, 'graph')) + $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/ 58).fork()).ldelim(); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode( + message.opsetImport[i], + writer.uint32(/* id 8, wireType 2 =*/ 66).fork(), + ).ldelim(); + if (message.metadataProps != null && message.metadataProps.length) + for (var i = 0; i < message.metadataProps.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.metadataProps[i], + writer.uint32(/* id 14, wireType 2 =*/ 114).fork(), + ).ldelim(); + if (message.trainingInfo != null && message.trainingInfo.length) + for (var i = 0; i < message.trainingInfo.length; ++i) + $root.onnx.TrainingInfoProto.encode( + message.trainingInfo[i], + writer.uint32(/* id 20, wireType 2 =*/ 162).fork(), + ).ldelim(); + if (message.functions != null && message.functions.length) + for (var i = 0; i < message.functions.length; ++i) + $root.onnx.FunctionProto.encode( + message.functions[i], + writer.uint32(/* id 25, wireType 2 =*/ 202).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.ModelProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.irVersion = reader.int64(); + break; + } + case 8: { + if (!(message.opsetImport && message.opsetImport.length)) message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.producerName = reader.string(); + break; + } + case 3: { + message.producerVersion = reader.string(); + break; + } + case 4: { + message.domain = reader.string(); + break; + } + case 5: { + message.modelVersion = reader.int64(); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + case 7: { + message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 14: { + if (!(message.metadataProps && message.metadataProps.length)) message.metadataProps = []; + message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 20: { + if (!(message.trainingInfo && message.trainingInfo.length)) message.trainingInfo = []; + message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); + break; + } + case 25: { + if (!(message.functions && message.functions.length)) message.functions = []; + message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ModelProto message. + * @function verify + * @memberof onnx.ModelProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ModelProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.irVersion != null && message.hasOwnProperty('irVersion')) + if ( + !$util.isInteger(message.irVersion) && + !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high)) + ) + return 'irVersion: integer|Long expected'; + if (message.opsetImport != null && message.hasOwnProperty('opsetImport')) { + if (!Array.isArray(message.opsetImport)) return 'opsetImport: array expected'; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) return 'opsetImport.' + error; + } + } + if (message.producerName != null && message.hasOwnProperty('producerName')) + if (!$util.isString(message.producerName)) return 'producerName: string expected'; + if (message.producerVersion != null && message.hasOwnProperty('producerVersion')) + if (!$util.isString(message.producerVersion)) return 'producerVersion: string expected'; + if (message.domain != null && message.hasOwnProperty('domain')) + if (!$util.isString(message.domain)) return 'domain: string expected'; + if (message.modelVersion != null && message.hasOwnProperty('modelVersion')) + if ( + !$util.isInteger(message.modelVersion) && + !( + message.modelVersion && + $util.isInteger(message.modelVersion.low) && + $util.isInteger(message.modelVersion.high) + ) + ) + return 'modelVersion: integer|Long expected'; + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.graph != null && message.hasOwnProperty('graph')) { + var error = $root.onnx.GraphProto.verify(message.graph); + if (error) return 'graph.' + error; + } + if (message.metadataProps != null && message.hasOwnProperty('metadataProps')) { + if (!Array.isArray(message.metadataProps)) return 'metadataProps: array expected'; + for (var i = 0; i < message.metadataProps.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); + if (error) return 'metadataProps.' + error; + } + } + if (message.trainingInfo != null && message.hasOwnProperty('trainingInfo')) { + if (!Array.isArray(message.trainingInfo)) return 'trainingInfo: array expected'; + for (var i = 0; i < message.trainingInfo.length; ++i) { + var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); + if (error) return 'trainingInfo.' + error; + } + } + if (message.functions != null && message.hasOwnProperty('functions')) { + if (!Array.isArray(message.functions)) return 'functions: array expected'; + for (var i = 0; i < message.functions.length; ++i) { + var error = $root.onnx.FunctionProto.verify(message.functions[i]); + if (error) return 'functions.' + error; + } + } + return null; + }; + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ModelProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ModelProto} ModelProto + */ + ModelProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ModelProto) return object; + var message = new $root.onnx.ModelProto(); + if (object.irVersion != null) + if ($util.Long) (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; + else if (typeof object.irVersion === 'string') message.irVersion = parseInt(object.irVersion, 10); + else if (typeof object.irVersion === 'number') message.irVersion = object.irVersion; + else if (typeof object.irVersion === 'object') + message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) throw TypeError('.onnx.ModelProto.opsetImport: array expected'); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== 'object') + throw TypeError('.onnx.ModelProto.opsetImport: object expected'); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.producerName != null) message.producerName = String(object.producerName); + if (object.producerVersion != null) message.producerVersion = String(object.producerVersion); + if (object.domain != null) message.domain = String(object.domain); + if (object.modelVersion != null) + if ($util.Long) (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; + else if (typeof object.modelVersion === 'string') message.modelVersion = parseInt(object.modelVersion, 10); + else if (typeof object.modelVersion === 'number') message.modelVersion = object.modelVersion; + else if (typeof object.modelVersion === 'object') + message.modelVersion = new $util.LongBits( + object.modelVersion.low >>> 0, + object.modelVersion.high >>> 0, + ).toNumber(); + if (object.docString != null) message.docString = String(object.docString); + if (object.graph != null) { + if (typeof object.graph !== 'object') throw TypeError('.onnx.ModelProto.graph: object expected'); + message.graph = $root.onnx.GraphProto.fromObject(object.graph); + } + if (object.metadataProps) { + if (!Array.isArray(object.metadataProps)) throw TypeError('.onnx.ModelProto.metadataProps: array expected'); + message.metadataProps = []; + for (var i = 0; i < object.metadataProps.length; ++i) { + if (typeof object.metadataProps[i] !== 'object') + throw TypeError('.onnx.ModelProto.metadataProps: object expected'); + message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); + } + } + if (object.trainingInfo) { + if (!Array.isArray(object.trainingInfo)) throw TypeError('.onnx.ModelProto.trainingInfo: array expected'); + message.trainingInfo = []; + for (var i = 0; i < object.trainingInfo.length; ++i) { + if (typeof object.trainingInfo[i] !== 'object') + throw TypeError('.onnx.ModelProto.trainingInfo: object expected'); + message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); + } + } + if (object.functions) { + if (!Array.isArray(object.functions)) throw TypeError('.onnx.ModelProto.functions: array expected'); + message.functions = []; + for (var i = 0; i < object.functions.length; ++i) { + if (typeof object.functions[i] !== 'object') throw TypeError('.onnx.ModelProto.functions: object expected'); + message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ModelProto + * @static + * @param {onnx.ModelProto} message ModelProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ModelProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.opsetImport = []; + object.metadataProps = []; + object.trainingInfo = []; + object.functions = []; + } + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.irVersion = + options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.irVersion = options.longs === String ? '0' : 0; + object.producerName = ''; + object.producerVersion = ''; + object.domain = ''; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.modelVersion = + options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.modelVersion = options.longs === String ? '0' : 0; + object.docString = ''; + object.graph = null; + } + if (message.irVersion != null && message.hasOwnProperty('irVersion')) + if (typeof message.irVersion === 'number') + object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; + else + object.irVersion = + options.longs === String + ? $util.Long.prototype.toString.call(message.irVersion) + : options.longs === Number + ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() + : message.irVersion; + if (message.producerName != null && message.hasOwnProperty('producerName')) + object.producerName = message.producerName; + if (message.producerVersion != null && message.hasOwnProperty('producerVersion')) + object.producerVersion = message.producerVersion; + if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + if (message.modelVersion != null && message.hasOwnProperty('modelVersion')) + if (typeof message.modelVersion === 'number') + object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; + else + object.modelVersion = + options.longs === String + ? $util.Long.prototype.toString.call(message.modelVersion) + : options.longs === Number + ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() + : message.modelVersion; + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.graph != null && message.hasOwnProperty('graph')) + object.graph = $root.onnx.GraphProto.toObject(message.graph, options); + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.metadataProps && message.metadataProps.length) { + object.metadataProps = []; + for (var j = 0; j < message.metadataProps.length; ++j) + object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); + } + if (message.trainingInfo && message.trainingInfo.length) { + object.trainingInfo = []; + for (var j = 0; j < message.trainingInfo.length; ++j) + object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); + } + if (message.functions && message.functions.length) { + object.functions = []; + for (var j = 0; j < message.functions.length; ++j) + object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + } + return object; + }; + + /** + * Converts this ModelProto to JSON. + * @function toJSON + * @memberof onnx.ModelProto + * @instance + * @returns {Object.} JSON object + */ + ModelProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ModelProto + * @function getTypeUrl + * @memberof onnx.ModelProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.ModelProto'; + }; + + return ModelProto; + })(); + + onnx.StringStringEntryProto = (function () { + /** + * Properties of a StringStringEntryProto. + * @memberof onnx + * @interface IStringStringEntryProto + * @property {string|null} [key] StringStringEntryProto key + * @property {string|null} [value] StringStringEntryProto value + */ + + /** + * Constructs a new StringStringEntryProto. + * @memberof onnx + * @classdesc Represents a StringStringEntryProto. + * @implements IStringStringEntryProto + * @constructor + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + */ + function StringStringEntryProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * StringStringEntryProto key. + * @member {string} key + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.key = ''; + + /** + * StringStringEntryProto value. + * @member {string} value + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.value = ''; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @function create + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance + */ + StringStringEntryProto.create = function create(properties) { + return new StringStringEntryProto(properties); + }; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encode + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.key != null && Object.hasOwnProperty.call(message, 'key')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.key); + if (message.value != null && Object.hasOwnProperty.call(message, 'value')) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.value); + return writer; + }; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.StringStringEntryProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.key = reader.string(); + break; + } + case 2: { + message.value = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a StringStringEntryProto message. + * @function verify + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + StringStringEntryProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.key != null && message.hasOwnProperty('key')) + if (!$util.isString(message.key)) return 'key: string expected'; + if (message.value != null && message.hasOwnProperty('value')) + if (!$util.isString(message.value)) return 'value: string expected'; + return null; + }; + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + */ + StringStringEntryProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.StringStringEntryProto) return object; + var message = new $root.onnx.StringStringEntryProto(); + if (object.key != null) message.key = String(object.key); + if (object.value != null) message.value = String(object.value); + return message; + }; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.StringStringEntryProto} message StringStringEntryProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + StringStringEntryProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.key = ''; + object.value = ''; + } + if (message.key != null && message.hasOwnProperty('key')) object.key = message.key; + if (message.value != null && message.hasOwnProperty('value')) object.value = message.value; + return object; + }; + + /** + * Converts this StringStringEntryProto to JSON. + * @function toJSON + * @memberof onnx.StringStringEntryProto + * @instance + * @returns {Object.} JSON object + */ + StringStringEntryProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for StringStringEntryProto + * @function getTypeUrl + * @memberof onnx.StringStringEntryProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.StringStringEntryProto'; + }; + + return StringStringEntryProto; + })(); + + onnx.TensorAnnotation = (function () { + /** + * Properties of a TensorAnnotation. + * @memberof onnx + * @interface ITensorAnnotation + * @property {string|null} [tensorName] TensorAnnotation tensorName + * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames + */ + + /** + * Constructs a new TensorAnnotation. + * @memberof onnx + * @classdesc Represents a TensorAnnotation. + * @implements ITensorAnnotation + * @constructor + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + */ + function TensorAnnotation(properties) { + this.quantParameterTensorNames = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorAnnotation tensorName. + * @member {string} tensorName + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.tensorName = ''; + + /** + * TensorAnnotation quantParameterTensorNames. + * @member {Array.} quantParameterTensorNames + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @function create + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + * @returns {onnx.TensorAnnotation} TensorAnnotation instance + */ + TensorAnnotation.create = function create(properties) { + return new TensorAnnotation(properties); + }; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encode + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.tensorName != null && Object.hasOwnProperty.call(message, 'tensorName')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.tensorName); + if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.quantParameterTensorNames[i], + writer.uint32(/* id 2, wireType 2 =*/ 18).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorAnnotation(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorName = reader.string(); + break; + } + case 2: { + if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) + message.quantParameterTensorNames = []; + message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorAnnotation message. + * @function verify + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorAnnotation.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.tensorName != null && message.hasOwnProperty('tensorName')) + if (!$util.isString(message.tensorName)) return 'tensorName: string expected'; + if (message.quantParameterTensorNames != null && message.hasOwnProperty('quantParameterTensorNames')) { + if (!Array.isArray(message.quantParameterTensorNames)) return 'quantParameterTensorNames: array expected'; + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); + if (error) return 'quantParameterTensorNames.' + error; + } + } + return null; + }; + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorAnnotation} TensorAnnotation + */ + TensorAnnotation.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorAnnotation) return object; + var message = new $root.onnx.TensorAnnotation(); + if (object.tensorName != null) message.tensorName = String(object.tensorName); + if (object.quantParameterTensorNames) { + if (!Array.isArray(object.quantParameterTensorNames)) + throw TypeError('.onnx.TensorAnnotation.quantParameterTensorNames: array expected'); + message.quantParameterTensorNames = []; + for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { + if (typeof object.quantParameterTensorNames[i] !== 'object') + throw TypeError('.onnx.TensorAnnotation.quantParameterTensorNames: object expected'); + message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject( + object.quantParameterTensorNames[i], + ); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.TensorAnnotation} message TensorAnnotation + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorAnnotation.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) object.quantParameterTensorNames = []; + if (options.defaults) object.tensorName = ''; + if (message.tensorName != null && message.hasOwnProperty('tensorName')) object.tensorName = message.tensorName; + if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { + object.quantParameterTensorNames = []; + for (var j = 0; j < message.quantParameterTensorNames.length; ++j) + object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject( + message.quantParameterTensorNames[j], + options, + ); + } + return object; + }; + + /** + * Converts this TensorAnnotation to JSON. + * @function toJSON + * @memberof onnx.TensorAnnotation + * @instance + * @returns {Object.} JSON object + */ + TensorAnnotation.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorAnnotation + * @function getTypeUrl + * @memberof onnx.TensorAnnotation + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorAnnotation'; + }; + + return TensorAnnotation; + })(); + + onnx.GraphProto = (function () { + /** + * Properties of a GraphProto. + * @memberof onnx + * @interface IGraphProto + * @property {Array.|null} [node] GraphProto node + * @property {string|null} [name] GraphProto name + * @property {Array.|null} [initializer] GraphProto initializer + * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer + * @property {string|null} [docString] GraphProto docString + * @property {Array.|null} [input] GraphProto input + * @property {Array.|null} [output] GraphProto output + * @property {Array.|null} [valueInfo] GraphProto valueInfo + * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation + */ + + /** + * Constructs a new GraphProto. + * @memberof onnx + * @classdesc Represents a GraphProto. + * @implements IGraphProto + * @constructor + * @param {onnx.IGraphProto=} [properties] Properties to set + */ + function GraphProto(properties) { + this.node = []; + this.initializer = []; + this.sparseInitializer = []; + this.input = []; + this.output = []; + this.valueInfo = []; + this.quantizationAnnotation = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * GraphProto node. + * @member {Array.} node + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.node = $util.emptyArray; + + /** + * GraphProto name. + * @member {string} name + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.name = ''; + + /** + * GraphProto initializer. + * @member {Array.} initializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.initializer = $util.emptyArray; + + /** + * GraphProto sparseInitializer. + * @member {Array.} sparseInitializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.sparseInitializer = $util.emptyArray; + + /** + * GraphProto docString. + * @member {string} docString + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.docString = ''; + + /** + * GraphProto input. + * @member {Array.} input + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.input = $util.emptyArray; + + /** + * GraphProto output. + * @member {Array.} output + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.output = $util.emptyArray; + + /** + * GraphProto valueInfo. + * @member {Array.} valueInfo + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.valueInfo = $util.emptyArray; -$root.onnx = (function() { + /** + * GraphProto quantizationAnnotation. + * @member {Array.} quantizationAnnotation + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.quantizationAnnotation = $util.emptyArray; + + /** + * Creates a new GraphProto instance using the specified properties. + * @function create + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto=} [properties] Properties to set + * @returns {onnx.GraphProto} GraphProto instance + */ + GraphProto.create = function create(properties) { + return new GraphProto(properties); + }; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encode + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.name); + if (message.initializer != null && message.initializer.length) + for (var i = 0; i < message.initializer.length; ++i) + $root.onnx.TensorProto.encode( + message.initializer[i], + writer.uint32(/* id 5, wireType 2 =*/ 42).fork(), + ).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 10, wireType 2 =*/ 82).string(message.docString); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + $root.onnx.ValueInfoProto.encode( + message.input[i], + writer.uint32(/* id 11, wireType 2 =*/ 90).fork(), + ).ldelim(); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + $root.onnx.ValueInfoProto.encode( + message.output[i], + writer.uint32(/* id 12, wireType 2 =*/ 98).fork(), + ).ldelim(); + if (message.valueInfo != null && message.valueInfo.length) + for (var i = 0; i < message.valueInfo.length; ++i) + $root.onnx.ValueInfoProto.encode( + message.valueInfo[i], + writer.uint32(/* id 13, wireType 2 =*/ 106).fork(), + ).ldelim(); + if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) + for (var i = 0; i < message.quantizationAnnotation.length; ++i) + $root.onnx.TensorAnnotation.encode( + message.quantizationAnnotation[i], + writer.uint32(/* id 14, wireType 2 =*/ 114).fork(), + ).ldelim(); + if (message.sparseInitializer != null && message.sparseInitializer.length) + for (var i = 0; i < message.sparseInitializer.length; ++i) + $root.onnx.SparseTensorProto.encode( + message.sparseInitializer[i], + writer.uint32(/* id 15, wireType 2 =*/ 122).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.GraphProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.node && message.node.length)) message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.name = reader.string(); + break; + } + case 5: { + if (!(message.initializer && message.initializer.length)) message.initializer = []; + message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.sparseInitializer && message.sparseInitializer.length)) message.sparseInitializer = []; + message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.docString = reader.string(); + break; + } + case 11: { + if (!(message.input && message.input.length)) message.input = []; + message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 12: { + if (!(message.output && message.output.length)) message.output = []; + message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 13: { + if (!(message.valueInfo && message.valueInfo.length)) message.valueInfo = []; + message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 14: { + if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) + message.quantizationAnnotation = []; + message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a GraphProto message. + * @function verify + * @memberof onnx.GraphProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + GraphProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.node != null && message.hasOwnProperty('node')) { + if (!Array.isArray(message.node)) return 'node: array expected'; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) return 'node.' + error; + } + } + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.initializer != null && message.hasOwnProperty('initializer')) { + if (!Array.isArray(message.initializer)) return 'initializer: array expected'; + for (var i = 0; i < message.initializer.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.initializer[i]); + if (error) return 'initializer.' + error; + } + } + if (message.sparseInitializer != null && message.hasOwnProperty('sparseInitializer')) { + if (!Array.isArray(message.sparseInitializer)) return 'sparseInitializer: array expected'; + for (var i = 0; i < message.sparseInitializer.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); + if (error) return 'sparseInitializer.' + error; + } + } + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.input != null && message.hasOwnProperty('input')) { + if (!Array.isArray(message.input)) return 'input: array expected'; + for (var i = 0; i < message.input.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.input[i]); + if (error) return 'input.' + error; + } + } + if (message.output != null && message.hasOwnProperty('output')) { + if (!Array.isArray(message.output)) return 'output: array expected'; + for (var i = 0; i < message.output.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.output[i]); + if (error) return 'output.' + error; + } + } + if (message.valueInfo != null && message.hasOwnProperty('valueInfo')) { + if (!Array.isArray(message.valueInfo)) return 'valueInfo: array expected'; + for (var i = 0; i < message.valueInfo.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); + if (error) return 'valueInfo.' + error; + } + } + if (message.quantizationAnnotation != null && message.hasOwnProperty('quantizationAnnotation')) { + if (!Array.isArray(message.quantizationAnnotation)) return 'quantizationAnnotation: array expected'; + for (var i = 0; i < message.quantizationAnnotation.length; ++i) { + var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); + if (error) return 'quantizationAnnotation.' + error; + } + } + return null; + }; + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.GraphProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.GraphProto} GraphProto + */ + GraphProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.GraphProto) return object; + var message = new $root.onnx.GraphProto(); + if (object.node) { + if (!Array.isArray(object.node)) throw TypeError('.onnx.GraphProto.node: array expected'); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== 'object') throw TypeError('.onnx.GraphProto.node: object expected'); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.name != null) message.name = String(object.name); + if (object.initializer) { + if (!Array.isArray(object.initializer)) throw TypeError('.onnx.GraphProto.initializer: array expected'); + message.initializer = []; + for (var i = 0; i < object.initializer.length; ++i) { + if (typeof object.initializer[i] !== 'object') + throw TypeError('.onnx.GraphProto.initializer: object expected'); + message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); + } + } + if (object.sparseInitializer) { + if (!Array.isArray(object.sparseInitializer)) + throw TypeError('.onnx.GraphProto.sparseInitializer: array expected'); + message.sparseInitializer = []; + for (var i = 0; i < object.sparseInitializer.length; ++i) { + if (typeof object.sparseInitializer[i] !== 'object') + throw TypeError('.onnx.GraphProto.sparseInitializer: object expected'); + message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); + } + } + if (object.docString != null) message.docString = String(object.docString); + if (object.input) { + if (!Array.isArray(object.input)) throw TypeError('.onnx.GraphProto.input: array expected'); + message.input = []; + for (var i = 0; i < object.input.length; ++i) { + if (typeof object.input[i] !== 'object') throw TypeError('.onnx.GraphProto.input: object expected'); + message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); + } + } + if (object.output) { + if (!Array.isArray(object.output)) throw TypeError('.onnx.GraphProto.output: array expected'); + message.output = []; + for (var i = 0; i < object.output.length; ++i) { + if (typeof object.output[i] !== 'object') throw TypeError('.onnx.GraphProto.output: object expected'); + message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); + } + } + if (object.valueInfo) { + if (!Array.isArray(object.valueInfo)) throw TypeError('.onnx.GraphProto.valueInfo: array expected'); + message.valueInfo = []; + for (var i = 0; i < object.valueInfo.length; ++i) { + if (typeof object.valueInfo[i] !== 'object') throw TypeError('.onnx.GraphProto.valueInfo: object expected'); + message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); + } + } + if (object.quantizationAnnotation) { + if (!Array.isArray(object.quantizationAnnotation)) + throw TypeError('.onnx.GraphProto.quantizationAnnotation: array expected'); + message.quantizationAnnotation = []; + for (var i = 0; i < object.quantizationAnnotation.length; ++i) { + if (typeof object.quantizationAnnotation[i] !== 'object') + throw TypeError('.onnx.GraphProto.quantizationAnnotation: object expected'); + message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.GraphProto + * @static + * @param {onnx.GraphProto} message GraphProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + GraphProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.node = []; + object.initializer = []; + object.input = []; + object.output = []; + object.valueInfo = []; + object.quantizationAnnotation = []; + object.sparseInitializer = []; + } + if (options.defaults) { + object.name = ''; + object.docString = ''; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.initializer && message.initializer.length) { + object.initializer = []; + for (var j = 0; j < message.initializer.length; ++j) + object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); + } + if (message.valueInfo && message.valueInfo.length) { + object.valueInfo = []; + for (var j = 0; j < message.valueInfo.length; ++j) + object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); + } + if (message.quantizationAnnotation && message.quantizationAnnotation.length) { + object.quantizationAnnotation = []; + for (var j = 0; j < message.quantizationAnnotation.length; ++j) + object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject( + message.quantizationAnnotation[j], + options, + ); + } + if (message.sparseInitializer && message.sparseInitializer.length) { + object.sparseInitializer = []; + for (var j = 0; j < message.sparseInitializer.length; ++j) + object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); + } + return object; + }; + + /** + * Converts this GraphProto to JSON. + * @function toJSON + * @memberof onnx.GraphProto + * @instance + * @returns {Object.} JSON object + */ + GraphProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for GraphProto + * @function getTypeUrl + * @memberof onnx.GraphProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.GraphProto'; + }; + + return GraphProto; + })(); + + onnx.TensorProto = (function () { + /** + * Properties of a TensorProto. + * @memberof onnx + * @interface ITensorProto + * @property {Array.|null} [dims] TensorProto dims + * @property {number|null} [dataType] TensorProto dataType + * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment + * @property {Array.|null} [floatData] TensorProto floatData + * @property {Array.|null} [int32Data] TensorProto int32Data + * @property {Array.|null} [stringData] TensorProto stringData + * @property {Array.|null} [int64Data] TensorProto int64Data + * @property {string|null} [name] TensorProto name + * @property {string|null} [docString] TensorProto docString + * @property {Uint8Array|null} [rawData] TensorProto rawData + * @property {Array.|null} [externalData] TensorProto externalData + * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation + * @property {Array.|null} [doubleData] TensorProto doubleData + * @property {Array.|null} [uint64Data] TensorProto uint64Data + */ + + /** + * Constructs a new TensorProto. + * @memberof onnx + * @classdesc Represents a TensorProto. + * @implements ITensorProto + * @constructor + * @param {onnx.ITensorProto=} [properties] Properties to set + */ + function TensorProto(properties) { + this.dims = []; + this.floatData = []; + this.int32Data = []; + this.stringData = []; + this.int64Data = []; + this.externalData = []; + this.doubleData = []; + this.uint64Data = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorProto dims. + * @member {Array.} dims + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dims = $util.emptyArray; + + /** + * TensorProto dataType. + * @member {number} dataType + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataType = 0; + + /** + * TensorProto segment. + * @member {onnx.TensorProto.ISegment|null|undefined} segment + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.segment = null; + + /** + * TensorProto floatData. + * @member {Array.} floatData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.floatData = $util.emptyArray; + + /** + * TensorProto int32Data. + * @member {Array.} int32Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int32Data = $util.emptyArray; + + /** + * TensorProto stringData. + * @member {Array.} stringData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.stringData = $util.emptyArray; + + /** + * TensorProto int64Data. + * @member {Array.} int64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int64Data = $util.emptyArray; + + /** + * TensorProto name. + * @member {string} name + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.name = ''; + + /** + * TensorProto docString. + * @member {string} docString + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.docString = ''; + + /** + * TensorProto rawData. + * @member {Uint8Array} rawData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.rawData = $util.newBuffer([]); + + /** + * TensorProto externalData. + * @member {Array.} externalData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.externalData = $util.emptyArray; + + /** + * TensorProto dataLocation. + * @member {onnx.TensorProto.DataLocation} dataLocation + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataLocation = 0; + + /** + * TensorProto doubleData. + * @member {Array.} doubleData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.doubleData = $util.emptyArray; + + /** + * TensorProto uint64Data. + * @member {Array.} uint64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.uint64Data = $util.emptyArray; + + /** + * Creates a new TensorProto instance using the specified properties. + * @function create + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto=} [properties] Properties to set + * @returns {onnx.TensorProto} TensorProto instance + */ + TensorProto.create = function create(properties) { + return new TensorProto(properties); + }; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(); + for (var i = 0; i < message.dims.length; ++i) writer.int64(message.dims[i]); + writer.ldelim(); + } + if (message.dataType != null && Object.hasOwnProperty.call(message, 'dataType')) + writer.uint32(/* id 2, wireType 0 =*/ 16).int32(message.dataType); + if (message.segment != null && Object.hasOwnProperty.call(message, 'segment')) + $root.onnx.TensorProto.Segment.encode( + message.segment, + writer.uint32(/* id 3, wireType 2 =*/ 26).fork(), + ).ldelim(); + if (message.floatData != null && message.floatData.length) { + writer.uint32(/* id 4, wireType 2 =*/ 34).fork(); + for (var i = 0; i < message.floatData.length; ++i) writer.float(message.floatData[i]); + writer.ldelim(); + } + if (message.int32Data != null && message.int32Data.length) { + writer.uint32(/* id 5, wireType 2 =*/ 42).fork(); + for (var i = 0; i < message.int32Data.length; ++i) writer.int32(message.int32Data[i]); + writer.ldelim(); + } + if (message.stringData != null && message.stringData.length) + for (var i = 0; i < message.stringData.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/ 50).bytes(message.stringData[i]); + if (message.int64Data != null && message.int64Data.length) { + writer.uint32(/* id 7, wireType 2 =*/ 58).fork(); + for (var i = 0; i < message.int64Data.length; ++i) writer.int64(message.int64Data[i]); + writer.ldelim(); + } + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 8, wireType 2 =*/ 66).string(message.name); + if (message.rawData != null && Object.hasOwnProperty.call(message, 'rawData')) + writer.uint32(/* id 9, wireType 2 =*/ 74).bytes(message.rawData); + if (message.doubleData != null && message.doubleData.length) { + writer.uint32(/* id 10, wireType 2 =*/ 82).fork(); + for (var i = 0; i < message.doubleData.length; ++i) writer.double(message.doubleData[i]); + writer.ldelim(); + } + if (message.uint64Data != null && message.uint64Data.length) { + writer.uint32(/* id 11, wireType 2 =*/ 90).fork(); + for (var i = 0; i < message.uint64Data.length; ++i) writer.uint64(message.uint64Data[i]); + writer.ldelim(); + } + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 12, wireType 2 =*/ 98).string(message.docString); + if (message.externalData != null && message.externalData.length) + for (var i = 0; i < message.externalData.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.externalData[i], + writer.uint32(/* id 13, wireType 2 =*/ 106).fork(), + ).ldelim(); + if (message.dataLocation != null && Object.hasOwnProperty.call(message, 'dataLocation')) + writer.uint32(/* id 14, wireType 0 =*/ 112).int32(message.dataLocation); + return writer; + }; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dims && message.dims.length)) message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.dims.push(reader.int64()); + } else message.dims.push(reader.int64()); + break; + } + case 2: { + message.dataType = reader.int32(); + break; + } + case 3: { + message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); + break; + } + case 4: { + if (!(message.floatData && message.floatData.length)) message.floatData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.floatData.push(reader.float()); + } else message.floatData.push(reader.float()); + break; + } + case 5: { + if (!(message.int32Data && message.int32Data.length)) message.int32Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.int32Data.push(reader.int32()); + } else message.int32Data.push(reader.int32()); + break; + } + case 6: { + if (!(message.stringData && message.stringData.length)) message.stringData = []; + message.stringData.push(reader.bytes()); + break; + } + case 7: { + if (!(message.int64Data && message.int64Data.length)) message.int64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.int64Data.push(reader.int64()); + } else message.int64Data.push(reader.int64()); + break; + } + case 8: { + message.name = reader.string(); + break; + } + case 12: { + message.docString = reader.string(); + break; + } + case 9: { + message.rawData = reader.bytes(); + break; + } + case 13: { + if (!(message.externalData && message.externalData.length)) message.externalData = []; + message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 14: { + message.dataLocation = reader.int32(); + break; + } + case 10: { + if (!(message.doubleData && message.doubleData.length)) message.doubleData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.doubleData.push(reader.double()); + } else message.doubleData.push(reader.double()); + break; + } + case 11: { + if (!(message.uint64Data && message.uint64Data.length)) message.uint64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.uint64Data.push(reader.uint64()); + } else message.uint64Data.push(reader.uint64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorProto message. + * @function verify + * @memberof onnx.TensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.dims != null && message.hasOwnProperty('dims')) { + if (!Array.isArray(message.dims)) return 'dims: array expected'; + for (var i = 0; i < message.dims.length; ++i) + if ( + !$util.isInteger(message.dims[i]) && + !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high)) + ) + return 'dims: integer|Long[] expected'; + } + if (message.dataType != null && message.hasOwnProperty('dataType')) + if (!$util.isInteger(message.dataType)) return 'dataType: integer expected'; + if (message.segment != null && message.hasOwnProperty('segment')) { + var error = $root.onnx.TensorProto.Segment.verify(message.segment); + if (error) return 'segment.' + error; + } + if (message.floatData != null && message.hasOwnProperty('floatData')) { + if (!Array.isArray(message.floatData)) return 'floatData: array expected'; + for (var i = 0; i < message.floatData.length; ++i) + if (typeof message.floatData[i] !== 'number') return 'floatData: number[] expected'; + } + if (message.int32Data != null && message.hasOwnProperty('int32Data')) { + if (!Array.isArray(message.int32Data)) return 'int32Data: array expected'; + for (var i = 0; i < message.int32Data.length; ++i) + if (!$util.isInteger(message.int32Data[i])) return 'int32Data: integer[] expected'; + } + if (message.stringData != null && message.hasOwnProperty('stringData')) { + if (!Array.isArray(message.stringData)) return 'stringData: array expected'; + for (var i = 0; i < message.stringData.length; ++i) + if ( + !( + (message.stringData[i] && typeof message.stringData[i].length === 'number') || + $util.isString(message.stringData[i]) + ) + ) + return 'stringData: buffer[] expected'; + } + if (message.int64Data != null && message.hasOwnProperty('int64Data')) { + if (!Array.isArray(message.int64Data)) return 'int64Data: array expected'; + for (var i = 0; i < message.int64Data.length; ++i) + if ( + !$util.isInteger(message.int64Data[i]) && + !( + message.int64Data[i] && + $util.isInteger(message.int64Data[i].low) && + $util.isInteger(message.int64Data[i].high) + ) + ) + return 'int64Data: integer|Long[] expected'; + } + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.rawData != null && message.hasOwnProperty('rawData')) + if (!((message.rawData && typeof message.rawData.length === 'number') || $util.isString(message.rawData))) + return 'rawData: buffer expected'; + if (message.externalData != null && message.hasOwnProperty('externalData')) { + if (!Array.isArray(message.externalData)) return 'externalData: array expected'; + for (var i = 0; i < message.externalData.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); + if (error) return 'externalData.' + error; + } + } + if (message.dataLocation != null && message.hasOwnProperty('dataLocation')) + switch (message.dataLocation) { + default: + return 'dataLocation: enum value expected'; + case 0: + case 1: + break; + } + if (message.doubleData != null && message.hasOwnProperty('doubleData')) { + if (!Array.isArray(message.doubleData)) return 'doubleData: array expected'; + for (var i = 0; i < message.doubleData.length; ++i) + if (typeof message.doubleData[i] !== 'number') return 'doubleData: number[] expected'; + } + if (message.uint64Data != null && message.hasOwnProperty('uint64Data')) { + if (!Array.isArray(message.uint64Data)) return 'uint64Data: array expected'; + for (var i = 0; i < message.uint64Data.length; ++i) + if ( + !$util.isInteger(message.uint64Data[i]) && + !( + message.uint64Data[i] && + $util.isInteger(message.uint64Data[i].low) && + $util.isInteger(message.uint64Data[i].high) + ) + ) + return 'uint64Data: integer|Long[] expected'; + } + return null; + }; + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto} TensorProto + */ + TensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto) return object; + var message = new $root.onnx.TensorProto(); + if (object.dims) { + if (!Array.isArray(object.dims)) throw TypeError('.onnx.TensorProto.dims: array expected'); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === 'string') message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === 'number') message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === 'object') + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + if (object.dataType != null) message.dataType = object.dataType | 0; + if (object.segment != null) { + if (typeof object.segment !== 'object') throw TypeError('.onnx.TensorProto.segment: object expected'); + message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); + } + if (object.floatData) { + if (!Array.isArray(object.floatData)) throw TypeError('.onnx.TensorProto.floatData: array expected'); + message.floatData = []; + for (var i = 0; i < object.floatData.length; ++i) message.floatData[i] = Number(object.floatData[i]); + } + if (object.int32Data) { + if (!Array.isArray(object.int32Data)) throw TypeError('.onnx.TensorProto.int32Data: array expected'); + message.int32Data = []; + for (var i = 0; i < object.int32Data.length; ++i) message.int32Data[i] = object.int32Data[i] | 0; + } + if (object.stringData) { + if (!Array.isArray(object.stringData)) throw TypeError('.onnx.TensorProto.stringData: array expected'); + message.stringData = []; + for (var i = 0; i < object.stringData.length; ++i) + if (typeof object.stringData[i] === 'string') + $util.base64.decode( + object.stringData[i], + (message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i]))), + 0, + ); + else if (object.stringData[i].length >= 0) message.stringData[i] = object.stringData[i]; + } + if (object.int64Data) { + if (!Array.isArray(object.int64Data)) throw TypeError('.onnx.TensorProto.int64Data: array expected'); + message.int64Data = []; + for (var i = 0; i < object.int64Data.length; ++i) + if ($util.Long) (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; + else if (typeof object.int64Data[i] === 'string') message.int64Data[i] = parseInt(object.int64Data[i], 10); + else if (typeof object.int64Data[i] === 'number') message.int64Data[i] = object.int64Data[i]; + else if (typeof object.int64Data[i] === 'object') + message.int64Data[i] = new $util.LongBits( + object.int64Data[i].low >>> 0, + object.int64Data[i].high >>> 0, + ).toNumber(); + } + if (object.name != null) message.name = String(object.name); + if (object.docString != null) message.docString = String(object.docString); + if (object.rawData != null) + if (typeof object.rawData === 'string') + $util.base64.decode( + object.rawData, + (message.rawData = $util.newBuffer($util.base64.length(object.rawData))), + 0, + ); + else if (object.rawData.length >= 0) message.rawData = object.rawData; + if (object.externalData) { + if (!Array.isArray(object.externalData)) throw TypeError('.onnx.TensorProto.externalData: array expected'); + message.externalData = []; + for (var i = 0; i < object.externalData.length; ++i) { + if (typeof object.externalData[i] !== 'object') + throw TypeError('.onnx.TensorProto.externalData: object expected'); + message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); + } + } + switch (object.dataLocation) { + default: + if (typeof object.dataLocation === 'number') { + message.dataLocation = object.dataLocation; + break; + } + break; + case 'DEFAULT': + case 0: + message.dataLocation = 0; + break; + case 'EXTERNAL': + case 1: + message.dataLocation = 1; + break; + } + if (object.doubleData) { + if (!Array.isArray(object.doubleData)) throw TypeError('.onnx.TensorProto.doubleData: array expected'); + message.doubleData = []; + for (var i = 0; i < object.doubleData.length; ++i) message.doubleData[i] = Number(object.doubleData[i]); + } + if (object.uint64Data) { + if (!Array.isArray(object.uint64Data)) throw TypeError('.onnx.TensorProto.uint64Data: array expected'); + message.uint64Data = []; + for (var i = 0; i < object.uint64Data.length; ++i) + if ($util.Long) (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; + else if (typeof object.uint64Data[i] === 'string') message.uint64Data[i] = parseInt(object.uint64Data[i], 10); + else if (typeof object.uint64Data[i] === 'number') message.uint64Data[i] = object.uint64Data[i]; + else if (typeof object.uint64Data[i] === 'object') + message.uint64Data[i] = new $util.LongBits( + object.uint64Data[i].low >>> 0, + object.uint64Data[i].high >>> 0, + ).toNumber(true); + } + return message; + }; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto + * @static + * @param {onnx.TensorProto} message TensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.dims = []; + object.floatData = []; + object.int32Data = []; + object.stringData = []; + object.int64Data = []; + object.doubleData = []; + object.uint64Data = []; + object.externalData = []; + } + if (options.defaults) { + object.dataType = 0; + object.segment = null; + object.name = ''; + if (options.bytes === String) object.rawData = ''; + else { + object.rawData = []; + if (options.bytes !== Array) object.rawData = $util.newBuffer(object.rawData); + } + object.docString = ''; + object.dataLocation = options.enums === String ? 'DEFAULT' : 0; + } + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === 'number') + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.dims[j]) + : options.longs === Number + ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() + : message.dims[j]; + } + if (message.dataType != null && message.hasOwnProperty('dataType')) object.dataType = message.dataType; + if (message.segment != null && message.hasOwnProperty('segment')) + object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); + if (message.floatData && message.floatData.length) { + object.floatData = []; + for (var j = 0; j < message.floatData.length; ++j) + object.floatData[j] = + options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; + } + if (message.int32Data && message.int32Data.length) { + object.int32Data = []; + for (var j = 0; j < message.int32Data.length; ++j) object.int32Data[j] = message.int32Data[j]; + } + if (message.stringData && message.stringData.length) { + object.stringData = []; + for (var j = 0; j < message.stringData.length; ++j) + object.stringData[j] = + options.bytes === String + ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) + : options.bytes === Array + ? Array.prototype.slice.call(message.stringData[j]) + : message.stringData[j]; + } + if (message.int64Data && message.int64Data.length) { + object.int64Data = []; + for (var j = 0; j < message.int64Data.length; ++j) + if (typeof message.int64Data[j] === 'number') + object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; + else + object.int64Data[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.int64Data[j]) + : options.longs === Number + ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() + : message.int64Data[j]; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.rawData != null && message.hasOwnProperty('rawData')) + object.rawData = + options.bytes === String + ? $util.base64.encode(message.rawData, 0, message.rawData.length) + : options.bytes === Array + ? Array.prototype.slice.call(message.rawData) + : message.rawData; + if (message.doubleData && message.doubleData.length) { + object.doubleData = []; + for (var j = 0; j < message.doubleData.length; ++j) + object.doubleData[j] = + options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; + } + if (message.uint64Data && message.uint64Data.length) { + object.uint64Data = []; + for (var j = 0; j < message.uint64Data.length; ++j) + if (typeof message.uint64Data[j] === 'number') + object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; + else + object.uint64Data[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.uint64Data[j]) + : options.longs === Number + ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) + : message.uint64Data[j]; + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.externalData && message.externalData.length) { + object.externalData = []; + for (var j = 0; j < message.externalData.length; ++j) + object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); + } + if (message.dataLocation != null && message.hasOwnProperty('dataLocation')) + object.dataLocation = + options.enums === String + ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined + ? message.dataLocation + : $root.onnx.TensorProto.DataLocation[message.dataLocation] + : message.dataLocation; + return object; + }; + + /** + * Converts this TensorProto to JSON. + * @function toJSON + * @memberof onnx.TensorProto + * @instance + * @returns {Object.} JSON object + */ + TensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; /** - * Namespace onnx. - * @exports onnx - * @namespace + * Gets the default type url for TensorProto + * @function getTypeUrl + * @memberof onnx.TensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url */ - var onnx = {}; + TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorProto'; + }; /** - * Version enum. - * @name onnx.Version + * DataType enum. + * @name onnx.TensorProto.DataType * @enum {number} - * @property {number} _START_VERSION=0 _START_VERSION value - * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value - * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value - * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value - * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value - * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value - * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value - * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value - * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value - * @property {number} IR_VERSION=9 IR_VERSION value - */ - onnx.Version = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "_START_VERSION"] = 0; - values[valuesById[1] = "IR_VERSION_2017_10_10"] = 1; - values[valuesById[2] = "IR_VERSION_2017_10_30"] = 2; - values[valuesById[3] = "IR_VERSION_2017_11_3"] = 3; - values[valuesById[4] = "IR_VERSION_2019_1_22"] = 4; - values[valuesById[5] = "IR_VERSION_2019_3_18"] = 5; - values[valuesById[6] = "IR_VERSION_2019_9_19"] = 6; - values[valuesById[7] = "IR_VERSION_2020_5_8"] = 7; - values[valuesById[8] = "IR_VERSION_2021_7_30"] = 8; - values[valuesById[9] = "IR_VERSION"] = 9; - return values; + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} UINT8=2 UINT8 value + * @property {number} INT8=3 INT8 value + * @property {number} UINT16=4 UINT16 value + * @property {number} INT16=5 INT16 value + * @property {number} INT32=6 INT32 value + * @property {number} INT64=7 INT64 value + * @property {number} STRING=8 STRING value + * @property {number} BOOL=9 BOOL value + * @property {number} FLOAT16=10 FLOAT16 value + * @property {number} DOUBLE=11 DOUBLE value + * @property {number} UINT32=12 UINT32 value + * @property {number} UINT64=13 UINT64 value + * @property {number} COMPLEX64=14 COMPLEX64 value + * @property {number} COMPLEX128=15 COMPLEX128 value + * @property {number} BFLOAT16=16 BFLOAT16 value + * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value + * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value + * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value + * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value + */ + TensorProto.DataType = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = 'UNDEFINED')] = 0; + values[(valuesById[1] = 'FLOAT')] = 1; + values[(valuesById[2] = 'UINT8')] = 2; + values[(valuesById[3] = 'INT8')] = 3; + values[(valuesById[4] = 'UINT16')] = 4; + values[(valuesById[5] = 'INT16')] = 5; + values[(valuesById[6] = 'INT32')] = 6; + values[(valuesById[7] = 'INT64')] = 7; + values[(valuesById[8] = 'STRING')] = 8; + values[(valuesById[9] = 'BOOL')] = 9; + values[(valuesById[10] = 'FLOAT16')] = 10; + values[(valuesById[11] = 'DOUBLE')] = 11; + values[(valuesById[12] = 'UINT32')] = 12; + values[(valuesById[13] = 'UINT64')] = 13; + values[(valuesById[14] = 'COMPLEX64')] = 14; + values[(valuesById[15] = 'COMPLEX128')] = 15; + values[(valuesById[16] = 'BFLOAT16')] = 16; + values[(valuesById[17] = 'FLOAT8E4M3FN')] = 17; + values[(valuesById[18] = 'FLOAT8E4M3FNUZ')] = 18; + values[(valuesById[19] = 'FLOAT8E5M2')] = 19; + values[(valuesById[20] = 'FLOAT8E5M2FNUZ')] = 20; + return values; })(); - onnx.AttributeProto = (function() { - - /** - * Properties of an AttributeProto. - * @memberof onnx - * @interface IAttributeProto - * @property {string|null} [name] AttributeProto name - * @property {string|null} [refAttrName] AttributeProto refAttrName - * @property {string|null} [docString] AttributeProto docString - * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type - * @property {number|null} [f] AttributeProto f - * @property {number|Long|null} [i] AttributeProto i - * @property {Uint8Array|null} [s] AttributeProto s - * @property {onnx.ITensorProto|null} [t] AttributeProto t - * @property {onnx.IGraphProto|null} [g] AttributeProto g - * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor - * @property {onnx.ITypeProto|null} [tp] AttributeProto tp - * @property {Array.|null} [floats] AttributeProto floats - * @property {Array.|null} [ints] AttributeProto ints - * @property {Array.|null} [strings] AttributeProto strings - * @property {Array.|null} [tensors] AttributeProto tensors - * @property {Array.|null} [graphs] AttributeProto graphs - * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors - * @property {Array.|null} [typeProtos] AttributeProto typeProtos - */ - - /** - * Constructs a new AttributeProto. - * @memberof onnx - * @classdesc Represents an AttributeProto. - * @implements IAttributeProto - * @constructor - * @param {onnx.IAttributeProto=} [properties] Properties to set - */ - function AttributeProto(properties) { - this.floats = []; - this.ints = []; - this.strings = []; - this.tensors = []; - this.graphs = []; - this.sparseTensors = []; - this.typeProtos = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } - - /** - * AttributeProto name. - * @member {string} name - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.name = ""; - - /** - * AttributeProto refAttrName. - * @member {string} refAttrName - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.refAttrName = ""; - - /** - * AttributeProto docString. - * @member {string} docString - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.docString = ""; - - /** - * AttributeProto type. - * @member {onnx.AttributeProto.AttributeType} type - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.type = 0; - - /** - * AttributeProto f. - * @member {number} f - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.f = 0; - - /** - * AttributeProto i. - * @member {number|Long} i - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * AttributeProto s. - * @member {Uint8Array} s - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.s = $util.newBuffer([]); - - /** - * AttributeProto t. - * @member {onnx.ITensorProto|null|undefined} t - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.t = null; - - /** - * AttributeProto g. - * @member {onnx.IGraphProto|null|undefined} g - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.g = null; - - /** - * AttributeProto sparseTensor. - * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.sparseTensor = null; - - /** - * AttributeProto tp. - * @member {onnx.ITypeProto|null|undefined} tp - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.tp = null; - - /** - * AttributeProto floats. - * @member {Array.} floats - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.floats = $util.emptyArray; - - /** - * AttributeProto ints. - * @member {Array.} ints - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.ints = $util.emptyArray; - - /** - * AttributeProto strings. - * @member {Array.} strings - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.strings = $util.emptyArray; - - /** - * AttributeProto tensors. - * @member {Array.} tensors - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.tensors = $util.emptyArray; - - /** - * AttributeProto graphs. - * @member {Array.} graphs - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.graphs = $util.emptyArray; - - /** - * AttributeProto sparseTensors. - * @member {Array.} sparseTensors - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.sparseTensors = $util.emptyArray; - - /** - * AttributeProto typeProtos. - * @member {Array.} typeProtos - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.typeProtos = $util.emptyArray; - - /** - * Creates a new AttributeProto instance using the specified properties. - * @function create - * @memberof onnx.AttributeProto - * @static - * @param {onnx.IAttributeProto=} [properties] Properties to set - * @returns {onnx.AttributeProto} AttributeProto instance - */ - AttributeProto.create = function create(properties) { - return new AttributeProto(properties); - }; - - /** - * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. - * @function encode - * @memberof onnx.AttributeProto - * @static - * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - AttributeProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); - if (message.f != null && Object.hasOwnProperty.call(message, "f")) - writer.uint32(/* id 2, wireType 5 =*/21).float(message.f); - if (message.i != null && Object.hasOwnProperty.call(message, "i")) - writer.uint32(/* id 3, wireType 0 =*/24).int64(message.i); - if (message.s != null && Object.hasOwnProperty.call(message, "s")) - writer.uint32(/* id 4, wireType 2 =*/34).bytes(message.s); - if (message.t != null && Object.hasOwnProperty.call(message, "t")) - $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); - if (message.g != null && Object.hasOwnProperty.call(message, "g")) - $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/50).fork()).ldelim(); - if (message.floats != null && message.floats.length) { - writer.uint32(/* id 7, wireType 2 =*/58).fork(); - for (var i = 0; i < message.floats.length; ++i) - writer.float(message.floats[i]); - writer.ldelim(); - } - if (message.ints != null && message.ints.length) { - writer.uint32(/* id 8, wireType 2 =*/66).fork(); - for (var i = 0; i < message.ints.length; ++i) - writer.int64(message.ints[i]); - writer.ldelim(); - } - if (message.strings != null && message.strings.length) - for (var i = 0; i < message.strings.length; ++i) - writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.strings[i]); - if (message.tensors != null && message.tensors.length) - for (var i = 0; i < message.tensors.length; ++i) - $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/82).fork()).ldelim(); - if (message.graphs != null && message.graphs.length) - for (var i = 0; i < message.graphs.length; ++i) - $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 13, wireType 2 =*/106).string(message.docString); - if (message.tp != null && Object.hasOwnProperty.call(message, "tp")) - $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); - if (message.typeProtos != null && message.typeProtos.length) - for (var i = 0; i < message.typeProtos.length; ++i) - $root.onnx.TypeProto.encode(message.typeProtos[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); - if (message.type != null && Object.hasOwnProperty.call(message, "type")) - writer.uint32(/* id 20, wireType 0 =*/160).int32(message.type); - if (message.refAttrName != null && Object.hasOwnProperty.call(message, "refAttrName")) - writer.uint32(/* id 21, wireType 2 =*/170).string(message.refAttrName); - if (message.sparseTensor != null && Object.hasOwnProperty.call(message, "sparseTensor")) - $root.onnx.SparseTensorProto.encode(message.sparseTensor, writer.uint32(/* id 22, wireType 2 =*/178).fork()).ldelim(); - if (message.sparseTensors != null && message.sparseTensors.length) - for (var i = 0; i < message.sparseTensors.length; ++i) - $root.onnx.SparseTensorProto.encode(message.sparseTensors[i], writer.uint32(/* id 23, wireType 2 =*/186).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.AttributeProto - * @static - * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes an AttributeProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.AttributeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.AttributeProto} AttributeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - AttributeProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.AttributeProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.name = reader.string(); - break; - } - case 21: { - message.refAttrName = reader.string(); - break; - } - case 13: { - message.docString = reader.string(); - break; - } - case 20: { - message.type = reader.int32(); - break; - } - case 2: { - message.f = reader.float(); - break; - } - case 3: { - message.i = reader.int64(); - break; - } - case 4: { - message.s = reader.bytes(); - break; - } - case 5: { - message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); - break; - } - case 6: { - message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); - break; - } - case 22: { - message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); - break; - } - case 14: { - message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - case 7: { - if (!(message.floats && message.floats.length)) - message.floats = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.floats.push(reader.float()); - } else - message.floats.push(reader.float()); - break; - } - case 8: { - if (!(message.ints && message.ints.length)) - message.ints = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.ints.push(reader.int64()); - } else - message.ints.push(reader.int64()); - break; - } - case 9: { - if (!(message.strings && message.strings.length)) - message.strings = []; - message.strings.push(reader.bytes()); - break; - } - case 10: { - if (!(message.tensors && message.tensors.length)) - message.tensors = []; - message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); - break; - } - case 11: { - if (!(message.graphs && message.graphs.length)) - message.graphs = []; - message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); - break; - } - case 23: { - if (!(message.sparseTensors && message.sparseTensors.length)) - message.sparseTensors = []; - message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); - break; - } - case 15: { - if (!(message.typeProtos && message.typeProtos.length)) - message.typeProtos = []; - message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes an AttributeProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.AttributeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.AttributeProto} AttributeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - AttributeProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies an AttributeProto message. - * @function verify - * @memberof onnx.AttributeProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - AttributeProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) - if (!$util.isString(message.refAttrName)) - return "refAttrName: string expected"; - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.type != null && message.hasOwnProperty("type")) - switch (message.type) { - default: - return "type: enum value expected"; - case 0: - case 1: - case 2: - case 3: - case 4: - case 5: - case 11: - case 13: - case 6: - case 7: - case 8: - case 9: - case 10: - case 12: - case 14: - break; - } - if (message.f != null && message.hasOwnProperty("f")) - if (typeof message.f !== "number") - return "f: number expected"; - if (message.i != null && message.hasOwnProperty("i")) - if (!$util.isInteger(message.i) && !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high))) - return "i: integer|Long expected"; - if (message.s != null && message.hasOwnProperty("s")) - if (!(message.s && typeof message.s.length === "number" || $util.isString(message.s))) - return "s: buffer expected"; - if (message.t != null && message.hasOwnProperty("t")) { - var error = $root.onnx.TensorProto.verify(message.t); - if (error) - return "t." + error; - } - if (message.g != null && message.hasOwnProperty("g")) { - var error = $root.onnx.GraphProto.verify(message.g); - if (error) - return "g." + error; - } - if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) { - var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); - if (error) - return "sparseTensor." + error; - } - if (message.tp != null && message.hasOwnProperty("tp")) { - var error = $root.onnx.TypeProto.verify(message.tp); - if (error) - return "tp." + error; - } - if (message.floats != null && message.hasOwnProperty("floats")) { - if (!Array.isArray(message.floats)) - return "floats: array expected"; - for (var i = 0; i < message.floats.length; ++i) - if (typeof message.floats[i] !== "number") - return "floats: number[] expected"; - } - if (message.ints != null && message.hasOwnProperty("ints")) { - if (!Array.isArray(message.ints)) - return "ints: array expected"; - for (var i = 0; i < message.ints.length; ++i) - if (!$util.isInteger(message.ints[i]) && !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high))) - return "ints: integer|Long[] expected"; - } - if (message.strings != null && message.hasOwnProperty("strings")) { - if (!Array.isArray(message.strings)) - return "strings: array expected"; - for (var i = 0; i < message.strings.length; ++i) - if (!(message.strings[i] && typeof message.strings[i].length === "number" || $util.isString(message.strings[i]))) - return "strings: buffer[] expected"; - } - if (message.tensors != null && message.hasOwnProperty("tensors")) { - if (!Array.isArray(message.tensors)) - return "tensors: array expected"; - for (var i = 0; i < message.tensors.length; ++i) { - var error = $root.onnx.TensorProto.verify(message.tensors[i]); - if (error) - return "tensors." + error; - } - } - if (message.graphs != null && message.hasOwnProperty("graphs")) { - if (!Array.isArray(message.graphs)) - return "graphs: array expected"; - for (var i = 0; i < message.graphs.length; ++i) { - var error = $root.onnx.GraphProto.verify(message.graphs[i]); - if (error) - return "graphs." + error; - } - } - if (message.sparseTensors != null && message.hasOwnProperty("sparseTensors")) { - if (!Array.isArray(message.sparseTensors)) - return "sparseTensors: array expected"; - for (var i = 0; i < message.sparseTensors.length; ++i) { - var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); - if (error) - return "sparseTensors." + error; - } - } - if (message.typeProtos != null && message.hasOwnProperty("typeProtos")) { - if (!Array.isArray(message.typeProtos)) - return "typeProtos: array expected"; - for (var i = 0; i < message.typeProtos.length; ++i) { - var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); - if (error) - return "typeProtos." + error; - } + TensorProto.Segment = (function () { + /** + * Properties of a Segment. + * @memberof onnx.TensorProto + * @interface ISegment + * @property {number|Long|null} [begin] Segment begin + * @property {number|Long|null} [end] Segment end + */ + + /** + * Constructs a new Segment. + * @memberof onnx.TensorProto + * @classdesc Represents a Segment. + * @implements ISegment + * @constructor + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + */ + function Segment(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Segment begin. + * @member {number|Long} begin + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * Segment end. + * @member {number|Long} end + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.end = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * Creates a new Segment instance using the specified properties. + * @function create + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + * @returns {onnx.TensorProto.Segment} Segment instance + */ + Segment.create = function create(properties) { + return new Segment(properties); + }; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.begin != null && Object.hasOwnProperty.call(message, 'begin')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int64(message.begin); + if (message.end != null && Object.hasOwnProperty.call(message, 'end')) + writer.uint32(/* id 2, wireType 0 =*/ 16).int64(message.end); + return writer; + }; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorProto.Segment(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.begin = reader.int64(); + break; + } + case 2: { + message.end = reader.int64(); + break; } - return null; - }; - - /** - * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.AttributeProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.AttributeProto} AttributeProto - */ - AttributeProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.AttributeProto) - return object; - var message = new $root.onnx.AttributeProto(); - if (object.name != null) - message.name = String(object.name); - if (object.refAttrName != null) - message.refAttrName = String(object.refAttrName); - if (object.docString != null) - message.docString = String(object.docString); - switch (object.type) { default: - if (typeof object.type === "number") { - message.type = object.type; - break; - } - break; - case "UNDEFINED": - case 0: - message.type = 0; - break; - case "FLOAT": - case 1: - message.type = 1; - break; - case "INT": - case 2: - message.type = 2; - break; - case "STRING": - case 3: - message.type = 3; - break; - case "TENSOR": - case 4: - message.type = 4; - break; - case "GRAPH": - case 5: - message.type = 5; - break; - case "SPARSE_TENSOR": - case 11: - message.type = 11; - break; - case "TYPE_PROTO": - case 13: - message.type = 13; - break; - case "FLOATS": - case 6: - message.type = 6; - break; - case "INTS": - case 7: - message.type = 7; - break; - case "STRINGS": - case 8: - message.type = 8; - break; - case "TENSORS": - case 9: - message.type = 9; - break; - case "GRAPHS": - case 10: - message.type = 10; - break; - case "SPARSE_TENSORS": - case 12: - message.type = 12; - break; - case "TYPE_PROTOS": - case 14: - message.type = 14; - break; - } - if (object.f != null) - message.f = Number(object.f); - if (object.i != null) - if ($util.Long) - (message.i = $util.Long.fromValue(object.i)).unsigned = false; - else if (typeof object.i === "string") - message.i = parseInt(object.i, 10); - else if (typeof object.i === "number") - message.i = object.i; - else if (typeof object.i === "object") - message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); - if (object.s != null) - if (typeof object.s === "string") - $util.base64.decode(object.s, message.s = $util.newBuffer($util.base64.length(object.s)), 0); - else if (object.s.length >= 0) - message.s = object.s; - if (object.t != null) { - if (typeof object.t !== "object") - throw TypeError(".onnx.AttributeProto.t: object expected"); - message.t = $root.onnx.TensorProto.fromObject(object.t); - } - if (object.g != null) { - if (typeof object.g !== "object") - throw TypeError(".onnx.AttributeProto.g: object expected"); - message.g = $root.onnx.GraphProto.fromObject(object.g); - } - if (object.sparseTensor != null) { - if (typeof object.sparseTensor !== "object") - throw TypeError(".onnx.AttributeProto.sparseTensor: object expected"); - message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); - } - if (object.tp != null) { - if (typeof object.tp !== "object") - throw TypeError(".onnx.AttributeProto.tp: object expected"); - message.tp = $root.onnx.TypeProto.fromObject(object.tp); - } - if (object.floats) { - if (!Array.isArray(object.floats)) - throw TypeError(".onnx.AttributeProto.floats: array expected"); - message.floats = []; - for (var i = 0; i < object.floats.length; ++i) - message.floats[i] = Number(object.floats[i]); - } - if (object.ints) { - if (!Array.isArray(object.ints)) - throw TypeError(".onnx.AttributeProto.ints: array expected"); - message.ints = []; - for (var i = 0; i < object.ints.length; ++i) - if ($util.Long) - (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; - else if (typeof object.ints[i] === "string") - message.ints[i] = parseInt(object.ints[i], 10); - else if (typeof object.ints[i] === "number") - message.ints[i] = object.ints[i]; - else if (typeof object.ints[i] === "object") - message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); - } - if (object.strings) { - if (!Array.isArray(object.strings)) - throw TypeError(".onnx.AttributeProto.strings: array expected"); - message.strings = []; - for (var i = 0; i < object.strings.length; ++i) - if (typeof object.strings[i] === "string") - $util.base64.decode(object.strings[i], message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i])), 0); - else if (object.strings[i].length >= 0) - message.strings[i] = object.strings[i]; - } - if (object.tensors) { - if (!Array.isArray(object.tensors)) - throw TypeError(".onnx.AttributeProto.tensors: array expected"); - message.tensors = []; - for (var i = 0; i < object.tensors.length; ++i) { - if (typeof object.tensors[i] !== "object") - throw TypeError(".onnx.AttributeProto.tensors: object expected"); - message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); - } - } - if (object.graphs) { - if (!Array.isArray(object.graphs)) - throw TypeError(".onnx.AttributeProto.graphs: array expected"); - message.graphs = []; - for (var i = 0; i < object.graphs.length; ++i) { - if (typeof object.graphs[i] !== "object") - throw TypeError(".onnx.AttributeProto.graphs: object expected"); - message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); - } - } - if (object.sparseTensors) { - if (!Array.isArray(object.sparseTensors)) - throw TypeError(".onnx.AttributeProto.sparseTensors: array expected"); - message.sparseTensors = []; - for (var i = 0; i < object.sparseTensors.length; ++i) { - if (typeof object.sparseTensors[i] !== "object") - throw TypeError(".onnx.AttributeProto.sparseTensors: object expected"); - message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); - } - } - if (object.typeProtos) { - if (!Array.isArray(object.typeProtos)) - throw TypeError(".onnx.AttributeProto.typeProtos: array expected"); - message.typeProtos = []; - for (var i = 0; i < object.typeProtos.length; ++i) { - if (typeof object.typeProtos[i] !== "object") - throw TypeError(".onnx.AttributeProto.typeProtos: object expected"); - message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); - } - } - return message; - }; - - /** - * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.AttributeProto - * @static - * @param {onnx.AttributeProto} message AttributeProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - AttributeProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.floats = []; - object.ints = []; - object.strings = []; - object.tensors = []; - object.graphs = []; - object.typeProtos = []; - object.sparseTensors = []; - } - if (options.defaults) { - object.name = ""; - object.f = 0; - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.i = options.longs === String ? "0" : 0; - if (options.bytes === String) - object.s = ""; - else { - object.s = []; - if (options.bytes !== Array) - object.s = $util.newBuffer(object.s); - } - object.t = null; - object.g = null; - object.docString = ""; - object.tp = null; - object.type = options.enums === String ? "UNDEFINED" : 0; - object.refAttrName = ""; - object.sparseTensor = null; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.f != null && message.hasOwnProperty("f")) - object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; - if (message.i != null && message.hasOwnProperty("i")) - if (typeof message.i === "number") - object.i = options.longs === String ? String(message.i) : message.i; - else - object.i = options.longs === String ? $util.Long.prototype.toString.call(message.i) : options.longs === Number ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() : message.i; - if (message.s != null && message.hasOwnProperty("s")) - object.s = options.bytes === String ? $util.base64.encode(message.s, 0, message.s.length) : options.bytes === Array ? Array.prototype.slice.call(message.s) : message.s; - if (message.t != null && message.hasOwnProperty("t")) - object.t = $root.onnx.TensorProto.toObject(message.t, options); - if (message.g != null && message.hasOwnProperty("g")) - object.g = $root.onnx.GraphProto.toObject(message.g, options); - if (message.floats && message.floats.length) { - object.floats = []; - for (var j = 0; j < message.floats.length; ++j) - object.floats[j] = options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; - } - if (message.ints && message.ints.length) { - object.ints = []; - for (var j = 0; j < message.ints.length; ++j) - if (typeof message.ints[j] === "number") - object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; - else - object.ints[j] = options.longs === String ? $util.Long.prototype.toString.call(message.ints[j]) : options.longs === Number ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() : message.ints[j]; - } - if (message.strings && message.strings.length) { - object.strings = []; - for (var j = 0; j < message.strings.length; ++j) - object.strings[j] = options.bytes === String ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.strings[j]) : message.strings[j]; - } - if (message.tensors && message.tensors.length) { - object.tensors = []; - for (var j = 0; j < message.tensors.length; ++j) - object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); - } - if (message.graphs && message.graphs.length) { - object.graphs = []; - for (var j = 0; j < message.graphs.length; ++j) - object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.tp != null && message.hasOwnProperty("tp")) - object.tp = $root.onnx.TypeProto.toObject(message.tp, options); - if (message.typeProtos && message.typeProtos.length) { - object.typeProtos = []; - for (var j = 0; j < message.typeProtos.length; ++j) - object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); - } - if (message.type != null && message.hasOwnProperty("type")) - object.type = options.enums === String ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined ? message.type : $root.onnx.AttributeProto.AttributeType[message.type] : message.type; - if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) - object.refAttrName = message.refAttrName; - if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) - object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); - if (message.sparseTensors && message.sparseTensors.length) { - object.sparseTensors = []; - for (var j = 0; j < message.sparseTensors.length; ++j) - object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); - } - return object; - }; - - /** - * Converts this AttributeProto to JSON. - * @function toJSON - * @memberof onnx.AttributeProto - * @instance - * @returns {Object.} JSON object - */ - AttributeProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for AttributeProto - * @function getTypeUrl - * @memberof onnx.AttributeProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.AttributeProto"; - }; - - /** - * AttributeType enum. - * @name onnx.AttributeProto.AttributeType - * @enum {number} - * @property {number} UNDEFINED=0 UNDEFINED value - * @property {number} FLOAT=1 FLOAT value - * @property {number} INT=2 INT value - * @property {number} STRING=3 STRING value - * @property {number} TENSOR=4 TENSOR value - * @property {number} GRAPH=5 GRAPH value - * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value - * @property {number} TYPE_PROTO=13 TYPE_PROTO value - * @property {number} FLOATS=6 FLOATS value - * @property {number} INTS=7 INTS value - * @property {number} STRINGS=8 STRINGS value - * @property {number} TENSORS=9 TENSORS value - * @property {number} GRAPHS=10 GRAPHS value - * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value - * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value - */ - AttributeProto.AttributeType = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "UNDEFINED"] = 0; - values[valuesById[1] = "FLOAT"] = 1; - values[valuesById[2] = "INT"] = 2; - values[valuesById[3] = "STRING"] = 3; - values[valuesById[4] = "TENSOR"] = 4; - values[valuesById[5] = "GRAPH"] = 5; - values[valuesById[11] = "SPARSE_TENSOR"] = 11; - values[valuesById[13] = "TYPE_PROTO"] = 13; - values[valuesById[6] = "FLOATS"] = 6; - values[valuesById[7] = "INTS"] = 7; - values[valuesById[8] = "STRINGS"] = 8; - values[valuesById[9] = "TENSORS"] = 9; - values[valuesById[10] = "GRAPHS"] = 10; - values[valuesById[12] = "SPARSE_TENSORS"] = 12; - values[valuesById[14] = "TYPE_PROTOS"] = 14; - return values; - })(); - - return AttributeProto; + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Segment message. + * @function verify + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Segment.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.begin != null && message.hasOwnProperty('begin')) + if ( + !$util.isInteger(message.begin) && + !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high)) + ) + return 'begin: integer|Long expected'; + if (message.end != null && message.hasOwnProperty('end')) + if ( + !$util.isInteger(message.end) && + !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high)) + ) + return 'end: integer|Long expected'; + return null; + }; + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto.Segment} Segment + */ + Segment.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto.Segment) return object; + var message = new $root.onnx.TensorProto.Segment(); + if (object.begin != null) + if ($util.Long) (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; + else if (typeof object.begin === 'string') message.begin = parseInt(object.begin, 10); + else if (typeof object.begin === 'number') message.begin = object.begin; + else if (typeof object.begin === 'object') + message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); + if (object.end != null) + if ($util.Long) (message.end = $util.Long.fromValue(object.end)).unsigned = false; + else if (typeof object.end === 'string') message.end = parseInt(object.end, 10); + else if (typeof object.end === 'number') message.end = object.end; + else if (typeof object.end === 'object') + message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.Segment} message Segment + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Segment.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.begin = + options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.begin = options.longs === String ? '0' : 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.end = options.longs === String ? '0' : 0; + } + if (message.begin != null && message.hasOwnProperty('begin')) + if (typeof message.begin === 'number') + object.begin = options.longs === String ? String(message.begin) : message.begin; + else + object.begin = + options.longs === String + ? $util.Long.prototype.toString.call(message.begin) + : options.longs === Number + ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() + : message.begin; + if (message.end != null && message.hasOwnProperty('end')) + if (typeof message.end === 'number') + object.end = options.longs === String ? String(message.end) : message.end; + else + object.end = + options.longs === String + ? $util.Long.prototype.toString.call(message.end) + : options.longs === Number + ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() + : message.end; + return object; + }; + + /** + * Converts this Segment to JSON. + * @function toJSON + * @memberof onnx.TensorProto.Segment + * @instance + * @returns {Object.} JSON object + */ + Segment.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Segment + * @function getTypeUrl + * @memberof onnx.TensorProto.Segment + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorProto.Segment'; + }; + + return Segment; + })(); + + /** + * DataLocation enum. + * @name onnx.TensorProto.DataLocation + * @enum {number} + * @property {number} DEFAULT=0 DEFAULT value + * @property {number} EXTERNAL=1 EXTERNAL value + */ + TensorProto.DataLocation = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = 'DEFAULT')] = 0; + values[(valuesById[1] = 'EXTERNAL')] = 1; + return values; })(); - onnx.ValueInfoProto = (function() { - - /** - * Properties of a ValueInfoProto. - * @memberof onnx - * @interface IValueInfoProto - * @property {string|null} [name] ValueInfoProto name - * @property {onnx.ITypeProto|null} [type] ValueInfoProto type - * @property {string|null} [docString] ValueInfoProto docString - */ - - /** - * Constructs a new ValueInfoProto. - * @memberof onnx - * @classdesc Represents a ValueInfoProto. - * @implements IValueInfoProto - * @constructor - * @param {onnx.IValueInfoProto=} [properties] Properties to set - */ - function ValueInfoProto(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + return TensorProto; + })(); + + onnx.SparseTensorProto = (function () { + /** + * Properties of a SparseTensorProto. + * @memberof onnx + * @interface ISparseTensorProto + * @property {onnx.ITensorProto|null} [values] SparseTensorProto values + * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices + * @property {Array.|null} [dims] SparseTensorProto dims + */ + + /** + * Constructs a new SparseTensorProto. + * @memberof onnx + * @classdesc Represents a SparseTensorProto. + * @implements ISparseTensorProto + * @constructor + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + */ + function SparseTensorProto(properties) { + this.dims = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensorProto values. + * @member {onnx.ITensorProto|null|undefined} values + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.values = null; + + /** + * SparseTensorProto indices. + * @member {onnx.ITensorProto|null|undefined} indices + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.indices = null; + + /** + * SparseTensorProto dims. + * @member {Array.} dims + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.dims = $util.emptyArray; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @function create + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + * @returns {onnx.SparseTensorProto} SparseTensorProto instance + */ + SparseTensorProto.create = function create(properties) { + return new SparseTensorProto(properties); + }; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.values != null && Object.hasOwnProperty.call(message, 'values')) + $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + if (message.indices != null && Object.hasOwnProperty.call(message, 'indices')) + $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 3, wireType 2 =*/ 26).fork(); + for (var i = 0; i < message.dims.length; ++i) writer.int64(message.dims[i]); + writer.ldelim(); + } + return writer; + }; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.SparseTensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.dims && message.dims.length)) message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.dims.push(reader.int64()); + } else message.dims.push(reader.int64()); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * ValueInfoProto name. - * @member {string} name - * @memberof onnx.ValueInfoProto - * @instance - */ - ValueInfoProto.prototype.name = ""; - - /** - * ValueInfoProto type. - * @member {onnx.ITypeProto|null|undefined} type - * @memberof onnx.ValueInfoProto - * @instance - */ - ValueInfoProto.prototype.type = null; - - /** - * ValueInfoProto docString. - * @member {string} docString - * @memberof onnx.ValueInfoProto - * @instance - */ - ValueInfoProto.prototype.docString = ""; - - /** - * Creates a new ValueInfoProto instance using the specified properties. - * @function create - * @memberof onnx.ValueInfoProto - * @static - * @param {onnx.IValueInfoProto=} [properties] Properties to set - * @returns {onnx.ValueInfoProto} ValueInfoProto instance - */ - ValueInfoProto.create = function create(properties) { - return new ValueInfoProto(properties); - }; - - /** - * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. - * @function encode - * @memberof onnx.ValueInfoProto - * @static - * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - ValueInfoProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); - if (message.type != null && Object.hasOwnProperty.call(message, "type")) - $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 3, wireType 2 =*/26).string(message.docString); - return writer; - }; - - /** - * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.ValueInfoProto - * @static - * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a ValueInfoProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.ValueInfoProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.ValueInfoProto} ValueInfoProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - ValueInfoProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ValueInfoProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.name = reader.string(); - break; - } - case 2: { - message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - case 3: { - message.docString = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.ValueInfoProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.ValueInfoProto} ValueInfoProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a ValueInfoProto message. - * @function verify - * @memberof onnx.ValueInfoProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - ValueInfoProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.type != null && message.hasOwnProperty("type")) { - var error = $root.onnx.TypeProto.verify(message.type); - if (error) - return "type." + error; - } - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - return null; - }; - - /** - * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.ValueInfoProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.ValueInfoProto} ValueInfoProto - */ - ValueInfoProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.ValueInfoProto) - return object; - var message = new $root.onnx.ValueInfoProto(); - if (object.name != null) - message.name = String(object.name); - if (object.type != null) { - if (typeof object.type !== "object") - throw TypeError(".onnx.ValueInfoProto.type: object expected"); - message.type = $root.onnx.TypeProto.fromObject(object.type); - } - if (object.docString != null) - message.docString = String(object.docString); - return message; - }; - - /** - * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.ValueInfoProto - * @static - * @param {onnx.ValueInfoProto} message ValueInfoProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - ValueInfoProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.name = ""; - object.type = null; - object.docString = ""; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.type != null && message.hasOwnProperty("type")) - object.type = $root.onnx.TypeProto.toObject(message.type, options); - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - return object; - }; - - /** - * Converts this ValueInfoProto to JSON. - * @function toJSON - * @memberof onnx.ValueInfoProto - * @instance - * @returns {Object.} JSON object - */ - ValueInfoProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for ValueInfoProto - * @function getTypeUrl - * @memberof onnx.ValueInfoProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.ValueInfoProto"; - }; + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; - return ValueInfoProto; - })(); + /** + * Verifies a SparseTensorProto message. + * @function verify + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensorProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.values != null && message.hasOwnProperty('values')) { + var error = $root.onnx.TensorProto.verify(message.values); + if (error) return 'values.' + error; + } + if (message.indices != null && message.hasOwnProperty('indices')) { + var error = $root.onnx.TensorProto.verify(message.indices); + if (error) return 'indices.' + error; + } + if (message.dims != null && message.hasOwnProperty('dims')) { + if (!Array.isArray(message.dims)) return 'dims: array expected'; + for (var i = 0; i < message.dims.length; ++i) + if ( + !$util.isInteger(message.dims[i]) && + !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high)) + ) + return 'dims: integer|Long[] expected'; + } + return null; + }; + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.SparseTensorProto} SparseTensorProto + */ + SparseTensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.SparseTensorProto) return object; + var message = new $root.onnx.SparseTensorProto(); + if (object.values != null) { + if (typeof object.values !== 'object') throw TypeError('.onnx.SparseTensorProto.values: object expected'); + message.values = $root.onnx.TensorProto.fromObject(object.values); + } + if (object.indices != null) { + if (typeof object.indices !== 'object') throw TypeError('.onnx.SparseTensorProto.indices: object expected'); + message.indices = $root.onnx.TensorProto.fromObject(object.indices); + } + if (object.dims) { + if (!Array.isArray(object.dims)) throw TypeError('.onnx.SparseTensorProto.dims: array expected'); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === 'string') message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === 'number') message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === 'object') + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.SparseTensorProto} message SparseTensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensorProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) object.dims = []; + if (options.defaults) { + object.values = null; + object.indices = null; + } + if (message.values != null && message.hasOwnProperty('values')) + object.values = $root.onnx.TensorProto.toObject(message.values, options); + if (message.indices != null && message.hasOwnProperty('indices')) + object.indices = $root.onnx.TensorProto.toObject(message.indices, options); + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === 'number') + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.dims[j]) + : options.longs === Number + ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() + : message.dims[j]; + } + return object; + }; + + /** + * Converts this SparseTensorProto to JSON. + * @function toJSON + * @memberof onnx.SparseTensorProto + * @instance + * @returns {Object.} JSON object + */ + SparseTensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensorProto + * @function getTypeUrl + * @memberof onnx.SparseTensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.SparseTensorProto'; + }; - onnx.NodeProto = (function() { - - /** - * Properties of a NodeProto. - * @memberof onnx - * @interface INodeProto - * @property {Array.|null} [input] NodeProto input - * @property {Array.|null} [output] NodeProto output - * @property {string|null} [name] NodeProto name - * @property {string|null} [opType] NodeProto opType - * @property {string|null} [domain] NodeProto domain - * @property {Array.|null} [attribute] NodeProto attribute - * @property {string|null} [docString] NodeProto docString - */ - - /** - * Constructs a new NodeProto. - * @memberof onnx - * @classdesc Represents a NodeProto. - * @implements INodeProto - * @constructor - * @param {onnx.INodeProto=} [properties] Properties to set - */ - function NodeProto(properties) { - this.input = []; - this.output = []; - this.attribute = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + return SparseTensorProto; + })(); + + onnx.TensorShapeProto = (function () { + /** + * Properties of a TensorShapeProto. + * @memberof onnx + * @interface ITensorShapeProto + * @property {Array.|null} [dim] TensorShapeProto dim + */ + + /** + * Constructs a new TensorShapeProto. + * @memberof onnx + * @classdesc Represents a TensorShapeProto. + * @implements ITensorShapeProto + * @constructor + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + */ + function TensorShapeProto(properties) { + this.dim = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorShapeProto dim. + * @member {Array.} dim + * @memberof onnx.TensorShapeProto + * @instance + */ + TensorShapeProto.prototype.dim = $util.emptyArray; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + * @returns {onnx.TensorShapeProto} TensorShapeProto instance + */ + TensorShapeProto.create = function create(properties) { + return new TensorShapeProto(properties); + }; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.dim != null && message.dim.length) + for (var i = 0; i < message.dim.length; ++i) + $root.onnx.TensorShapeProto.Dimension.encode( + message.dim[i], + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorShapeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dim && message.dim.length)) message.dim = []; + message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * NodeProto input. - * @member {Array.} input - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.input = $util.emptyArray; - - /** - * NodeProto output. - * @member {Array.} output - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.output = $util.emptyArray; - - /** - * NodeProto name. - * @member {string} name - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.name = ""; - - /** - * NodeProto opType. - * @member {string} opType - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.opType = ""; - - /** - * NodeProto domain. - * @member {string} domain - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.domain = ""; - - /** - * NodeProto attribute. - * @member {Array.} attribute - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.attribute = $util.emptyArray; - - /** - * NodeProto docString. - * @member {string} docString - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.docString = ""; - - /** - * Creates a new NodeProto instance using the specified properties. - * @function create - * @memberof onnx.NodeProto - * @static - * @param {onnx.INodeProto=} [properties] Properties to set - * @returns {onnx.NodeProto} NodeProto instance - */ - NodeProto.create = function create(properties) { - return new NodeProto(properties); - }; - - /** - * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. - * @function encode - * @memberof onnx.NodeProto - * @static - * @param {onnx.INodeProto} message NodeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - NodeProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.input != null && message.input.length) - for (var i = 0; i < message.input.length; ++i) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.input[i]); - if (message.output != null && message.output.length) - for (var i = 0; i < message.output.length; ++i) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.output[i]); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 3, wireType 2 =*/26).string(message.name); - if (message.opType != null && Object.hasOwnProperty.call(message, "opType")) - writer.uint32(/* id 4, wireType 2 =*/34).string(message.opType); - if (message.attribute != null && message.attribute.length) - for (var i = 0; i < message.attribute.length; ++i) - $root.onnx.AttributeProto.encode(message.attribute[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); - if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) - writer.uint32(/* id 7, wireType 2 =*/58).string(message.domain); - return writer; - }; - - /** - * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.NodeProto - * @static - * @param {onnx.INodeProto} message NodeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - NodeProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a NodeProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.NodeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.NodeProto} NodeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - NodeProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.NodeProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - if (!(message.input && message.input.length)) - message.input = []; - message.input.push(reader.string()); - break; - } - case 2: { - if (!(message.output && message.output.length)) - message.output = []; - message.output.push(reader.string()); - break; - } - case 3: { - message.name = reader.string(); - break; - } - case 4: { - message.opType = reader.string(); - break; - } - case 7: { - message.domain = reader.string(); - break; - } - case 5: { - if (!(message.attribute && message.attribute.length)) - message.attribute = []; - message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); - break; - } - case 6: { - message.docString = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a NodeProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.NodeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.NodeProto} NodeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - NodeProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a NodeProto message. - * @function verify - * @memberof onnx.NodeProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - NodeProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.input != null && message.hasOwnProperty("input")) { - if (!Array.isArray(message.input)) - return "input: array expected"; - for (var i = 0; i < message.input.length; ++i) - if (!$util.isString(message.input[i])) - return "input: string[] expected"; - } - if (message.output != null && message.hasOwnProperty("output")) { - if (!Array.isArray(message.output)) - return "output: array expected"; - for (var i = 0; i < message.output.length; ++i) - if (!$util.isString(message.output[i])) - return "output: string[] expected"; - } - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.opType != null && message.hasOwnProperty("opType")) - if (!$util.isString(message.opType)) - return "opType: string expected"; - if (message.domain != null && message.hasOwnProperty("domain")) - if (!$util.isString(message.domain)) - return "domain: string expected"; - if (message.attribute != null && message.hasOwnProperty("attribute")) { - if (!Array.isArray(message.attribute)) - return "attribute: array expected"; - for (var i = 0; i < message.attribute.length; ++i) { - var error = $root.onnx.AttributeProto.verify(message.attribute[i]); - if (error) - return "attribute." + error; - } - } - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - return null; - }; - - /** - * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.NodeProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.NodeProto} NodeProto - */ - NodeProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.NodeProto) - return object; - var message = new $root.onnx.NodeProto(); - if (object.input) { - if (!Array.isArray(object.input)) - throw TypeError(".onnx.NodeProto.input: array expected"); - message.input = []; - for (var i = 0; i < object.input.length; ++i) - message.input[i] = String(object.input[i]); - } - if (object.output) { - if (!Array.isArray(object.output)) - throw TypeError(".onnx.NodeProto.output: array expected"); - message.output = []; - for (var i = 0; i < object.output.length; ++i) - message.output[i] = String(object.output[i]); - } - if (object.name != null) - message.name = String(object.name); - if (object.opType != null) - message.opType = String(object.opType); - if (object.domain != null) - message.domain = String(object.domain); - if (object.attribute) { - if (!Array.isArray(object.attribute)) - throw TypeError(".onnx.NodeProto.attribute: array expected"); - message.attribute = []; - for (var i = 0; i < object.attribute.length; ++i) { - if (typeof object.attribute[i] !== "object") - throw TypeError(".onnx.NodeProto.attribute: object expected"); - message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); - } - } - if (object.docString != null) - message.docString = String(object.docString); - return message; - }; - - /** - * Creates a plain object from a NodeProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.NodeProto - * @static - * @param {onnx.NodeProto} message NodeProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - NodeProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.input = []; - object.output = []; - object.attribute = []; - } - if (options.defaults) { - object.name = ""; - object.opType = ""; - object.docString = ""; - object.domain = ""; - } - if (message.input && message.input.length) { - object.input = []; - for (var j = 0; j < message.input.length; ++j) - object.input[j] = message.input[j]; - } - if (message.output && message.output.length) { - object.output = []; - for (var j = 0; j < message.output.length; ++j) - object.output[j] = message.output[j]; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.opType != null && message.hasOwnProperty("opType")) - object.opType = message.opType; - if (message.attribute && message.attribute.length) { - object.attribute = []; - for (var j = 0; j < message.attribute.length; ++j) - object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.domain != null && message.hasOwnProperty("domain")) - object.domain = message.domain; - return object; - }; - - /** - * Converts this NodeProto to JSON. - * @function toJSON - * @memberof onnx.NodeProto - * @instance - * @returns {Object.} JSON object - */ - NodeProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for NodeProto - * @function getTypeUrl - * @memberof onnx.NodeProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.NodeProto"; - }; + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; - return NodeProto; - })(); + /** + * Verifies a TensorShapeProto message. + * @function verify + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorShapeProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.dim != null && message.hasOwnProperty('dim')) { + if (!Array.isArray(message.dim)) return 'dim: array expected'; + for (var i = 0; i < message.dim.length; ++i) { + var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); + if (error) return 'dim.' + error; + } + } + return null; + }; - onnx.TrainingInfoProto = (function() { - - /** - * Properties of a TrainingInfoProto. - * @memberof onnx - * @interface ITrainingInfoProto - * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization - * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm - * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding - * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding - */ - - /** - * Constructs a new TrainingInfoProto. - * @memberof onnx - * @classdesc Represents a TrainingInfoProto. - * @implements ITrainingInfoProto - * @constructor - * @param {onnx.ITrainingInfoProto=} [properties] Properties to set - */ - function TrainingInfoProto(properties) { - this.initializationBinding = []; - this.updateBinding = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto} TensorShapeProto + */ + TensorShapeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto) return object; + var message = new $root.onnx.TensorShapeProto(); + if (object.dim) { + if (!Array.isArray(object.dim)) throw TypeError('.onnx.TensorShapeProto.dim: array expected'); + message.dim = []; + for (var i = 0; i < object.dim.length; ++i) { + if (typeof object.dim[i] !== 'object') throw TypeError('.onnx.TensorShapeProto.dim: object expected'); + message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); } + } + return message; + }; - /** - * TrainingInfoProto initialization. - * @member {onnx.IGraphProto|null|undefined} initialization - * @memberof onnx.TrainingInfoProto - * @instance - */ - TrainingInfoProto.prototype.initialization = null; - - /** - * TrainingInfoProto algorithm. - * @member {onnx.IGraphProto|null|undefined} algorithm - * @memberof onnx.TrainingInfoProto - * @instance - */ - TrainingInfoProto.prototype.algorithm = null; - - /** - * TrainingInfoProto initializationBinding. - * @member {Array.} initializationBinding - * @memberof onnx.TrainingInfoProto - * @instance - */ - TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; - - /** - * TrainingInfoProto updateBinding. - * @member {Array.} updateBinding - * @memberof onnx.TrainingInfoProto - * @instance - */ - TrainingInfoProto.prototype.updateBinding = $util.emptyArray; - - /** - * Creates a new TrainingInfoProto instance using the specified properties. - * @function create - * @memberof onnx.TrainingInfoProto - * @static - * @param {onnx.ITrainingInfoProto=} [properties] Properties to set - * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance - */ - TrainingInfoProto.create = function create(properties) { - return new TrainingInfoProto(properties); - }; - - /** - * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. - * @function encode - * @memberof onnx.TrainingInfoProto - * @static - * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TrainingInfoProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.initialization != null && Object.hasOwnProperty.call(message, "initialization")) - $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - if (message.algorithm != null && Object.hasOwnProperty.call(message, "algorithm")) - $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - if (message.initializationBinding != null && message.initializationBinding.length) - for (var i = 0; i < message.initializationBinding.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.initializationBinding[i], writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); - if (message.updateBinding != null && message.updateBinding.length) - for (var i = 0; i < message.updateBinding.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.updateBinding[i], writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TrainingInfoProto - * @static - * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TrainingInfoProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.TrainingInfoProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TrainingInfoProto} TrainingInfoProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TrainingInfoProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TrainingInfoProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); - break; - } - case 2: { - message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); - break; - } - case 3: { - if (!(message.initializationBinding && message.initializationBinding.length)) - message.initializationBinding = []; - message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - case 4: { - if (!(message.updateBinding && message.updateBinding.length)) - message.updateBinding = []; - message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TrainingInfoProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TrainingInfoProto} TrainingInfoProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TrainingInfoProto message. - * @function verify - * @memberof onnx.TrainingInfoProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TrainingInfoProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.initialization != null && message.hasOwnProperty("initialization")) { - var error = $root.onnx.GraphProto.verify(message.initialization); - if (error) - return "initialization." + error; - } - if (message.algorithm != null && message.hasOwnProperty("algorithm")) { - var error = $root.onnx.GraphProto.verify(message.algorithm); - if (error) - return "algorithm." + error; - } - if (message.initializationBinding != null && message.hasOwnProperty("initializationBinding")) { - if (!Array.isArray(message.initializationBinding)) - return "initializationBinding: array expected"; - for (var i = 0; i < message.initializationBinding.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); - if (error) - return "initializationBinding." + error; - } - } - if (message.updateBinding != null && message.hasOwnProperty("updateBinding")) { - if (!Array.isArray(message.updateBinding)) - return "updateBinding: array expected"; - for (var i = 0; i < message.updateBinding.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); - if (error) - return "updateBinding." + error; - } - } - return null; - }; - - /** - * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TrainingInfoProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.TrainingInfoProto} TrainingInfoProto - */ - TrainingInfoProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TrainingInfoProto) - return object; - var message = new $root.onnx.TrainingInfoProto(); - if (object.initialization != null) { - if (typeof object.initialization !== "object") - throw TypeError(".onnx.TrainingInfoProto.initialization: object expected"); - message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); - } - if (object.algorithm != null) { - if (typeof object.algorithm !== "object") - throw TypeError(".onnx.TrainingInfoProto.algorithm: object expected"); - message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); - } - if (object.initializationBinding) { - if (!Array.isArray(object.initializationBinding)) - throw TypeError(".onnx.TrainingInfoProto.initializationBinding: array expected"); - message.initializationBinding = []; - for (var i = 0; i < object.initializationBinding.length; ++i) { - if (typeof object.initializationBinding[i] !== "object") - throw TypeError(".onnx.TrainingInfoProto.initializationBinding: object expected"); - message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.initializationBinding[i]); - } - } - if (object.updateBinding) { - if (!Array.isArray(object.updateBinding)) - throw TypeError(".onnx.TrainingInfoProto.updateBinding: array expected"); - message.updateBinding = []; - for (var i = 0; i < object.updateBinding.length; ++i) { - if (typeof object.updateBinding[i] !== "object") - throw TypeError(".onnx.TrainingInfoProto.updateBinding: object expected"); - message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TrainingInfoProto - * @static - * @param {onnx.TrainingInfoProto} message TrainingInfoProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TrainingInfoProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.initializationBinding = []; - object.updateBinding = []; - } - if (options.defaults) { - object.initialization = null; - object.algorithm = null; - } - if (message.initialization != null && message.hasOwnProperty("initialization")) - object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); - if (message.algorithm != null && message.hasOwnProperty("algorithm")) - object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); - if (message.initializationBinding && message.initializationBinding.length) { - object.initializationBinding = []; - for (var j = 0; j < message.initializationBinding.length; ++j) - object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.initializationBinding[j], options); - } - if (message.updateBinding && message.updateBinding.length) { - object.updateBinding = []; - for (var j = 0; j < message.updateBinding.length; ++j) - object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); - } - return object; - }; - - /** - * Converts this TrainingInfoProto to JSON. - * @function toJSON - * @memberof onnx.TrainingInfoProto - * @instance - * @returns {Object.} JSON object - */ - TrainingInfoProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TrainingInfoProto - * @function getTypeUrl - * @memberof onnx.TrainingInfoProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.TensorShapeProto} message TensorShapeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorShapeProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) object.dim = []; + if (message.dim && message.dim.length) { + object.dim = []; + for (var j = 0; j < message.dim.length; ++j) + object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); + } + return object; + }; + + /** + * Converts this TensorShapeProto to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto + * @instance + * @returns {Object.} JSON object + */ + TensorShapeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorShapeProto + * @function getTypeUrl + * @memberof onnx.TensorShapeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorShapeProto'; + }; + + TensorShapeProto.Dimension = (function () { + /** + * Properties of a Dimension. + * @memberof onnx.TensorShapeProto + * @interface IDimension + * @property {number|Long|null} [dimValue] Dimension dimValue + * @property {string|null} [dimParam] Dimension dimParam + * @property {string|null} [denotation] Dimension denotation + */ + + /** + * Constructs a new Dimension. + * @memberof onnx.TensorShapeProto + * @classdesc Represents a Dimension. + * @implements IDimension + * @constructor + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + */ + function Dimension(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Dimension dimValue. + * @member {number|Long|null|undefined} dimValue + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimValue = null; + + /** + * Dimension dimParam. + * @member {string|null|undefined} dimParam + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimParam = null; + + /** + * Dimension denotation. + * @member {string} denotation + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.denotation = ''; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * Dimension value. + * @member {"dimValue"|"dimParam"|undefined} value + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Object.defineProperty(Dimension.prototype, 'value', { + get: $util.oneOfGetter(($oneOfFields = ['dimValue', 'dimParam'])), + set: $util.oneOfSetter($oneOfFields), + }); + + /** + * Creates a new Dimension instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + * @returns {onnx.TensorShapeProto.Dimension} Dimension instance + */ + Dimension.create = function create(properties) { + return new Dimension(properties); + }; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.dimValue != null && Object.hasOwnProperty.call(message, 'dimValue')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int64(message.dimValue); + if (message.dimParam != null && Object.hasOwnProperty.call(message, 'dimParam')) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.dimParam); + if (message.denotation != null && Object.hasOwnProperty.call(message, 'denotation')) + writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.denotation); + return writer; + }; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorShapeProto.Dimension(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.dimValue = reader.int64(); + break; + } + case 2: { + message.dimParam = reader.string(); + break; + } + case 3: { + message.denotation = reader.string(); + break; } - return typeUrlPrefix + "/onnx.TrainingInfoProto"; - }; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Dimension message. + * @function verify + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Dimension.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + var properties = {}; + if (message.dimValue != null && message.hasOwnProperty('dimValue')) { + properties.value = 1; + if ( + !$util.isInteger(message.dimValue) && + !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high)) + ) + return 'dimValue: integer|Long expected'; + } + if (message.dimParam != null && message.hasOwnProperty('dimParam')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + if (!$util.isString(message.dimParam)) return 'dimParam: string expected'; + } + if (message.denotation != null && message.hasOwnProperty('denotation')) + if (!$util.isString(message.denotation)) return 'denotation: string expected'; + return null; + }; + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto.Dimension} Dimension + */ + Dimension.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto.Dimension) return object; + var message = new $root.onnx.TensorShapeProto.Dimension(); + if (object.dimValue != null) + if ($util.Long) (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; + else if (typeof object.dimValue === 'string') message.dimValue = parseInt(object.dimValue, 10); + else if (typeof object.dimValue === 'number') message.dimValue = object.dimValue; + else if (typeof object.dimValue === 'object') + message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); + if (object.dimParam != null) message.dimParam = String(object.dimParam); + if (object.denotation != null) message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.Dimension} message Dimension + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Dimension.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) object.denotation = ''; + if (message.dimValue != null && message.hasOwnProperty('dimValue')) { + if (typeof message.dimValue === 'number') + object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; + else + object.dimValue = + options.longs === String + ? $util.Long.prototype.toString.call(message.dimValue) + : options.longs === Number + ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() + : message.dimValue; + if (options.oneofs) object.value = 'dimValue'; + } + if (message.dimParam != null && message.hasOwnProperty('dimParam')) { + object.dimParam = message.dimParam; + if (options.oneofs) object.value = 'dimParam'; + } + if (message.denotation != null && message.hasOwnProperty('denotation')) object.denotation = message.denotation; + return object; + }; + + /** + * Converts this Dimension to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto.Dimension + * @instance + * @returns {Object.} JSON object + */ + Dimension.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Dimension + * @function getTypeUrl + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorShapeProto.Dimension'; + }; - return TrainingInfoProto; + return Dimension; })(); - onnx.ModelProto = (function() { - - /** - * Properties of a ModelProto. - * @memberof onnx - * @interface IModelProto - * @property {number|Long|null} [irVersion] ModelProto irVersion - * @property {Array.|null} [opsetImport] ModelProto opsetImport - * @property {string|null} [producerName] ModelProto producerName - * @property {string|null} [producerVersion] ModelProto producerVersion - * @property {string|null} [domain] ModelProto domain - * @property {number|Long|null} [modelVersion] ModelProto modelVersion - * @property {string|null} [docString] ModelProto docString - * @property {onnx.IGraphProto|null} [graph] ModelProto graph - * @property {Array.|null} [metadataProps] ModelProto metadataProps - * @property {Array.|null} [trainingInfo] ModelProto trainingInfo - * @property {Array.|null} [functions] ModelProto functions - */ - - /** - * Constructs a new ModelProto. - * @memberof onnx - * @classdesc Represents a ModelProto. - * @implements IModelProto - * @constructor - * @param {onnx.IModelProto=} [properties] Properties to set - */ - function ModelProto(properties) { - this.opsetImport = []; - this.metadataProps = []; - this.trainingInfo = []; - this.functions = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + return TensorShapeProto; + })(); + + onnx.TypeProto = (function () { + /** + * Properties of a TypeProto. + * @memberof onnx + * @interface ITypeProto + * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType + * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType + * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType + * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType + * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType + * @property {string|null} [denotation] TypeProto denotation + */ + + /** + * Constructs a new TypeProto. + * @memberof onnx + * @classdesc Represents a TypeProto. + * @implements ITypeProto + * @constructor + * @param {onnx.ITypeProto=} [properties] Properties to set + */ + function TypeProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TypeProto tensorType. + * @member {onnx.TypeProto.ITensor|null|undefined} tensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.tensorType = null; + + /** + * TypeProto sequenceType. + * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sequenceType = null; + + /** + * TypeProto mapType. + * @member {onnx.TypeProto.IMap|null|undefined} mapType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.mapType = null; + + /** + * TypeProto optionalType. + * @member {onnx.TypeProto.IOptional|null|undefined} optionalType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.optionalType = null; + + /** + * TypeProto sparseTensorType. + * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sparseTensorType = null; + + /** + * TypeProto denotation. + * @member {string} denotation + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.denotation = ''; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * TypeProto value. + * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value + * @memberof onnx.TypeProto + * @instance + */ + Object.defineProperty(TypeProto.prototype, 'value', { + get: $util.oneOfGetter( + ($oneOfFields = ['tensorType', 'sequenceType', 'mapType', 'optionalType', 'sparseTensorType']), + ), + set: $util.oneOfSetter($oneOfFields), + }); + + /** + * Creates a new TypeProto instance using the specified properties. + * @function create + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto=} [properties] Properties to set + * @returns {onnx.TypeProto} TypeProto instance + */ + TypeProto.create = function create(properties) { + return new TypeProto(properties); + }; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.tensorType != null && Object.hasOwnProperty.call(message, 'tensorType')) + $root.onnx.TypeProto.Tensor.encode( + message.tensorType, + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), + ).ldelim(); + if (message.sequenceType != null && Object.hasOwnProperty.call(message, 'sequenceType')) + $root.onnx.TypeProto.Sequence.encode( + message.sequenceType, + writer.uint32(/* id 4, wireType 2 =*/ 34).fork(), + ).ldelim(); + if (message.mapType != null && Object.hasOwnProperty.call(message, 'mapType')) + $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/ 42).fork()).ldelim(); + if (message.denotation != null && Object.hasOwnProperty.call(message, 'denotation')) + writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.denotation); + if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, 'sparseTensorType')) + $root.onnx.TypeProto.SparseTensor.encode( + message.sparseTensorType, + writer.uint32(/* id 8, wireType 2 =*/ 66).fork(), + ).ldelim(); + if (message.optionalType != null && Object.hasOwnProperty.call(message, 'optionalType')) + $root.onnx.TypeProto.Optional.encode( + message.optionalType, + writer.uint32(/* id 9, wireType 2 =*/ 74).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); + break; + } + case 4: { + message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); + break; + } + case 5: { + message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); + break; + } + case 9: { + message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); + break; + } + case 8: { + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); + break; + } + case 6: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * ModelProto irVersion. - * @member {number|Long} irVersion - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * ModelProto opsetImport. - * @member {Array.} opsetImport - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.opsetImport = $util.emptyArray; - - /** - * ModelProto producerName. - * @member {string} producerName - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.producerName = ""; - - /** - * ModelProto producerVersion. - * @member {string} producerVersion - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.producerVersion = ""; - - /** - * ModelProto domain. - * @member {string} domain - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.domain = ""; - - /** - * ModelProto modelVersion. - * @member {number|Long} modelVersion - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * ModelProto docString. - * @member {string} docString - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.docString = ""; - - /** - * ModelProto graph. - * @member {onnx.IGraphProto|null|undefined} graph - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.graph = null; - - /** - * ModelProto metadataProps. - * @member {Array.} metadataProps - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.metadataProps = $util.emptyArray; - - /** - * ModelProto trainingInfo. - * @member {Array.} trainingInfo - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.trainingInfo = $util.emptyArray; - - /** - * ModelProto functions. - * @member {Array.} functions - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.functions = $util.emptyArray; - - /** - * Creates a new ModelProto instance using the specified properties. - * @function create - * @memberof onnx.ModelProto - * @static - * @param {onnx.IModelProto=} [properties] Properties to set - * @returns {onnx.ModelProto} ModelProto instance - */ - ModelProto.create = function create(properties) { - return new ModelProto(properties); - }; - - /** - * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. - * @function encode - * @memberof onnx.ModelProto - * @static - * @param {onnx.IModelProto} message ModelProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - ModelProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.irVersion != null && Object.hasOwnProperty.call(message, "irVersion")) - writer.uint32(/* id 1, wireType 0 =*/8).int64(message.irVersion); - if (message.producerName != null && Object.hasOwnProperty.call(message, "producerName")) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.producerName); - if (message.producerVersion != null && Object.hasOwnProperty.call(message, "producerVersion")) - writer.uint32(/* id 3, wireType 2 =*/26).string(message.producerVersion); - if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) - writer.uint32(/* id 4, wireType 2 =*/34).string(message.domain); - if (message.modelVersion != null && Object.hasOwnProperty.call(message, "modelVersion")) - writer.uint32(/* id 5, wireType 0 =*/40).int64(message.modelVersion); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); - if (message.graph != null && Object.hasOwnProperty.call(message, "graph")) - $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); - if (message.opsetImport != null && message.opsetImport.length) - for (var i = 0; i < message.opsetImport.length; ++i) - $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); - if (message.metadataProps != null && message.metadataProps.length) - for (var i = 0; i < message.metadataProps.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.metadataProps[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); - if (message.trainingInfo != null && message.trainingInfo.length) - for (var i = 0; i < message.trainingInfo.length; ++i) - $root.onnx.TrainingInfoProto.encode(message.trainingInfo[i], writer.uint32(/* id 20, wireType 2 =*/162).fork()).ldelim(); - if (message.functions != null && message.functions.length) - for (var i = 0; i < message.functions.length; ++i) - $root.onnx.FunctionProto.encode(message.functions[i], writer.uint32(/* id 25, wireType 2 =*/202).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.ModelProto - * @static - * @param {onnx.IModelProto} message ModelProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - ModelProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a ModelProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.ModelProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.ModelProto} ModelProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - ModelProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ModelProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.irVersion = reader.int64(); - break; - } - case 8: { - if (!(message.opsetImport && message.opsetImport.length)) - message.opsetImport = []; - message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); - break; - } - case 2: { - message.producerName = reader.string(); - break; - } - case 3: { - message.producerVersion = reader.string(); - break; - } - case 4: { - message.domain = reader.string(); - break; - } - case 5: { - message.modelVersion = reader.int64(); - break; - } - case 6: { - message.docString = reader.string(); - break; - } - case 7: { - message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); - break; - } - case 14: { - if (!(message.metadataProps && message.metadataProps.length)) - message.metadataProps = []; - message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - case 20: { - if (!(message.trainingInfo && message.trainingInfo.length)) - message.trainingInfo = []; - message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); - break; - } - case 25: { - if (!(message.functions && message.functions.length)) - message.functions = []; - message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a ModelProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.ModelProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.ModelProto} ModelProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - ModelProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a ModelProto message. - * @function verify - * @memberof onnx.ModelProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - ModelProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.irVersion != null && message.hasOwnProperty("irVersion")) - if (!$util.isInteger(message.irVersion) && !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high))) - return "irVersion: integer|Long expected"; - if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { - if (!Array.isArray(message.opsetImport)) - return "opsetImport: array expected"; - for (var i = 0; i < message.opsetImport.length; ++i) { - var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); - if (error) - return "opsetImport." + error; - } - } - if (message.producerName != null && message.hasOwnProperty("producerName")) - if (!$util.isString(message.producerName)) - return "producerName: string expected"; - if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) - if (!$util.isString(message.producerVersion)) - return "producerVersion: string expected"; - if (message.domain != null && message.hasOwnProperty("domain")) - if (!$util.isString(message.domain)) - return "domain: string expected"; - if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) - if (!$util.isInteger(message.modelVersion) && !(message.modelVersion && $util.isInteger(message.modelVersion.low) && $util.isInteger(message.modelVersion.high))) - return "modelVersion: integer|Long expected"; - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.graph != null && message.hasOwnProperty("graph")) { - var error = $root.onnx.GraphProto.verify(message.graph); - if (error) - return "graph." + error; - } - if (message.metadataProps != null && message.hasOwnProperty("metadataProps")) { - if (!Array.isArray(message.metadataProps)) - return "metadataProps: array expected"; - for (var i = 0; i < message.metadataProps.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); - if (error) - return "metadataProps." + error; - } - } - if (message.trainingInfo != null && message.hasOwnProperty("trainingInfo")) { - if (!Array.isArray(message.trainingInfo)) - return "trainingInfo: array expected"; - for (var i = 0; i < message.trainingInfo.length; ++i) { - var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); - if (error) - return "trainingInfo." + error; - } - } - if (message.functions != null && message.hasOwnProperty("functions")) { - if (!Array.isArray(message.functions)) - return "functions: array expected"; - for (var i = 0; i < message.functions.length; ++i) { - var error = $root.onnx.FunctionProto.verify(message.functions[i]); - if (error) - return "functions." + error; - } - } - return null; - }; - - /** - * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.ModelProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.ModelProto} ModelProto - */ - ModelProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.ModelProto) - return object; - var message = new $root.onnx.ModelProto(); - if (object.irVersion != null) - if ($util.Long) - (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; - else if (typeof object.irVersion === "string") - message.irVersion = parseInt(object.irVersion, 10); - else if (typeof object.irVersion === "number") - message.irVersion = object.irVersion; - else if (typeof object.irVersion === "object") - message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); - if (object.opsetImport) { - if (!Array.isArray(object.opsetImport)) - throw TypeError(".onnx.ModelProto.opsetImport: array expected"); - message.opsetImport = []; - for (var i = 0; i < object.opsetImport.length; ++i) { - if (typeof object.opsetImport[i] !== "object") - throw TypeError(".onnx.ModelProto.opsetImport: object expected"); - message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); - } - } - if (object.producerName != null) - message.producerName = String(object.producerName); - if (object.producerVersion != null) - message.producerVersion = String(object.producerVersion); - if (object.domain != null) - message.domain = String(object.domain); - if (object.modelVersion != null) - if ($util.Long) - (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; - else if (typeof object.modelVersion === "string") - message.modelVersion = parseInt(object.modelVersion, 10); - else if (typeof object.modelVersion === "number") - message.modelVersion = object.modelVersion; - else if (typeof object.modelVersion === "object") - message.modelVersion = new $util.LongBits(object.modelVersion.low >>> 0, object.modelVersion.high >>> 0).toNumber(); - if (object.docString != null) - message.docString = String(object.docString); - if (object.graph != null) { - if (typeof object.graph !== "object") - throw TypeError(".onnx.ModelProto.graph: object expected"); - message.graph = $root.onnx.GraphProto.fromObject(object.graph); - } - if (object.metadataProps) { - if (!Array.isArray(object.metadataProps)) - throw TypeError(".onnx.ModelProto.metadataProps: array expected"); - message.metadataProps = []; - for (var i = 0; i < object.metadataProps.length; ++i) { - if (typeof object.metadataProps[i] !== "object") - throw TypeError(".onnx.ModelProto.metadataProps: object expected"); - message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); - } - } - if (object.trainingInfo) { - if (!Array.isArray(object.trainingInfo)) - throw TypeError(".onnx.ModelProto.trainingInfo: array expected"); - message.trainingInfo = []; - for (var i = 0; i < object.trainingInfo.length; ++i) { - if (typeof object.trainingInfo[i] !== "object") - throw TypeError(".onnx.ModelProto.trainingInfo: object expected"); - message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); - } - } - if (object.functions) { - if (!Array.isArray(object.functions)) - throw TypeError(".onnx.ModelProto.functions: array expected"); - message.functions = []; - for (var i = 0; i < object.functions.length; ++i) { - if (typeof object.functions[i] !== "object") - throw TypeError(".onnx.ModelProto.functions: object expected"); - message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a ModelProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.ModelProto - * @static - * @param {onnx.ModelProto} message ModelProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - ModelProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.opsetImport = []; - object.metadataProps = []; - object.trainingInfo = []; - object.functions = []; - } - if (options.defaults) { - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.irVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.irVersion = options.longs === String ? "0" : 0; - object.producerName = ""; - object.producerVersion = ""; - object.domain = ""; - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.modelVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.modelVersion = options.longs === String ? "0" : 0; - object.docString = ""; - object.graph = null; - } - if (message.irVersion != null && message.hasOwnProperty("irVersion")) - if (typeof message.irVersion === "number") - object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; - else - object.irVersion = options.longs === String ? $util.Long.prototype.toString.call(message.irVersion) : options.longs === Number ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() : message.irVersion; - if (message.producerName != null && message.hasOwnProperty("producerName")) - object.producerName = message.producerName; - if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) - object.producerVersion = message.producerVersion; - if (message.domain != null && message.hasOwnProperty("domain")) - object.domain = message.domain; - if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) - if (typeof message.modelVersion === "number") - object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; - else - object.modelVersion = options.longs === String ? $util.Long.prototype.toString.call(message.modelVersion) : options.longs === Number ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() : message.modelVersion; - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.graph != null && message.hasOwnProperty("graph")) - object.graph = $root.onnx.GraphProto.toObject(message.graph, options); - if (message.opsetImport && message.opsetImport.length) { - object.opsetImport = []; - for (var j = 0; j < message.opsetImport.length; ++j) - object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); - } - if (message.metadataProps && message.metadataProps.length) { - object.metadataProps = []; - for (var j = 0; j < message.metadataProps.length; ++j) - object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); - } - if (message.trainingInfo && message.trainingInfo.length) { - object.trainingInfo = []; - for (var j = 0; j < message.trainingInfo.length; ++j) - object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); - } - if (message.functions && message.functions.length) { - object.functions = []; - for (var j = 0; j < message.functions.length; ++j) - object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TypeProto message. + * @function verify + * @memberof onnx.TypeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TypeProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + var properties = {}; + if (message.tensorType != null && message.hasOwnProperty('tensorType')) { + properties.value = 1; + { + var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); + if (error) return 'tensorType.' + error; + } + } + if (message.sequenceType != null && message.hasOwnProperty('sequenceType')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); + if (error) return 'sequenceType.' + error; + } + } + if (message.mapType != null && message.hasOwnProperty('mapType')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Map.verify(message.mapType); + if (error) return 'mapType.' + error; + } + } + if (message.optionalType != null && message.hasOwnProperty('optionalType')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); + if (error) return 'optionalType.' + error; + } + } + if (message.sparseTensorType != null && message.hasOwnProperty('sparseTensorType')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + { + var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); + if (error) return 'sparseTensorType.' + error; + } + } + if (message.denotation != null && message.hasOwnProperty('denotation')) + if (!$util.isString(message.denotation)) return 'denotation: string expected'; + return null; + }; + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto} TypeProto + */ + TypeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto) return object; + var message = new $root.onnx.TypeProto(); + if (object.tensorType != null) { + if (typeof object.tensorType !== 'object') throw TypeError('.onnx.TypeProto.tensorType: object expected'); + message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); + } + if (object.sequenceType != null) { + if (typeof object.sequenceType !== 'object') throw TypeError('.onnx.TypeProto.sequenceType: object expected'); + message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); + } + if (object.mapType != null) { + if (typeof object.mapType !== 'object') throw TypeError('.onnx.TypeProto.mapType: object expected'); + message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); + } + if (object.optionalType != null) { + if (typeof object.optionalType !== 'object') throw TypeError('.onnx.TypeProto.optionalType: object expected'); + message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); + } + if (object.sparseTensorType != null) { + if (typeof object.sparseTensorType !== 'object') + throw TypeError('.onnx.TypeProto.sparseTensorType: object expected'); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); + } + if (object.denotation != null) message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto + * @static + * @param {onnx.TypeProto} message TypeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TypeProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) object.denotation = ''; + if (message.tensorType != null && message.hasOwnProperty('tensorType')) { + object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); + if (options.oneofs) object.value = 'tensorType'; + } + if (message.sequenceType != null && message.hasOwnProperty('sequenceType')) { + object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); + if (options.oneofs) object.value = 'sequenceType'; + } + if (message.mapType != null && message.hasOwnProperty('mapType')) { + object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); + if (options.oneofs) object.value = 'mapType'; + } + if (message.denotation != null && message.hasOwnProperty('denotation')) object.denotation = message.denotation; + if (message.sparseTensorType != null && message.hasOwnProperty('sparseTensorType')) { + object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); + if (options.oneofs) object.value = 'sparseTensorType'; + } + if (message.optionalType != null && message.hasOwnProperty('optionalType')) { + object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); + if (options.oneofs) object.value = 'optionalType'; + } + return object; + }; + + /** + * Converts this TypeProto to JSON. + * @function toJSON + * @memberof onnx.TypeProto + * @instance + * @returns {Object.} JSON object + */ + TypeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TypeProto + * @function getTypeUrl + * @memberof onnx.TypeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto'; + }; + + TypeProto.Tensor = (function () { + /** + * Properties of a Tensor. + * @memberof onnx.TypeProto + * @interface ITensor + * @property {number|null} [elemType] Tensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape + */ + + /** + * Constructs a new Tensor. + * @memberof onnx.TypeProto + * @classdesc Represents a Tensor. + * @implements ITensor + * @constructor + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + */ + function Tensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Tensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.elemType = 0; + + /** + * Tensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.shape = null; + + /** + * Creates a new Tensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + * @returns {onnx.TypeProto.Tensor} Tensor instance + */ + Tensor.create = function create(properties) { + return new Tensor(properties); + }; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, 'shape')) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.Tensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; } - return object; - }; - - /** - * Converts this ModelProto to JSON. - * @function toJSON - * @memberof onnx.ModelProto - * @instance - * @returns {Object.} JSON object - */ - ModelProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for ModelProto - * @function getTypeUrl - * @memberof onnx.ModelProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Tensor message. + * @function verify + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Tensor.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.elemType != null && message.hasOwnProperty('elemType')) + if (!$util.isInteger(message.elemType)) return 'elemType: integer expected'; + if (message.shape != null && message.hasOwnProperty('shape')) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) return 'shape.' + error; + } + return null; + }; + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Tensor} Tensor + */ + Tensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Tensor) return object; + var message = new $root.onnx.TypeProto.Tensor(); + if (object.elemType != null) message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== 'object') throw TypeError('.onnx.TypeProto.Tensor.shape: object expected'); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.Tensor} message Tensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Tensor.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty('elemType')) object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty('shape')) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this Tensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Tensor + * @instance + * @returns {Object.} JSON object + */ + Tensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Tensor + * @function getTypeUrl + * @memberof onnx.TypeProto.Tensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.Tensor'; + }; + + return Tensor; + })(); + + TypeProto.Sequence = (function () { + /** + * Properties of a Sequence. + * @memberof onnx.TypeProto + * @interface ISequence + * @property {onnx.ITypeProto|null} [elemType] Sequence elemType + */ + + /** + * Constructs a new Sequence. + * @memberof onnx.TypeProto + * @classdesc Represents a Sequence. + * @implements ISequence + * @constructor + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + */ + function Sequence(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Sequence elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Sequence + * @instance + */ + Sequence.prototype.elemType = null; + + /** + * Creates a new Sequence instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + * @returns {onnx.TypeProto.Sequence} Sequence instance + */ + Sequence.create = function create(properties) { + return new Sequence(properties); + }; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.Sequence(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; } - return typeUrlPrefix + "/onnx.ModelProto"; - }; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Sequence message. + * @function verify + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Sequence.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.elemType != null && message.hasOwnProperty('elemType')) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) return 'elemType.' + error; + } + return null; + }; + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Sequence} Sequence + */ + Sequence.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Sequence) return object; + var message = new $root.onnx.TypeProto.Sequence(); + if (object.elemType != null) { + if (typeof object.elemType !== 'object') + throw TypeError('.onnx.TypeProto.Sequence.elemType: object expected'); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.Sequence} message Sequence + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Sequence.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) object.elemType = null; + if (message.elemType != null && message.hasOwnProperty('elemType')) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Sequence to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Sequence + * @instance + * @returns {Object.} JSON object + */ + Sequence.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Sequence + * @function getTypeUrl + * @memberof onnx.TypeProto.Sequence + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.Sequence'; + }; - return ModelProto; + return Sequence; })(); - onnx.StringStringEntryProto = (function() { - - /** - * Properties of a StringStringEntryProto. - * @memberof onnx - * @interface IStringStringEntryProto - * @property {string|null} [key] StringStringEntryProto key - * @property {string|null} [value] StringStringEntryProto value - */ - - /** - * Constructs a new StringStringEntryProto. - * @memberof onnx - * @classdesc Represents a StringStringEntryProto. - * @implements IStringStringEntryProto - * @constructor - * @param {onnx.IStringStringEntryProto=} [properties] Properties to set - */ - function StringStringEntryProto(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } - - /** - * StringStringEntryProto key. - * @member {string} key - * @memberof onnx.StringStringEntryProto - * @instance - */ - StringStringEntryProto.prototype.key = ""; - - /** - * StringStringEntryProto value. - * @member {string} value - * @memberof onnx.StringStringEntryProto - * @instance - */ - StringStringEntryProto.prototype.value = ""; - - /** - * Creates a new StringStringEntryProto instance using the specified properties. - * @function create - * @memberof onnx.StringStringEntryProto - * @static - * @param {onnx.IStringStringEntryProto=} [properties] Properties to set - * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance - */ - StringStringEntryProto.create = function create(properties) { - return new StringStringEntryProto(properties); - }; - - /** - * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. - * @function encode - * @memberof onnx.StringStringEntryProto - * @static - * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - StringStringEntryProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.key != null && Object.hasOwnProperty.call(message, "key")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.key); - if (message.value != null && Object.hasOwnProperty.call(message, "value")) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.value); - return writer; - }; - - /** - * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.StringStringEntryProto - * @static - * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a StringStringEntryProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.StringStringEntryProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.StringStringEntryProto} StringStringEntryProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - StringStringEntryProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.StringStringEntryProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.key = reader.string(); - break; - } - case 2: { - message.value = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } + TypeProto.Map = (function () { + /** + * Properties of a Map. + * @memberof onnx.TypeProto + * @interface IMap + * @property {number|null} [keyType] Map keyType + * @property {onnx.ITypeProto|null} [valueType] Map valueType + */ + + /** + * Constructs a new Map. + * @memberof onnx.TypeProto + * @classdesc Represents a Map. + * @implements IMap + * @constructor + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + */ + function Map(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Map keyType. + * @member {number} keyType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.keyType = 0; + + /** + * Map valueType. + * @member {onnx.ITypeProto|null|undefined} valueType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.valueType = null; + + /** + * Creates a new Map instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + * @returns {onnx.TypeProto.Map} Map instance + */ + Map.create = function create(properties) { + return new Map(properties); + }; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.keyType != null && Object.hasOwnProperty.call(message, 'keyType')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int32(message.keyType); + if (message.valueType != null && Object.hasOwnProperty.call(message, 'valueType')) + $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Map message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.Map(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.keyType = reader.int32(); + break; + } + case 2: { + message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; } - return message; - }; - - /** - * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.StringStringEntryProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.StringStringEntryProto} StringStringEntryProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a StringStringEntryProto message. - * @function verify - * @memberof onnx.StringStringEntryProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - StringStringEntryProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.key != null && message.hasOwnProperty("key")) - if (!$util.isString(message.key)) - return "key: string expected"; - if (message.value != null && message.hasOwnProperty("value")) - if (!$util.isString(message.value)) - return "value: string expected"; - return null; - }; - - /** - * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.StringStringEntryProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.StringStringEntryProto} StringStringEntryProto - */ - StringStringEntryProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.StringStringEntryProto) - return object; - var message = new $root.onnx.StringStringEntryProto(); - if (object.key != null) - message.key = String(object.key); - if (object.value != null) - message.value = String(object.value); - return message; - }; - - /** - * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.StringStringEntryProto - * @static - * @param {onnx.StringStringEntryProto} message StringStringEntryProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - StringStringEntryProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.key = ""; - object.value = ""; - } - if (message.key != null && message.hasOwnProperty("key")) - object.key = message.key; - if (message.value != null && message.hasOwnProperty("value")) - object.value = message.value; - return object; - }; - - /** - * Converts this StringStringEntryProto to JSON. - * @function toJSON - * @memberof onnx.StringStringEntryProto - * @instance - * @returns {Object.} JSON object - */ - StringStringEntryProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for StringStringEntryProto - * @function getTypeUrl - * @memberof onnx.StringStringEntryProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.StringStringEntryProto"; - }; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Map message. + * @function verify + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Map.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.keyType != null && message.hasOwnProperty('keyType')) + if (!$util.isInteger(message.keyType)) return 'keyType: integer expected'; + if (message.valueType != null && message.hasOwnProperty('valueType')) { + var error = $root.onnx.TypeProto.verify(message.valueType); + if (error) return 'valueType.' + error; + } + return null; + }; + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Map} Map + */ + Map.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Map) return object; + var message = new $root.onnx.TypeProto.Map(); + if (object.keyType != null) message.keyType = object.keyType | 0; + if (object.valueType != null) { + if (typeof object.valueType !== 'object') throw TypeError('.onnx.TypeProto.Map.valueType: object expected'); + message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); + } + return message; + }; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.Map} message Map + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Map.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.keyType = 0; + object.valueType = null; + } + if (message.keyType != null && message.hasOwnProperty('keyType')) object.keyType = message.keyType; + if (message.valueType != null && message.hasOwnProperty('valueType')) + object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); + return object; + }; + + /** + * Converts this Map to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Map + * @instance + * @returns {Object.} JSON object + */ + Map.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Map + * @function getTypeUrl + * @memberof onnx.TypeProto.Map + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.Map'; + }; - return StringStringEntryProto; + return Map; })(); - onnx.TensorAnnotation = (function() { - - /** - * Properties of a TensorAnnotation. - * @memberof onnx - * @interface ITensorAnnotation - * @property {string|null} [tensorName] TensorAnnotation tensorName - * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames - */ - - /** - * Constructs a new TensorAnnotation. - * @memberof onnx - * @classdesc Represents a TensorAnnotation. - * @implements ITensorAnnotation - * @constructor - * @param {onnx.ITensorAnnotation=} [properties] Properties to set - */ - function TensorAnnotation(properties) { - this.quantParameterTensorNames = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + TypeProto.Optional = (function () { + /** + * Properties of an Optional. + * @memberof onnx.TypeProto + * @interface IOptional + * @property {onnx.ITypeProto|null} [elemType] Optional elemType + */ + + /** + * Constructs a new Optional. + * @memberof onnx.TypeProto + * @classdesc Represents an Optional. + * @implements IOptional + * @constructor + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + */ + function Optional(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Optional elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Optional + * @instance + */ + Optional.prototype.elemType = null; + + /** + * Creates a new Optional instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + * @returns {onnx.TypeProto.Optional} Optional instance + */ + Optional.create = function create(properties) { + return new Optional(properties); + }; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.Optional(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an Optional message. + * @function verify + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Optional.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.elemType != null && message.hasOwnProperty('elemType')) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) return 'elemType.' + error; + } + return null; + }; + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Optional} Optional + */ + Optional.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Optional) return object; + var message = new $root.onnx.TypeProto.Optional(); + if (object.elemType != null) { + if (typeof object.elemType !== 'object') + throw TypeError('.onnx.TypeProto.Optional.elemType: object expected'); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); } + return message; + }; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.Optional} message Optional + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Optional.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) object.elemType = null; + if (message.elemType != null && message.hasOwnProperty('elemType')) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Optional to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Optional + * @instance + * @returns {Object.} JSON object + */ + Optional.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Optional + * @function getTypeUrl + * @memberof onnx.TypeProto.Optional + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.Optional'; + }; - /** - * TensorAnnotation tensorName. - * @member {string} tensorName - * @memberof onnx.TensorAnnotation - * @instance - */ - TensorAnnotation.prototype.tensorName = ""; - - /** - * TensorAnnotation quantParameterTensorNames. - * @member {Array.} quantParameterTensorNames - * @memberof onnx.TensorAnnotation - * @instance - */ - TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; - - /** - * Creates a new TensorAnnotation instance using the specified properties. - * @function create - * @memberof onnx.TensorAnnotation - * @static - * @param {onnx.ITensorAnnotation=} [properties] Properties to set - * @returns {onnx.TensorAnnotation} TensorAnnotation instance - */ - TensorAnnotation.create = function create(properties) { - return new TensorAnnotation(properties); - }; - - /** - * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. - * @function encode - * @memberof onnx.TensorAnnotation - * @static - * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorAnnotation.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.tensorName != null && Object.hasOwnProperty.call(message, "tensorName")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.tensorName); - if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) - for (var i = 0; i < message.quantParameterTensorNames.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.quantParameterTensorNames[i], writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorAnnotation - * @static - * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TensorAnnotation message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorAnnotation - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorAnnotation} TensorAnnotation - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorAnnotation.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorAnnotation(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.tensorName = reader.string(); - break; - } - case 2: { - if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) - message.quantParameterTensorNames = []; - message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorAnnotation - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorAnnotation} TensorAnnotation - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TensorAnnotation message. - * @function verify - * @memberof onnx.TensorAnnotation - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TensorAnnotation.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.tensorName != null && message.hasOwnProperty("tensorName")) - if (!$util.isString(message.tensorName)) - return "tensorName: string expected"; - if (message.quantParameterTensorNames != null && message.hasOwnProperty("quantParameterTensorNames")) { - if (!Array.isArray(message.quantParameterTensorNames)) - return "quantParameterTensorNames: array expected"; - for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); - if (error) - return "quantParameterTensorNames." + error; - } - } - return null; - }; - - /** - * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorAnnotation - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorAnnotation} TensorAnnotation - */ - TensorAnnotation.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorAnnotation) - return object; - var message = new $root.onnx.TensorAnnotation(); - if (object.tensorName != null) - message.tensorName = String(object.tensorName); - if (object.quantParameterTensorNames) { - if (!Array.isArray(object.quantParameterTensorNames)) - throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: array expected"); - message.quantParameterTensorNames = []; - for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { - if (typeof object.quantParameterTensorNames[i] !== "object") - throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: object expected"); - message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject(object.quantParameterTensorNames[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorAnnotation - * @static - * @param {onnx.TensorAnnotation} message TensorAnnotation - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TensorAnnotation.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) - object.quantParameterTensorNames = []; - if (options.defaults) - object.tensorName = ""; - if (message.tensorName != null && message.hasOwnProperty("tensorName")) - object.tensorName = message.tensorName; - if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { - object.quantParameterTensorNames = []; - for (var j = 0; j < message.quantParameterTensorNames.length; ++j) - object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject(message.quantParameterTensorNames[j], options); - } - return object; - }; - - /** - * Converts this TensorAnnotation to JSON. - * @function toJSON - * @memberof onnx.TensorAnnotation - * @instance - * @returns {Object.} JSON object - */ - TensorAnnotation.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TensorAnnotation - * @function getTypeUrl - * @memberof onnx.TensorAnnotation - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; + return Optional; + })(); + + TypeProto.SparseTensor = (function () { + /** + * Properties of a SparseTensor. + * @memberof onnx.TypeProto + * @interface ISparseTensor + * @property {number|null} [elemType] SparseTensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape + */ + + /** + * Constructs a new SparseTensor. + * @memberof onnx.TypeProto + * @classdesc Represents a SparseTensor. + * @implements ISparseTensor + * @constructor + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + */ + function SparseTensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.elemType = 0; + + /** + * SparseTensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.shape = null; + + /** + * Creates a new SparseTensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance + */ + SparseTensor.create = function create(properties) { + return new SparseTensor(properties); + }; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, 'shape')) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.SparseTensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; } - return typeUrlPrefix + "/onnx.TensorAnnotation"; - }; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensor message. + * @function verify + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensor.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.elemType != null && message.hasOwnProperty('elemType')) + if (!$util.isInteger(message.elemType)) return 'elemType: integer expected'; + if (message.shape != null && message.hasOwnProperty('shape')) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) return 'shape.' + error; + } + return null; + }; + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + */ + SparseTensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.SparseTensor) return object; + var message = new $root.onnx.TypeProto.SparseTensor(); + if (object.elemType != null) message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== 'object') throw TypeError('.onnx.TypeProto.SparseTensor.shape: object expected'); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.SparseTensor} message SparseTensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensor.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty('elemType')) object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty('shape')) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this SparseTensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.SparseTensor + * @instance + * @returns {Object.} JSON object + */ + SparseTensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensor + * @function getTypeUrl + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.SparseTensor'; + }; - return TensorAnnotation; + return SparseTensor; })(); - onnx.GraphProto = (function() { - - /** - * Properties of a GraphProto. - * @memberof onnx - * @interface IGraphProto - * @property {Array.|null} [node] GraphProto node - * @property {string|null} [name] GraphProto name - * @property {Array.|null} [initializer] GraphProto initializer - * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer - * @property {string|null} [docString] GraphProto docString - * @property {Array.|null} [input] GraphProto input - * @property {Array.|null} [output] GraphProto output - * @property {Array.|null} [valueInfo] GraphProto valueInfo - * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation - */ - - /** - * Constructs a new GraphProto. - * @memberof onnx - * @classdesc Represents a GraphProto. - * @implements IGraphProto - * @constructor - * @param {onnx.IGraphProto=} [properties] Properties to set - */ - function GraphProto(properties) { - this.node = []; - this.initializer = []; - this.sparseInitializer = []; - this.input = []; - this.output = []; - this.valueInfo = []; - this.quantizationAnnotation = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + return TypeProto; + })(); - /** - * GraphProto node. - * @member {Array.} node - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.node = $util.emptyArray; - - /** - * GraphProto name. - * @member {string} name - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.name = ""; - - /** - * GraphProto initializer. - * @member {Array.} initializer - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.initializer = $util.emptyArray; - - /** - * GraphProto sparseInitializer. - * @member {Array.} sparseInitializer - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.sparseInitializer = $util.emptyArray; - - /** - * GraphProto docString. - * @member {string} docString - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.docString = ""; - - /** - * GraphProto input. - * @member {Array.} input - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.input = $util.emptyArray; - - /** - * GraphProto output. - * @member {Array.} output - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.output = $util.emptyArray; - - /** - * GraphProto valueInfo. - * @member {Array.} valueInfo - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.valueInfo = $util.emptyArray; - - /** - * GraphProto quantizationAnnotation. - * @member {Array.} quantizationAnnotation - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.quantizationAnnotation = $util.emptyArray; - - /** - * Creates a new GraphProto instance using the specified properties. - * @function create - * @memberof onnx.GraphProto - * @static - * @param {onnx.IGraphProto=} [properties] Properties to set - * @returns {onnx.GraphProto} GraphProto instance - */ - GraphProto.create = function create(properties) { - return new GraphProto(properties); - }; - - /** - * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. - * @function encode - * @memberof onnx.GraphProto - * @static - * @param {onnx.IGraphProto} message GraphProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - GraphProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.node != null && message.node.length) - for (var i = 0; i < message.node.length; ++i) - $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.name); - if (message.initializer != null && message.initializer.length) - for (var i = 0; i < message.initializer.length; ++i) - $root.onnx.TensorProto.encode(message.initializer[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 10, wireType 2 =*/82).string(message.docString); - if (message.input != null && message.input.length) - for (var i = 0; i < message.input.length; ++i) - $root.onnx.ValueInfoProto.encode(message.input[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); - if (message.output != null && message.output.length) - for (var i = 0; i < message.output.length; ++i) - $root.onnx.ValueInfoProto.encode(message.output[i], writer.uint32(/* id 12, wireType 2 =*/98).fork()).ldelim(); - if (message.valueInfo != null && message.valueInfo.length) - for (var i = 0; i < message.valueInfo.length; ++i) - $root.onnx.ValueInfoProto.encode(message.valueInfo[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); - if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) - for (var i = 0; i < message.quantizationAnnotation.length; ++i) - $root.onnx.TensorAnnotation.encode(message.quantizationAnnotation[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); - if (message.sparseInitializer != null && message.sparseInitializer.length) - for (var i = 0; i < message.sparseInitializer.length; ++i) - $root.onnx.SparseTensorProto.encode(message.sparseInitializer[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.GraphProto - * @static - * @param {onnx.IGraphProto} message GraphProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - GraphProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a GraphProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.GraphProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.GraphProto} GraphProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - GraphProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.GraphProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - if (!(message.node && message.node.length)) - message.node = []; - message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); - break; - } - case 2: { - message.name = reader.string(); - break; - } - case 5: { - if (!(message.initializer && message.initializer.length)) - message.initializer = []; - message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); - break; - } - case 15: { - if (!(message.sparseInitializer && message.sparseInitializer.length)) - message.sparseInitializer = []; - message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); - break; - } - case 10: { - message.docString = reader.string(); - break; - } - case 11: { - if (!(message.input && message.input.length)) - message.input = []; - message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); - break; - } - case 12: { - if (!(message.output && message.output.length)) - message.output = []; - message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); - break; - } - case 13: { - if (!(message.valueInfo && message.valueInfo.length)) - message.valueInfo = []; - message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); - break; - } - case 14: { - if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) - message.quantizationAnnotation = []; - message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a GraphProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.GraphProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.GraphProto} GraphProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - GraphProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a GraphProto message. - * @function verify - * @memberof onnx.GraphProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - GraphProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.node != null && message.hasOwnProperty("node")) { - if (!Array.isArray(message.node)) - return "node: array expected"; - for (var i = 0; i < message.node.length; ++i) { - var error = $root.onnx.NodeProto.verify(message.node[i]); - if (error) - return "node." + error; - } - } - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.initializer != null && message.hasOwnProperty("initializer")) { - if (!Array.isArray(message.initializer)) - return "initializer: array expected"; - for (var i = 0; i < message.initializer.length; ++i) { - var error = $root.onnx.TensorProto.verify(message.initializer[i]); - if (error) - return "initializer." + error; - } - } - if (message.sparseInitializer != null && message.hasOwnProperty("sparseInitializer")) { - if (!Array.isArray(message.sparseInitializer)) - return "sparseInitializer: array expected"; - for (var i = 0; i < message.sparseInitializer.length; ++i) { - var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); - if (error) - return "sparseInitializer." + error; - } - } - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.input != null && message.hasOwnProperty("input")) { - if (!Array.isArray(message.input)) - return "input: array expected"; - for (var i = 0; i < message.input.length; ++i) { - var error = $root.onnx.ValueInfoProto.verify(message.input[i]); - if (error) - return "input." + error; - } - } - if (message.output != null && message.hasOwnProperty("output")) { - if (!Array.isArray(message.output)) - return "output: array expected"; - for (var i = 0; i < message.output.length; ++i) { - var error = $root.onnx.ValueInfoProto.verify(message.output[i]); - if (error) - return "output." + error; - } - } - if (message.valueInfo != null && message.hasOwnProperty("valueInfo")) { - if (!Array.isArray(message.valueInfo)) - return "valueInfo: array expected"; - for (var i = 0; i < message.valueInfo.length; ++i) { - var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); - if (error) - return "valueInfo." + error; - } - } - if (message.quantizationAnnotation != null && message.hasOwnProperty("quantizationAnnotation")) { - if (!Array.isArray(message.quantizationAnnotation)) - return "quantizationAnnotation: array expected"; - for (var i = 0; i < message.quantizationAnnotation.length; ++i) { - var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); - if (error) - return "quantizationAnnotation." + error; - } - } - return null; - }; - - /** - * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.GraphProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.GraphProto} GraphProto - */ - GraphProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.GraphProto) - return object; - var message = new $root.onnx.GraphProto(); - if (object.node) { - if (!Array.isArray(object.node)) - throw TypeError(".onnx.GraphProto.node: array expected"); - message.node = []; - for (var i = 0; i < object.node.length; ++i) { - if (typeof object.node[i] !== "object") - throw TypeError(".onnx.GraphProto.node: object expected"); - message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); - } - } - if (object.name != null) - message.name = String(object.name); - if (object.initializer) { - if (!Array.isArray(object.initializer)) - throw TypeError(".onnx.GraphProto.initializer: array expected"); - message.initializer = []; - for (var i = 0; i < object.initializer.length; ++i) { - if (typeof object.initializer[i] !== "object") - throw TypeError(".onnx.GraphProto.initializer: object expected"); - message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); - } - } - if (object.sparseInitializer) { - if (!Array.isArray(object.sparseInitializer)) - throw TypeError(".onnx.GraphProto.sparseInitializer: array expected"); - message.sparseInitializer = []; - for (var i = 0; i < object.sparseInitializer.length; ++i) { - if (typeof object.sparseInitializer[i] !== "object") - throw TypeError(".onnx.GraphProto.sparseInitializer: object expected"); - message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); - } - } - if (object.docString != null) - message.docString = String(object.docString); - if (object.input) { - if (!Array.isArray(object.input)) - throw TypeError(".onnx.GraphProto.input: array expected"); - message.input = []; - for (var i = 0; i < object.input.length; ++i) { - if (typeof object.input[i] !== "object") - throw TypeError(".onnx.GraphProto.input: object expected"); - message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); - } - } - if (object.output) { - if (!Array.isArray(object.output)) - throw TypeError(".onnx.GraphProto.output: array expected"); - message.output = []; - for (var i = 0; i < object.output.length; ++i) { - if (typeof object.output[i] !== "object") - throw TypeError(".onnx.GraphProto.output: object expected"); - message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); - } - } - if (object.valueInfo) { - if (!Array.isArray(object.valueInfo)) - throw TypeError(".onnx.GraphProto.valueInfo: array expected"); - message.valueInfo = []; - for (var i = 0; i < object.valueInfo.length; ++i) { - if (typeof object.valueInfo[i] !== "object") - throw TypeError(".onnx.GraphProto.valueInfo: object expected"); - message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); - } - } - if (object.quantizationAnnotation) { - if (!Array.isArray(object.quantizationAnnotation)) - throw TypeError(".onnx.GraphProto.quantizationAnnotation: array expected"); - message.quantizationAnnotation = []; - for (var i = 0; i < object.quantizationAnnotation.length; ++i) { - if (typeof object.quantizationAnnotation[i] !== "object") - throw TypeError(".onnx.GraphProto.quantizationAnnotation: object expected"); - message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a GraphProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.GraphProto - * @static - * @param {onnx.GraphProto} message GraphProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - GraphProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.node = []; - object.initializer = []; - object.input = []; - object.output = []; - object.valueInfo = []; - object.quantizationAnnotation = []; - object.sparseInitializer = []; - } - if (options.defaults) { - object.name = ""; - object.docString = ""; - } - if (message.node && message.node.length) { - object.node = []; - for (var j = 0; j < message.node.length; ++j) - object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.initializer && message.initializer.length) { - object.initializer = []; - for (var j = 0; j < message.initializer.length; ++j) - object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.input && message.input.length) { - object.input = []; - for (var j = 0; j < message.input.length; ++j) - object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); - } - if (message.output && message.output.length) { - object.output = []; - for (var j = 0; j < message.output.length; ++j) - object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); - } - if (message.valueInfo && message.valueInfo.length) { - object.valueInfo = []; - for (var j = 0; j < message.valueInfo.length; ++j) - object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); - } - if (message.quantizationAnnotation && message.quantizationAnnotation.length) { - object.quantizationAnnotation = []; - for (var j = 0; j < message.quantizationAnnotation.length; ++j) - object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject(message.quantizationAnnotation[j], options); - } - if (message.sparseInitializer && message.sparseInitializer.length) { - object.sparseInitializer = []; - for (var j = 0; j < message.sparseInitializer.length; ++j) - object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); - } - return object; - }; - - /** - * Converts this GraphProto to JSON. - * @function toJSON - * @memberof onnx.GraphProto - * @instance - * @returns {Object.} JSON object - */ - GraphProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for GraphProto - * @function getTypeUrl - * @memberof onnx.GraphProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.GraphProto"; - }; + onnx.OperatorSetIdProto = (function () { + /** + * Properties of an OperatorSetIdProto. + * @memberof onnx + * @interface IOperatorSetIdProto + * @property {string|null} [domain] OperatorSetIdProto domain + * @property {number|Long|null} [version] OperatorSetIdProto version + */ - return GraphProto; - })(); + /** + * Constructs a new OperatorSetIdProto. + * @memberof onnx + * @classdesc Represents an OperatorSetIdProto. + * @implements IOperatorSetIdProto + * @constructor + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + */ + function OperatorSetIdProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } - onnx.TensorProto = (function() { - - /** - * Properties of a TensorProto. - * @memberof onnx - * @interface ITensorProto - * @property {Array.|null} [dims] TensorProto dims - * @property {number|null} [dataType] TensorProto dataType - * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment - * @property {Array.|null} [floatData] TensorProto floatData - * @property {Array.|null} [int32Data] TensorProto int32Data - * @property {Array.|null} [stringData] TensorProto stringData - * @property {Array.|null} [int64Data] TensorProto int64Data - * @property {string|null} [name] TensorProto name - * @property {string|null} [docString] TensorProto docString - * @property {Uint8Array|null} [rawData] TensorProto rawData - * @property {Array.|null} [externalData] TensorProto externalData - * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation - * @property {Array.|null} [doubleData] TensorProto doubleData - * @property {Array.|null} [uint64Data] TensorProto uint64Data - */ - - /** - * Constructs a new TensorProto. - * @memberof onnx - * @classdesc Represents a TensorProto. - * @implements ITensorProto - * @constructor - * @param {onnx.ITensorProto=} [properties] Properties to set - */ - function TensorProto(properties) { - this.dims = []; - this.floatData = []; - this.int32Data = []; - this.stringData = []; - this.int64Data = []; - this.externalData = []; - this.doubleData = []; - this.uint64Data = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * OperatorSetIdProto domain. + * @member {string} domain + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.domain = ''; - /** - * TensorProto dims. - * @member {Array.} dims - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.dims = $util.emptyArray; - - /** - * TensorProto dataType. - * @member {number} dataType - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.dataType = 0; - - /** - * TensorProto segment. - * @member {onnx.TensorProto.ISegment|null|undefined} segment - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.segment = null; - - /** - * TensorProto floatData. - * @member {Array.} floatData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.floatData = $util.emptyArray; - - /** - * TensorProto int32Data. - * @member {Array.} int32Data - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.int32Data = $util.emptyArray; - - /** - * TensorProto stringData. - * @member {Array.} stringData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.stringData = $util.emptyArray; - - /** - * TensorProto int64Data. - * @member {Array.} int64Data - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.int64Data = $util.emptyArray; - - /** - * TensorProto name. - * @member {string} name - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.name = ""; - - /** - * TensorProto docString. - * @member {string} docString - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.docString = ""; - - /** - * TensorProto rawData. - * @member {Uint8Array} rawData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.rawData = $util.newBuffer([]); - - /** - * TensorProto externalData. - * @member {Array.} externalData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.externalData = $util.emptyArray; - - /** - * TensorProto dataLocation. - * @member {onnx.TensorProto.DataLocation} dataLocation - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.dataLocation = 0; - - /** - * TensorProto doubleData. - * @member {Array.} doubleData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.doubleData = $util.emptyArray; - - /** - * TensorProto uint64Data. - * @member {Array.} uint64Data - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.uint64Data = $util.emptyArray; - - /** - * Creates a new TensorProto instance using the specified properties. - * @function create - * @memberof onnx.TensorProto - * @static - * @param {onnx.ITensorProto=} [properties] Properties to set - * @returns {onnx.TensorProto} TensorProto instance - */ - TensorProto.create = function create(properties) { - return new TensorProto(properties); - }; - - /** - * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. - * @function encode - * @memberof onnx.TensorProto - * @static - * @param {onnx.ITensorProto} message TensorProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.dims != null && message.dims.length) { - writer.uint32(/* id 1, wireType 2 =*/10).fork(); - for (var i = 0; i < message.dims.length; ++i) - writer.int64(message.dims[i]); - writer.ldelim(); - } - if (message.dataType != null && Object.hasOwnProperty.call(message, "dataType")) - writer.uint32(/* id 2, wireType 0 =*/16).int32(message.dataType); - if (message.segment != null && Object.hasOwnProperty.call(message, "segment")) - $root.onnx.TensorProto.Segment.encode(message.segment, writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); - if (message.floatData != null && message.floatData.length) { - writer.uint32(/* id 4, wireType 2 =*/34).fork(); - for (var i = 0; i < message.floatData.length; ++i) - writer.float(message.floatData[i]); - writer.ldelim(); - } - if (message.int32Data != null && message.int32Data.length) { - writer.uint32(/* id 5, wireType 2 =*/42).fork(); - for (var i = 0; i < message.int32Data.length; ++i) - writer.int32(message.int32Data[i]); - writer.ldelim(); - } - if (message.stringData != null && message.stringData.length) - for (var i = 0; i < message.stringData.length; ++i) - writer.uint32(/* id 6, wireType 2 =*/50).bytes(message.stringData[i]); - if (message.int64Data != null && message.int64Data.length) { - writer.uint32(/* id 7, wireType 2 =*/58).fork(); - for (var i = 0; i < message.int64Data.length; ++i) - writer.int64(message.int64Data[i]); - writer.ldelim(); - } - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 8, wireType 2 =*/66).string(message.name); - if (message.rawData != null && Object.hasOwnProperty.call(message, "rawData")) - writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.rawData); - if (message.doubleData != null && message.doubleData.length) { - writer.uint32(/* id 10, wireType 2 =*/82).fork(); - for (var i = 0; i < message.doubleData.length; ++i) - writer.double(message.doubleData[i]); - writer.ldelim(); - } - if (message.uint64Data != null && message.uint64Data.length) { - writer.uint32(/* id 11, wireType 2 =*/90).fork(); - for (var i = 0; i < message.uint64Data.length; ++i) - writer.uint64(message.uint64Data[i]); - writer.ldelim(); - } - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 12, wireType 2 =*/98).string(message.docString); - if (message.externalData != null && message.externalData.length) - for (var i = 0; i < message.externalData.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.externalData[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); - if (message.dataLocation != null && Object.hasOwnProperty.call(message, "dataLocation")) - writer.uint32(/* id 14, wireType 0 =*/112).int32(message.dataLocation); - return writer; - }; - - /** - * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorProto - * @static - * @param {onnx.ITensorProto} message TensorProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TensorProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorProto} TensorProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - if (!(message.dims && message.dims.length)) - message.dims = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.dims.push(reader.int64()); - } else - message.dims.push(reader.int64()); - break; - } - case 2: { - message.dataType = reader.int32(); - break; - } - case 3: { - message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); - break; - } - case 4: { - if (!(message.floatData && message.floatData.length)) - message.floatData = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.floatData.push(reader.float()); - } else - message.floatData.push(reader.float()); - break; - } - case 5: { - if (!(message.int32Data && message.int32Data.length)) - message.int32Data = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.int32Data.push(reader.int32()); - } else - message.int32Data.push(reader.int32()); - break; - } - case 6: { - if (!(message.stringData && message.stringData.length)) - message.stringData = []; - message.stringData.push(reader.bytes()); - break; - } - case 7: { - if (!(message.int64Data && message.int64Data.length)) - message.int64Data = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.int64Data.push(reader.int64()); - } else - message.int64Data.push(reader.int64()); - break; - } - case 8: { - message.name = reader.string(); - break; - } - case 12: { - message.docString = reader.string(); - break; - } - case 9: { - message.rawData = reader.bytes(); - break; - } - case 13: { - if (!(message.externalData && message.externalData.length)) - message.externalData = []; - message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - case 14: { - message.dataLocation = reader.int32(); - break; - } - case 10: { - if (!(message.doubleData && message.doubleData.length)) - message.doubleData = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.doubleData.push(reader.double()); - } else - message.doubleData.push(reader.double()); - break; - } - case 11: { - if (!(message.uint64Data && message.uint64Data.length)) - message.uint64Data = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.uint64Data.push(reader.uint64()); - } else - message.uint64Data.push(reader.uint64()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TensorProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorProto} TensorProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TensorProto message. - * @function verify - * @memberof onnx.TensorProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TensorProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.dims != null && message.hasOwnProperty("dims")) { - if (!Array.isArray(message.dims)) - return "dims: array expected"; - for (var i = 0; i < message.dims.length; ++i) - if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) - return "dims: integer|Long[] expected"; - } - if (message.dataType != null && message.hasOwnProperty("dataType")) - if (!$util.isInteger(message.dataType)) - return "dataType: integer expected"; - if (message.segment != null && message.hasOwnProperty("segment")) { - var error = $root.onnx.TensorProto.Segment.verify(message.segment); - if (error) - return "segment." + error; - } - if (message.floatData != null && message.hasOwnProperty("floatData")) { - if (!Array.isArray(message.floatData)) - return "floatData: array expected"; - for (var i = 0; i < message.floatData.length; ++i) - if (typeof message.floatData[i] !== "number") - return "floatData: number[] expected"; - } - if (message.int32Data != null && message.hasOwnProperty("int32Data")) { - if (!Array.isArray(message.int32Data)) - return "int32Data: array expected"; - for (var i = 0; i < message.int32Data.length; ++i) - if (!$util.isInteger(message.int32Data[i])) - return "int32Data: integer[] expected"; - } - if (message.stringData != null && message.hasOwnProperty("stringData")) { - if (!Array.isArray(message.stringData)) - return "stringData: array expected"; - for (var i = 0; i < message.stringData.length; ++i) - if (!(message.stringData[i] && typeof message.stringData[i].length === "number" || $util.isString(message.stringData[i]))) - return "stringData: buffer[] expected"; - } - if (message.int64Data != null && message.hasOwnProperty("int64Data")) { - if (!Array.isArray(message.int64Data)) - return "int64Data: array expected"; - for (var i = 0; i < message.int64Data.length; ++i) - if (!$util.isInteger(message.int64Data[i]) && !(message.int64Data[i] && $util.isInteger(message.int64Data[i].low) && $util.isInteger(message.int64Data[i].high))) - return "int64Data: integer|Long[] expected"; - } - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.rawData != null && message.hasOwnProperty("rawData")) - if (!(message.rawData && typeof message.rawData.length === "number" || $util.isString(message.rawData))) - return "rawData: buffer expected"; - if (message.externalData != null && message.hasOwnProperty("externalData")) { - if (!Array.isArray(message.externalData)) - return "externalData: array expected"; - for (var i = 0; i < message.externalData.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); - if (error) - return "externalData." + error; - } - } - if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) - switch (message.dataLocation) { - default: - return "dataLocation: enum value expected"; - case 0: - case 1: - break; - } - if (message.doubleData != null && message.hasOwnProperty("doubleData")) { - if (!Array.isArray(message.doubleData)) - return "doubleData: array expected"; - for (var i = 0; i < message.doubleData.length; ++i) - if (typeof message.doubleData[i] !== "number") - return "doubleData: number[] expected"; - } - if (message.uint64Data != null && message.hasOwnProperty("uint64Data")) { - if (!Array.isArray(message.uint64Data)) - return "uint64Data: array expected"; - for (var i = 0; i < message.uint64Data.length; ++i) - if (!$util.isInteger(message.uint64Data[i]) && !(message.uint64Data[i] && $util.isInteger(message.uint64Data[i].low) && $util.isInteger(message.uint64Data[i].high))) - return "uint64Data: integer|Long[] expected"; - } - return null; - }; - - /** - * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorProto} TensorProto - */ - TensorProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorProto) - return object; - var message = new $root.onnx.TensorProto(); - if (object.dims) { - if (!Array.isArray(object.dims)) - throw TypeError(".onnx.TensorProto.dims: array expected"); - message.dims = []; - for (var i = 0; i < object.dims.length; ++i) - if ($util.Long) - (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; - else if (typeof object.dims[i] === "string") - message.dims[i] = parseInt(object.dims[i], 10); - else if (typeof object.dims[i] === "number") - message.dims[i] = object.dims[i]; - else if (typeof object.dims[i] === "object") - message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); - } - if (object.dataType != null) - message.dataType = object.dataType | 0; - if (object.segment != null) { - if (typeof object.segment !== "object") - throw TypeError(".onnx.TensorProto.segment: object expected"); - message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); - } - if (object.floatData) { - if (!Array.isArray(object.floatData)) - throw TypeError(".onnx.TensorProto.floatData: array expected"); - message.floatData = []; - for (var i = 0; i < object.floatData.length; ++i) - message.floatData[i] = Number(object.floatData[i]); - } - if (object.int32Data) { - if (!Array.isArray(object.int32Data)) - throw TypeError(".onnx.TensorProto.int32Data: array expected"); - message.int32Data = []; - for (var i = 0; i < object.int32Data.length; ++i) - message.int32Data[i] = object.int32Data[i] | 0; - } - if (object.stringData) { - if (!Array.isArray(object.stringData)) - throw TypeError(".onnx.TensorProto.stringData: array expected"); - message.stringData = []; - for (var i = 0; i < object.stringData.length; ++i) - if (typeof object.stringData[i] === "string") - $util.base64.decode(object.stringData[i], message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i])), 0); - else if (object.stringData[i].length >= 0) - message.stringData[i] = object.stringData[i]; - } - if (object.int64Data) { - if (!Array.isArray(object.int64Data)) - throw TypeError(".onnx.TensorProto.int64Data: array expected"); - message.int64Data = []; - for (var i = 0; i < object.int64Data.length; ++i) - if ($util.Long) - (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; - else if (typeof object.int64Data[i] === "string") - message.int64Data[i] = parseInt(object.int64Data[i], 10); - else if (typeof object.int64Data[i] === "number") - message.int64Data[i] = object.int64Data[i]; - else if (typeof object.int64Data[i] === "object") - message.int64Data[i] = new $util.LongBits(object.int64Data[i].low >>> 0, object.int64Data[i].high >>> 0).toNumber(); - } - if (object.name != null) - message.name = String(object.name); - if (object.docString != null) - message.docString = String(object.docString); - if (object.rawData != null) - if (typeof object.rawData === "string") - $util.base64.decode(object.rawData, message.rawData = $util.newBuffer($util.base64.length(object.rawData)), 0); - else if (object.rawData.length >= 0) - message.rawData = object.rawData; - if (object.externalData) { - if (!Array.isArray(object.externalData)) - throw TypeError(".onnx.TensorProto.externalData: array expected"); - message.externalData = []; - for (var i = 0; i < object.externalData.length; ++i) { - if (typeof object.externalData[i] !== "object") - throw TypeError(".onnx.TensorProto.externalData: object expected"); - message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); - } - } - switch (object.dataLocation) { - default: - if (typeof object.dataLocation === "number") { - message.dataLocation = object.dataLocation; - break; - } - break; - case "DEFAULT": - case 0: - message.dataLocation = 0; - break; - case "EXTERNAL": - case 1: - message.dataLocation = 1; - break; - } - if (object.doubleData) { - if (!Array.isArray(object.doubleData)) - throw TypeError(".onnx.TensorProto.doubleData: array expected"); - message.doubleData = []; - for (var i = 0; i < object.doubleData.length; ++i) - message.doubleData[i] = Number(object.doubleData[i]); - } - if (object.uint64Data) { - if (!Array.isArray(object.uint64Data)) - throw TypeError(".onnx.TensorProto.uint64Data: array expected"); - message.uint64Data = []; - for (var i = 0; i < object.uint64Data.length; ++i) - if ($util.Long) - (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; - else if (typeof object.uint64Data[i] === "string") - message.uint64Data[i] = parseInt(object.uint64Data[i], 10); - else if (typeof object.uint64Data[i] === "number") - message.uint64Data[i] = object.uint64Data[i]; - else if (typeof object.uint64Data[i] === "object") - message.uint64Data[i] = new $util.LongBits(object.uint64Data[i].low >>> 0, object.uint64Data[i].high >>> 0).toNumber(true); - } - return message; - }; - - /** - * Creates a plain object from a TensorProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorProto - * @static - * @param {onnx.TensorProto} message TensorProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TensorProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.dims = []; - object.floatData = []; - object.int32Data = []; - object.stringData = []; - object.int64Data = []; - object.doubleData = []; - object.uint64Data = []; - object.externalData = []; - } - if (options.defaults) { - object.dataType = 0; - object.segment = null; - object.name = ""; - if (options.bytes === String) - object.rawData = ""; - else { - object.rawData = []; - if (options.bytes !== Array) - object.rawData = $util.newBuffer(object.rawData); - } - object.docString = ""; - object.dataLocation = options.enums === String ? "DEFAULT" : 0; - } - if (message.dims && message.dims.length) { - object.dims = []; - for (var j = 0; j < message.dims.length; ++j) - if (typeof message.dims[j] === "number") - object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; - else - object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; - } - if (message.dataType != null && message.hasOwnProperty("dataType")) - object.dataType = message.dataType; - if (message.segment != null && message.hasOwnProperty("segment")) - object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); - if (message.floatData && message.floatData.length) { - object.floatData = []; - for (var j = 0; j < message.floatData.length; ++j) - object.floatData[j] = options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; - } - if (message.int32Data && message.int32Data.length) { - object.int32Data = []; - for (var j = 0; j < message.int32Data.length; ++j) - object.int32Data[j] = message.int32Data[j]; - } - if (message.stringData && message.stringData.length) { - object.stringData = []; - for (var j = 0; j < message.stringData.length; ++j) - object.stringData[j] = options.bytes === String ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.stringData[j]) : message.stringData[j]; - } - if (message.int64Data && message.int64Data.length) { - object.int64Data = []; - for (var j = 0; j < message.int64Data.length; ++j) - if (typeof message.int64Data[j] === "number") - object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; - else - object.int64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.int64Data[j]) : options.longs === Number ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() : message.int64Data[j]; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.rawData != null && message.hasOwnProperty("rawData")) - object.rawData = options.bytes === String ? $util.base64.encode(message.rawData, 0, message.rawData.length) : options.bytes === Array ? Array.prototype.slice.call(message.rawData) : message.rawData; - if (message.doubleData && message.doubleData.length) { - object.doubleData = []; - for (var j = 0; j < message.doubleData.length; ++j) - object.doubleData[j] = options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; - } - if (message.uint64Data && message.uint64Data.length) { - object.uint64Data = []; - for (var j = 0; j < message.uint64Data.length; ++j) - if (typeof message.uint64Data[j] === "number") - object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; - else - object.uint64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.uint64Data[j]) : options.longs === Number ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) : message.uint64Data[j]; - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.externalData && message.externalData.length) { - object.externalData = []; - for (var j = 0; j < message.externalData.length; ++j) - object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); - } - if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) - object.dataLocation = options.enums === String ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined ? message.dataLocation : $root.onnx.TensorProto.DataLocation[message.dataLocation] : message.dataLocation; - return object; - }; - - /** - * Converts this TensorProto to JSON. - * @function toJSON - * @memberof onnx.TensorProto - * @instance - * @returns {Object.} JSON object - */ - TensorProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TensorProto - * @function getTypeUrl - * @memberof onnx.TensorProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TensorProto"; - }; - - /** - * DataType enum. - * @name onnx.TensorProto.DataType - * @enum {number} - * @property {number} UNDEFINED=0 UNDEFINED value - * @property {number} FLOAT=1 FLOAT value - * @property {number} UINT8=2 UINT8 value - * @property {number} INT8=3 INT8 value - * @property {number} UINT16=4 UINT16 value - * @property {number} INT16=5 INT16 value - * @property {number} INT32=6 INT32 value - * @property {number} INT64=7 INT64 value - * @property {number} STRING=8 STRING value - * @property {number} BOOL=9 BOOL value - * @property {number} FLOAT16=10 FLOAT16 value - * @property {number} DOUBLE=11 DOUBLE value - * @property {number} UINT32=12 UINT32 value - * @property {number} UINT64=13 UINT64 value - * @property {number} COMPLEX64=14 COMPLEX64 value - * @property {number} COMPLEX128=15 COMPLEX128 value - * @property {number} BFLOAT16=16 BFLOAT16 value - * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value - * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value - * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value - * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value - */ - TensorProto.DataType = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "UNDEFINED"] = 0; - values[valuesById[1] = "FLOAT"] = 1; - values[valuesById[2] = "UINT8"] = 2; - values[valuesById[3] = "INT8"] = 3; - values[valuesById[4] = "UINT16"] = 4; - values[valuesById[5] = "INT16"] = 5; - values[valuesById[6] = "INT32"] = 6; - values[valuesById[7] = "INT64"] = 7; - values[valuesById[8] = "STRING"] = 8; - values[valuesById[9] = "BOOL"] = 9; - values[valuesById[10] = "FLOAT16"] = 10; - values[valuesById[11] = "DOUBLE"] = 11; - values[valuesById[12] = "UINT32"] = 12; - values[valuesById[13] = "UINT64"] = 13; - values[valuesById[14] = "COMPLEX64"] = 14; - values[valuesById[15] = "COMPLEX128"] = 15; - values[valuesById[16] = "BFLOAT16"] = 16; - values[valuesById[17] = "FLOAT8E4M3FN"] = 17; - values[valuesById[18] = "FLOAT8E4M3FNUZ"] = 18; - values[valuesById[19] = "FLOAT8E5M2"] = 19; - values[valuesById[20] = "FLOAT8E5M2FNUZ"] = 20; - return values; - })(); - - TensorProto.Segment = (function() { - - /** - * Properties of a Segment. - * @memberof onnx.TensorProto - * @interface ISegment - * @property {number|Long|null} [begin] Segment begin - * @property {number|Long|null} [end] Segment end - */ - - /** - * Constructs a new Segment. - * @memberof onnx.TensorProto - * @classdesc Represents a Segment. - * @implements ISegment - * @constructor - * @param {onnx.TensorProto.ISegment=} [properties] Properties to set - */ - function Segment(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * OperatorSetIdProto version. + * @member {number|Long} version + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; - /** - * Segment begin. - * @member {number|Long} begin - * @memberof onnx.TensorProto.Segment - * @instance - */ - Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * Segment end. - * @member {number|Long} end - * @memberof onnx.TensorProto.Segment - * @instance - */ - Segment.prototype.end = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * Creates a new Segment instance using the specified properties. - * @function create - * @memberof onnx.TensorProto.Segment - * @static - * @param {onnx.TensorProto.ISegment=} [properties] Properties to set - * @returns {onnx.TensorProto.Segment} Segment instance - */ - Segment.create = function create(properties) { - return new Segment(properties); - }; - - /** - * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. - * @function encode - * @memberof onnx.TensorProto.Segment - * @static - * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Segment.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.begin != null && Object.hasOwnProperty.call(message, "begin")) - writer.uint32(/* id 1, wireType 0 =*/8).int64(message.begin); - if (message.end != null && Object.hasOwnProperty.call(message, "end")) - writer.uint32(/* id 2, wireType 0 =*/16).int64(message.end); - return writer; - }; - - /** - * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorProto.Segment - * @static - * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Segment.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Segment message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorProto.Segment - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorProto.Segment} Segment - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Segment.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto.Segment(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.begin = reader.int64(); - break; - } - case 2: { - message.end = reader.int64(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Segment message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorProto.Segment - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorProto.Segment} Segment - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Segment.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Segment message. - * @function verify - * @memberof onnx.TensorProto.Segment - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Segment.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.begin != null && message.hasOwnProperty("begin")) - if (!$util.isInteger(message.begin) && !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high))) - return "begin: integer|Long expected"; - if (message.end != null && message.hasOwnProperty("end")) - if (!$util.isInteger(message.end) && !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high))) - return "end: integer|Long expected"; - return null; - }; - - /** - * Creates a Segment message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorProto.Segment - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorProto.Segment} Segment - */ - Segment.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorProto.Segment) - return object; - var message = new $root.onnx.TensorProto.Segment(); - if (object.begin != null) - if ($util.Long) - (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; - else if (typeof object.begin === "string") - message.begin = parseInt(object.begin, 10); - else if (typeof object.begin === "number") - message.begin = object.begin; - else if (typeof object.begin === "object") - message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); - if (object.end != null) - if ($util.Long) - (message.end = $util.Long.fromValue(object.end)).unsigned = false; - else if (typeof object.end === "string") - message.end = parseInt(object.end, 10); - else if (typeof object.end === "number") - message.end = object.end; - else if (typeof object.end === "object") - message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); - return message; - }; - - /** - * Creates a plain object from a Segment message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorProto.Segment - * @static - * @param {onnx.TensorProto.Segment} message Segment - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Segment.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.begin = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.begin = options.longs === String ? "0" : 0; - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.end = options.longs === String ? "0" : 0; - } - if (message.begin != null && message.hasOwnProperty("begin")) - if (typeof message.begin === "number") - object.begin = options.longs === String ? String(message.begin) : message.begin; - else - object.begin = options.longs === String ? $util.Long.prototype.toString.call(message.begin) : options.longs === Number ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() : message.begin; - if (message.end != null && message.hasOwnProperty("end")) - if (typeof message.end === "number") - object.end = options.longs === String ? String(message.end) : message.end; - else - object.end = options.longs === String ? $util.Long.prototype.toString.call(message.end) : options.longs === Number ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() : message.end; - return object; - }; - - /** - * Converts this Segment to JSON. - * @function toJSON - * @memberof onnx.TensorProto.Segment - * @instance - * @returns {Object.} JSON object - */ - Segment.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Segment - * @function getTypeUrl - * @memberof onnx.TensorProto.Segment - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TensorProto.Segment"; - }; - - return Segment; - })(); - - /** - * DataLocation enum. - * @name onnx.TensorProto.DataLocation - * @enum {number} - * @property {number} DEFAULT=0 DEFAULT value - * @property {number} EXTERNAL=1 EXTERNAL value - */ - TensorProto.DataLocation = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "DEFAULT"] = 0; - values[valuesById[1] = "EXTERNAL"] = 1; - return values; - })(); - - return TensorProto; - })(); + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @function create + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance + */ + OperatorSetIdProto.create = function create(properties) { + return new OperatorSetIdProto(properties); + }; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.domain); + if (message.version != null && Object.hasOwnProperty.call(message, 'version')) + writer.uint32(/* id 2, wireType 0 =*/ 16).int64(message.version); + return writer; + }; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; - onnx.SparseTensorProto = (function() { - - /** - * Properties of a SparseTensorProto. - * @memberof onnx - * @interface ISparseTensorProto - * @property {onnx.ITensorProto|null} [values] SparseTensorProto values - * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices - * @property {Array.|null} [dims] SparseTensorProto dims - */ - - /** - * Constructs a new SparseTensorProto. - * @memberof onnx - * @classdesc Represents a SparseTensorProto. - * @implements ISparseTensorProto - * @constructor - * @param {onnx.ISparseTensorProto=} [properties] Properties to set - */ - function SparseTensorProto(properties) { - this.dims = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.OperatorSetIdProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.domain = reader.string(); + break; + } + case 2: { + message.version = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * SparseTensorProto values. - * @member {onnx.ITensorProto|null|undefined} values - * @memberof onnx.SparseTensorProto - * @instance - */ - SparseTensorProto.prototype.values = null; - - /** - * SparseTensorProto indices. - * @member {onnx.ITensorProto|null|undefined} indices - * @memberof onnx.SparseTensorProto - * @instance - */ - SparseTensorProto.prototype.indices = null; - - /** - * SparseTensorProto dims. - * @member {Array.} dims - * @memberof onnx.SparseTensorProto - * @instance - */ - SparseTensorProto.prototype.dims = $util.emptyArray; - - /** - * Creates a new SparseTensorProto instance using the specified properties. - * @function create - * @memberof onnx.SparseTensorProto - * @static - * @param {onnx.ISparseTensorProto=} [properties] Properties to set - * @returns {onnx.SparseTensorProto} SparseTensorProto instance - */ - SparseTensorProto.create = function create(properties) { - return new SparseTensorProto(properties); - }; - - /** - * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. - * @function encode - * @memberof onnx.SparseTensorProto - * @static - * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - SparseTensorProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.values != null && Object.hasOwnProperty.call(message, "values")) - $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - if (message.indices != null && Object.hasOwnProperty.call(message, "indices")) - $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - if (message.dims != null && message.dims.length) { - writer.uint32(/* id 3, wireType 2 =*/26).fork(); - for (var i = 0; i < message.dims.length; ++i) - writer.int64(message.dims[i]); - writer.ldelim(); - } - return writer; - }; - - /** - * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.SparseTensorProto - * @static - * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a SparseTensorProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.SparseTensorProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.SparseTensorProto} SparseTensorProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - SparseTensorProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.SparseTensorProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); - break; - } - case 2: { - message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); - break; - } - case 3: { - if (!(message.dims && message.dims.length)) - message.dims = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.dims.push(reader.int64()); - } else - message.dims.push(reader.int64()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.SparseTensorProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.SparseTensorProto} SparseTensorProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a SparseTensorProto message. - * @function verify - * @memberof onnx.SparseTensorProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - SparseTensorProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.values != null && message.hasOwnProperty("values")) { - var error = $root.onnx.TensorProto.verify(message.values); - if (error) - return "values." + error; - } - if (message.indices != null && message.hasOwnProperty("indices")) { - var error = $root.onnx.TensorProto.verify(message.indices); - if (error) - return "indices." + error; - } - if (message.dims != null && message.hasOwnProperty("dims")) { - if (!Array.isArray(message.dims)) - return "dims: array expected"; - for (var i = 0; i < message.dims.length; ++i) - if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) - return "dims: integer|Long[] expected"; - } - return null; - }; - - /** - * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.SparseTensorProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.SparseTensorProto} SparseTensorProto - */ - SparseTensorProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.SparseTensorProto) - return object; - var message = new $root.onnx.SparseTensorProto(); - if (object.values != null) { - if (typeof object.values !== "object") - throw TypeError(".onnx.SparseTensorProto.values: object expected"); - message.values = $root.onnx.TensorProto.fromObject(object.values); - } - if (object.indices != null) { - if (typeof object.indices !== "object") - throw TypeError(".onnx.SparseTensorProto.indices: object expected"); - message.indices = $root.onnx.TensorProto.fromObject(object.indices); - } - if (object.dims) { - if (!Array.isArray(object.dims)) - throw TypeError(".onnx.SparseTensorProto.dims: array expected"); - message.dims = []; - for (var i = 0; i < object.dims.length; ++i) - if ($util.Long) - (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; - else if (typeof object.dims[i] === "string") - message.dims[i] = parseInt(object.dims[i], 10); - else if (typeof object.dims[i] === "number") - message.dims[i] = object.dims[i]; - else if (typeof object.dims[i] === "object") - message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); - } - return message; - }; - - /** - * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.SparseTensorProto - * @static - * @param {onnx.SparseTensorProto} message SparseTensorProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - SparseTensorProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) - object.dims = []; - if (options.defaults) { - object.values = null; - object.indices = null; - } - if (message.values != null && message.hasOwnProperty("values")) - object.values = $root.onnx.TensorProto.toObject(message.values, options); - if (message.indices != null && message.hasOwnProperty("indices")) - object.indices = $root.onnx.TensorProto.toObject(message.indices, options); - if (message.dims && message.dims.length) { - object.dims = []; - for (var j = 0; j < message.dims.length; ++j) - if (typeof message.dims[j] === "number") - object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; - else - object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; - } - return object; - }; - - /** - * Converts this SparseTensorProto to JSON. - * @function toJSON - * @memberof onnx.SparseTensorProto - * @instance - * @returns {Object.} JSON object - */ - SparseTensorProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for SparseTensorProto - * @function getTypeUrl - * @memberof onnx.SparseTensorProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.SparseTensorProto"; - }; + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; - return SparseTensorProto; - })(); + /** + * Verifies an OperatorSetIdProto message. + * @function verify + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + OperatorSetIdProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.domain != null && message.hasOwnProperty('domain')) + if (!$util.isString(message.domain)) return 'domain: string expected'; + if (message.version != null && message.hasOwnProperty('version')) + if ( + !$util.isInteger(message.version) && + !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high)) + ) + return 'version: integer|Long expected'; + return null; + }; - onnx.TensorShapeProto = (function() { - - /** - * Properties of a TensorShapeProto. - * @memberof onnx - * @interface ITensorShapeProto - * @property {Array.|null} [dim] TensorShapeProto dim - */ - - /** - * Constructs a new TensorShapeProto. - * @memberof onnx - * @classdesc Represents a TensorShapeProto. - * @implements ITensorShapeProto - * @constructor - * @param {onnx.ITensorShapeProto=} [properties] Properties to set - */ - function TensorShapeProto(properties) { - this.dim = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + */ + OperatorSetIdProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.OperatorSetIdProto) return object; + var message = new $root.onnx.OperatorSetIdProto(); + if (object.domain != null) message.domain = String(object.domain); + if (object.version != null) + if ($util.Long) (message.version = $util.Long.fromValue(object.version)).unsigned = false; + else if (typeof object.version === 'string') message.version = parseInt(object.version, 10); + else if (typeof object.version === 'number') message.version = object.version; + else if (typeof object.version === 'object') + message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); + return message; + }; - /** - * TensorShapeProto dim. - * @member {Array.} dim - * @memberof onnx.TensorShapeProto - * @instance - */ - TensorShapeProto.prototype.dim = $util.emptyArray; - - /** - * Creates a new TensorShapeProto instance using the specified properties. - * @function create - * @memberof onnx.TensorShapeProto - * @static - * @param {onnx.ITensorShapeProto=} [properties] Properties to set - * @returns {onnx.TensorShapeProto} TensorShapeProto instance - */ - TensorShapeProto.create = function create(properties) { - return new TensorShapeProto(properties); - }; - - /** - * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. - * @function encode - * @memberof onnx.TensorShapeProto - * @static - * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorShapeProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.dim != null && message.dim.length) - for (var i = 0; i < message.dim.length; ++i) - $root.onnx.TensorShapeProto.Dimension.encode(message.dim[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorShapeProto - * @static - * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TensorShapeProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorShapeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorShapeProto} TensorShapeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorShapeProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - if (!(message.dim && message.dim.length)) - message.dim = []; - message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorShapeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorShapeProto} TensorShapeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TensorShapeProto message. - * @function verify - * @memberof onnx.TensorShapeProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TensorShapeProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.dim != null && message.hasOwnProperty("dim")) { - if (!Array.isArray(message.dim)) - return "dim: array expected"; - for (var i = 0; i < message.dim.length; ++i) { - var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); - if (error) - return "dim." + error; - } - } - return null; - }; - - /** - * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorShapeProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorShapeProto} TensorShapeProto - */ - TensorShapeProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorShapeProto) - return object; - var message = new $root.onnx.TensorShapeProto(); - if (object.dim) { - if (!Array.isArray(object.dim)) - throw TypeError(".onnx.TensorShapeProto.dim: array expected"); - message.dim = []; - for (var i = 0; i < object.dim.length; ++i) { - if (typeof object.dim[i] !== "object") - throw TypeError(".onnx.TensorShapeProto.dim: object expected"); - message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorShapeProto - * @static - * @param {onnx.TensorShapeProto} message TensorShapeProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TensorShapeProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) - object.dim = []; - if (message.dim && message.dim.length) { - object.dim = []; - for (var j = 0; j < message.dim.length; ++j) - object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); - } - return object; - }; - - /** - * Converts this TensorShapeProto to JSON. - * @function toJSON - * @memberof onnx.TensorShapeProto - * @instance - * @returns {Object.} JSON object - */ - TensorShapeProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TensorShapeProto - * @function getTypeUrl - * @memberof onnx.TensorShapeProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TensorShapeProto"; - }; - - TensorShapeProto.Dimension = (function() { - - /** - * Properties of a Dimension. - * @memberof onnx.TensorShapeProto - * @interface IDimension - * @property {number|Long|null} [dimValue] Dimension dimValue - * @property {string|null} [dimParam] Dimension dimParam - * @property {string|null} [denotation] Dimension denotation - */ - - /** - * Constructs a new Dimension. - * @memberof onnx.TensorShapeProto - * @classdesc Represents a Dimension. - * @implements IDimension - * @constructor - * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set - */ - function Dimension(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + OperatorSetIdProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.domain = ''; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.version = + options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.version = options.longs === String ? '0' : 0; + } + if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + if (message.version != null && message.hasOwnProperty('version')) + if (typeof message.version === 'number') + object.version = options.longs === String ? String(message.version) : message.version; + else + object.version = + options.longs === String + ? $util.Long.prototype.toString.call(message.version) + : options.longs === Number + ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() + : message.version; + return object; + }; - /** - * Dimension dimValue. - * @member {number|Long|null|undefined} dimValue - * @memberof onnx.TensorShapeProto.Dimension - * @instance - */ - Dimension.prototype.dimValue = null; - - /** - * Dimension dimParam. - * @member {string|null|undefined} dimParam - * @memberof onnx.TensorShapeProto.Dimension - * @instance - */ - Dimension.prototype.dimParam = null; - - /** - * Dimension denotation. - * @member {string} denotation - * @memberof onnx.TensorShapeProto.Dimension - * @instance - */ - Dimension.prototype.denotation = ""; - - // OneOf field names bound to virtual getters and setters - var $oneOfFields; - - /** - * Dimension value. - * @member {"dimValue"|"dimParam"|undefined} value - * @memberof onnx.TensorShapeProto.Dimension - * @instance - */ - Object.defineProperty(Dimension.prototype, "value", { - get: $util.oneOfGetter($oneOfFields = ["dimValue", "dimParam"]), - set: $util.oneOfSetter($oneOfFields) - }); - - /** - * Creates a new Dimension instance using the specified properties. - * @function create - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set - * @returns {onnx.TensorShapeProto.Dimension} Dimension instance - */ - Dimension.create = function create(properties) { - return new Dimension(properties); - }; - - /** - * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. - * @function encode - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Dimension.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.dimValue != null && Object.hasOwnProperty.call(message, "dimValue")) - writer.uint32(/* id 1, wireType 0 =*/8).int64(message.dimValue); - if (message.dimParam != null && Object.hasOwnProperty.call(message, "dimParam")) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.dimParam); - if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) - writer.uint32(/* id 3, wireType 2 =*/26).string(message.denotation); - return writer; - }; - - /** - * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Dimension.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Dimension message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorShapeProto.Dimension} Dimension - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Dimension.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto.Dimension(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.dimValue = reader.int64(); - break; - } - case 2: { - message.dimParam = reader.string(); - break; - } - case 3: { - message.denotation = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Dimension message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorShapeProto.Dimension} Dimension - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Dimension.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Dimension message. - * @function verify - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Dimension.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - var properties = {}; - if (message.dimValue != null && message.hasOwnProperty("dimValue")) { - properties.value = 1; - if (!$util.isInteger(message.dimValue) && !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high))) - return "dimValue: integer|Long expected"; - } - if (message.dimParam != null && message.hasOwnProperty("dimParam")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - if (!$util.isString(message.dimParam)) - return "dimParam: string expected"; - } - if (message.denotation != null && message.hasOwnProperty("denotation")) - if (!$util.isString(message.denotation)) - return "denotation: string expected"; - return null; - }; - - /** - * Creates a Dimension message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorShapeProto.Dimension} Dimension - */ - Dimension.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorShapeProto.Dimension) - return object; - var message = new $root.onnx.TensorShapeProto.Dimension(); - if (object.dimValue != null) - if ($util.Long) - (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; - else if (typeof object.dimValue === "string") - message.dimValue = parseInt(object.dimValue, 10); - else if (typeof object.dimValue === "number") - message.dimValue = object.dimValue; - else if (typeof object.dimValue === "object") - message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); - if (object.dimParam != null) - message.dimParam = String(object.dimParam); - if (object.denotation != null) - message.denotation = String(object.denotation); - return message; - }; - - /** - * Creates a plain object from a Dimension message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {onnx.TensorShapeProto.Dimension} message Dimension - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Dimension.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) - object.denotation = ""; - if (message.dimValue != null && message.hasOwnProperty("dimValue")) { - if (typeof message.dimValue === "number") - object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; - else - object.dimValue = options.longs === String ? $util.Long.prototype.toString.call(message.dimValue) : options.longs === Number ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() : message.dimValue; - if (options.oneofs) - object.value = "dimValue"; - } - if (message.dimParam != null && message.hasOwnProperty("dimParam")) { - object.dimParam = message.dimParam; - if (options.oneofs) - object.value = "dimParam"; - } - if (message.denotation != null && message.hasOwnProperty("denotation")) - object.denotation = message.denotation; - return object; - }; - - /** - * Converts this Dimension to JSON. - * @function toJSON - * @memberof onnx.TensorShapeProto.Dimension - * @instance - * @returns {Object.} JSON object - */ - Dimension.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Dimension - * @function getTypeUrl - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TensorShapeProto.Dimension"; - }; - - return Dimension; - })(); - - return TensorShapeProto; - })(); + /** + * Converts this OperatorSetIdProto to JSON. + * @function toJSON + * @memberof onnx.OperatorSetIdProto + * @instance + * @returns {Object.} JSON object + */ + OperatorSetIdProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; - onnx.TypeProto = (function() { - - /** - * Properties of a TypeProto. - * @memberof onnx - * @interface ITypeProto - * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType - * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType - * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType - * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType - * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType - * @property {string|null} [denotation] TypeProto denotation - */ - - /** - * Constructs a new TypeProto. - * @memberof onnx - * @classdesc Represents a TypeProto. - * @implements ITypeProto - * @constructor - * @param {onnx.ITypeProto=} [properties] Properties to set - */ - function TypeProto(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * Gets the default type url for OperatorSetIdProto + * @function getTypeUrl + * @memberof onnx.OperatorSetIdProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.OperatorSetIdProto'; + }; + + return OperatorSetIdProto; + })(); + + /** + * OperatorStatus enum. + * @name onnx.OperatorStatus + * @enum {number} + * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value + * @property {number} STABLE=1 STABLE value + */ + onnx.OperatorStatus = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = 'EXPERIMENTAL')] = 0; + values[(valuesById[1] = 'STABLE')] = 1; + return values; + })(); + + onnx.FunctionProto = (function () { + /** + * Properties of a FunctionProto. + * @memberof onnx + * @interface IFunctionProto + * @property {string|null} [name] FunctionProto name + * @property {Array.|null} [input] FunctionProto input + * @property {Array.|null} [output] FunctionProto output + * @property {Array.|null} [attribute] FunctionProto attribute + * @property {Array.|null} [attributeProto] FunctionProto attributeProto + * @property {Array.|null} [node] FunctionProto node + * @property {string|null} [docString] FunctionProto docString + * @property {Array.|null} [opsetImport] FunctionProto opsetImport + * @property {string|null} [domain] FunctionProto domain + */ - /** - * TypeProto tensorType. - * @member {onnx.TypeProto.ITensor|null|undefined} tensorType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.tensorType = null; - - /** - * TypeProto sequenceType. - * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.sequenceType = null; - - /** - * TypeProto mapType. - * @member {onnx.TypeProto.IMap|null|undefined} mapType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.mapType = null; - - /** - * TypeProto optionalType. - * @member {onnx.TypeProto.IOptional|null|undefined} optionalType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.optionalType = null; - - /** - * TypeProto sparseTensorType. - * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.sparseTensorType = null; - - /** - * TypeProto denotation. - * @member {string} denotation - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.denotation = ""; - - // OneOf field names bound to virtual getters and setters - var $oneOfFields; - - /** - * TypeProto value. - * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value - * @memberof onnx.TypeProto - * @instance - */ - Object.defineProperty(TypeProto.prototype, "value", { - get: $util.oneOfGetter($oneOfFields = ["tensorType", "sequenceType", "mapType", "optionalType", "sparseTensorType"]), - set: $util.oneOfSetter($oneOfFields) - }); - - /** - * Creates a new TypeProto instance using the specified properties. - * @function create - * @memberof onnx.TypeProto - * @static - * @param {onnx.ITypeProto=} [properties] Properties to set - * @returns {onnx.TypeProto} TypeProto instance - */ - TypeProto.create = function create(properties) { - return new TypeProto(properties); - }; - - /** - * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto - * @static - * @param {onnx.ITypeProto} message TypeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TypeProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.tensorType != null && Object.hasOwnProperty.call(message, "tensorType")) - $root.onnx.TypeProto.Tensor.encode(message.tensorType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - if (message.sequenceType != null && Object.hasOwnProperty.call(message, "sequenceType")) - $root.onnx.TypeProto.Sequence.encode(message.sequenceType, writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); - if (message.mapType != null && Object.hasOwnProperty.call(message, "mapType")) - $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); - if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) - writer.uint32(/* id 6, wireType 2 =*/50).string(message.denotation); - if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, "sparseTensorType")) - $root.onnx.TypeProto.SparseTensor.encode(message.sparseTensorType, writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); - if (message.optionalType != null && Object.hasOwnProperty.call(message, "optionalType")) - $root.onnx.TypeProto.Optional.encode(message.optionalType, writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto - * @static - * @param {onnx.ITypeProto} message TypeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TypeProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TypeProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto} TypeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TypeProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); - break; - } - case 4: { - message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); - break; - } - case 5: { - message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); - break; - } - case 9: { - message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); - break; - } - case 8: { - message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); - break; - } - case 6: { - message.denotation = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TypeProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto} TypeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TypeProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TypeProto message. - * @function verify - * @memberof onnx.TypeProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TypeProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - var properties = {}; - if (message.tensorType != null && message.hasOwnProperty("tensorType")) { - properties.value = 1; - { - var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); - if (error) - return "tensorType." + error; - } - } - if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - { - var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); - if (error) - return "sequenceType." + error; - } - } - if (message.mapType != null && message.hasOwnProperty("mapType")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - { - var error = $root.onnx.TypeProto.Map.verify(message.mapType); - if (error) - return "mapType." + error; - } - } - if (message.optionalType != null && message.hasOwnProperty("optionalType")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - { - var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); - if (error) - return "optionalType." + error; - } - } - if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - { - var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); - if (error) - return "sparseTensorType." + error; - } - } - if (message.denotation != null && message.hasOwnProperty("denotation")) - if (!$util.isString(message.denotation)) - return "denotation: string expected"; - return null; - }; - - /** - * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto} TypeProto - */ - TypeProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto) - return object; - var message = new $root.onnx.TypeProto(); - if (object.tensorType != null) { - if (typeof object.tensorType !== "object") - throw TypeError(".onnx.TypeProto.tensorType: object expected"); - message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); - } - if (object.sequenceType != null) { - if (typeof object.sequenceType !== "object") - throw TypeError(".onnx.TypeProto.sequenceType: object expected"); - message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); - } - if (object.mapType != null) { - if (typeof object.mapType !== "object") - throw TypeError(".onnx.TypeProto.mapType: object expected"); - message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); - } - if (object.optionalType != null) { - if (typeof object.optionalType !== "object") - throw TypeError(".onnx.TypeProto.optionalType: object expected"); - message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); - } - if (object.sparseTensorType != null) { - if (typeof object.sparseTensorType !== "object") - throw TypeError(".onnx.TypeProto.sparseTensorType: object expected"); - message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); - } - if (object.denotation != null) - message.denotation = String(object.denotation); - return message; - }; - - /** - * Creates a plain object from a TypeProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto - * @static - * @param {onnx.TypeProto} message TypeProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TypeProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) - object.denotation = ""; - if (message.tensorType != null && message.hasOwnProperty("tensorType")) { - object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); - if (options.oneofs) - object.value = "tensorType"; - } - if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { - object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); - if (options.oneofs) - object.value = "sequenceType"; - } - if (message.mapType != null && message.hasOwnProperty("mapType")) { - object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); - if (options.oneofs) - object.value = "mapType"; - } - if (message.denotation != null && message.hasOwnProperty("denotation")) - object.denotation = message.denotation; - if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { - object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); - if (options.oneofs) - object.value = "sparseTensorType"; - } - if (message.optionalType != null && message.hasOwnProperty("optionalType")) { - object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); - if (options.oneofs) - object.value = "optionalType"; - } - return object; - }; - - /** - * Converts this TypeProto to JSON. - * @function toJSON - * @memberof onnx.TypeProto - * @instance - * @returns {Object.} JSON object - */ - TypeProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TypeProto - * @function getTypeUrl - * @memberof onnx.TypeProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto"; - }; - - TypeProto.Tensor = (function() { - - /** - * Properties of a Tensor. - * @memberof onnx.TypeProto - * @interface ITensor - * @property {number|null} [elemType] Tensor elemType - * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape - */ - - /** - * Constructs a new Tensor. - * @memberof onnx.TypeProto - * @classdesc Represents a Tensor. - * @implements ITensor - * @constructor - * @param {onnx.TypeProto.ITensor=} [properties] Properties to set - */ - function Tensor(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * Constructs a new FunctionProto. + * @memberof onnx + * @classdesc Represents a FunctionProto. + * @implements IFunctionProto + * @constructor + * @param {onnx.IFunctionProto=} [properties] Properties to set + */ + function FunctionProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + this.attributeProto = []; + this.node = []; + this.opsetImport = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } - /** - * Tensor elemType. - * @member {number} elemType - * @memberof onnx.TypeProto.Tensor - * @instance - */ - Tensor.prototype.elemType = 0; - - /** - * Tensor shape. - * @member {onnx.ITensorShapeProto|null|undefined} shape - * @memberof onnx.TypeProto.Tensor - * @instance - */ - Tensor.prototype.shape = null; - - /** - * Creates a new Tensor instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.Tensor - * @static - * @param {onnx.TypeProto.ITensor=} [properties] Properties to set - * @returns {onnx.TypeProto.Tensor} Tensor instance - */ - Tensor.create = function create(properties) { - return new Tensor(properties); - }; - - /** - * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.Tensor - * @static - * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Tensor.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) - writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); - if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) - $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.Tensor - * @static - * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Tensor.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Tensor message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.Tensor - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.Tensor} Tensor - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Tensor.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Tensor(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.elemType = reader.int32(); - break; - } - case 2: { - message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Tensor message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.Tensor - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.Tensor} Tensor - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Tensor.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Tensor message. - * @function verify - * @memberof onnx.TypeProto.Tensor - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Tensor.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.elemType != null && message.hasOwnProperty("elemType")) - if (!$util.isInteger(message.elemType)) - return "elemType: integer expected"; - if (message.shape != null && message.hasOwnProperty("shape")) { - var error = $root.onnx.TensorShapeProto.verify(message.shape); - if (error) - return "shape." + error; - } - return null; - }; - - /** - * Creates a Tensor message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.Tensor - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.Tensor} Tensor - */ - Tensor.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.Tensor) - return object; - var message = new $root.onnx.TypeProto.Tensor(); - if (object.elemType != null) - message.elemType = object.elemType | 0; - if (object.shape != null) { - if (typeof object.shape !== "object") - throw TypeError(".onnx.TypeProto.Tensor.shape: object expected"); - message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); - } - return message; - }; - - /** - * Creates a plain object from a Tensor message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.Tensor - * @static - * @param {onnx.TypeProto.Tensor} message Tensor - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Tensor.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.elemType = 0; - object.shape = null; - } - if (message.elemType != null && message.hasOwnProperty("elemType")) - object.elemType = message.elemType; - if (message.shape != null && message.hasOwnProperty("shape")) - object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); - return object; - }; - - /** - * Converts this Tensor to JSON. - * @function toJSON - * @memberof onnx.TypeProto.Tensor - * @instance - * @returns {Object.} JSON object - */ - Tensor.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Tensor - * @function getTypeUrl - * @memberof onnx.TypeProto.Tensor - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.Tensor"; - }; - - return Tensor; - })(); - - TypeProto.Sequence = (function() { - - /** - * Properties of a Sequence. - * @memberof onnx.TypeProto - * @interface ISequence - * @property {onnx.ITypeProto|null} [elemType] Sequence elemType - */ - - /** - * Constructs a new Sequence. - * @memberof onnx.TypeProto - * @classdesc Represents a Sequence. - * @implements ISequence - * @constructor - * @param {onnx.TypeProto.ISequence=} [properties] Properties to set - */ - function Sequence(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto name. + * @member {string} name + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.name = ''; - /** - * Sequence elemType. - * @member {onnx.ITypeProto|null|undefined} elemType - * @memberof onnx.TypeProto.Sequence - * @instance - */ - Sequence.prototype.elemType = null; - - /** - * Creates a new Sequence instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.Sequence - * @static - * @param {onnx.TypeProto.ISequence=} [properties] Properties to set - * @returns {onnx.TypeProto.Sequence} Sequence instance - */ - Sequence.create = function create(properties) { - return new Sequence(properties); - }; - - /** - * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.Sequence - * @static - * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Sequence.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) - $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.Sequence - * @static - * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Sequence.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Sequence message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.Sequence - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.Sequence} Sequence - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Sequence.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Sequence(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Sequence message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.Sequence - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.Sequence} Sequence - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Sequence.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Sequence message. - * @function verify - * @memberof onnx.TypeProto.Sequence - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Sequence.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.elemType != null && message.hasOwnProperty("elemType")) { - var error = $root.onnx.TypeProto.verify(message.elemType); - if (error) - return "elemType." + error; - } - return null; - }; - - /** - * Creates a Sequence message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.Sequence - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.Sequence} Sequence - */ - Sequence.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.Sequence) - return object; - var message = new $root.onnx.TypeProto.Sequence(); - if (object.elemType != null) { - if (typeof object.elemType !== "object") - throw TypeError(".onnx.TypeProto.Sequence.elemType: object expected"); - message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); - } - return message; - }; - - /** - * Creates a plain object from a Sequence message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.Sequence - * @static - * @param {onnx.TypeProto.Sequence} message Sequence - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Sequence.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) - object.elemType = null; - if (message.elemType != null && message.hasOwnProperty("elemType")) - object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); - return object; - }; - - /** - * Converts this Sequence to JSON. - * @function toJSON - * @memberof onnx.TypeProto.Sequence - * @instance - * @returns {Object.} JSON object - */ - Sequence.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Sequence - * @function getTypeUrl - * @memberof onnx.TypeProto.Sequence - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.Sequence"; - }; - - return Sequence; - })(); - - TypeProto.Map = (function() { - - /** - * Properties of a Map. - * @memberof onnx.TypeProto - * @interface IMap - * @property {number|null} [keyType] Map keyType - * @property {onnx.ITypeProto|null} [valueType] Map valueType - */ - - /** - * Constructs a new Map. - * @memberof onnx.TypeProto - * @classdesc Represents a Map. - * @implements IMap - * @constructor - * @param {onnx.TypeProto.IMap=} [properties] Properties to set - */ - function Map(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto input. + * @member {Array.} input + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.input = $util.emptyArray; - /** - * Map keyType. - * @member {number} keyType - * @memberof onnx.TypeProto.Map - * @instance - */ - Map.prototype.keyType = 0; - - /** - * Map valueType. - * @member {onnx.ITypeProto|null|undefined} valueType - * @memberof onnx.TypeProto.Map - * @instance - */ - Map.prototype.valueType = null; - - /** - * Creates a new Map instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.Map - * @static - * @param {onnx.TypeProto.IMap=} [properties] Properties to set - * @returns {onnx.TypeProto.Map} Map instance - */ - Map.create = function create(properties) { - return new Map(properties); - }; - - /** - * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.Map - * @static - * @param {onnx.TypeProto.IMap} message Map message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Map.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.keyType != null && Object.hasOwnProperty.call(message, "keyType")) - writer.uint32(/* id 1, wireType 0 =*/8).int32(message.keyType); - if (message.valueType != null && Object.hasOwnProperty.call(message, "valueType")) - $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.Map - * @static - * @param {onnx.TypeProto.IMap} message Map message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Map.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Map message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.Map - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.Map} Map - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Map.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Map(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.keyType = reader.int32(); - break; - } - case 2: { - message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Map message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.Map - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.Map} Map - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Map.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Map message. - * @function verify - * @memberof onnx.TypeProto.Map - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Map.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.keyType != null && message.hasOwnProperty("keyType")) - if (!$util.isInteger(message.keyType)) - return "keyType: integer expected"; - if (message.valueType != null && message.hasOwnProperty("valueType")) { - var error = $root.onnx.TypeProto.verify(message.valueType); - if (error) - return "valueType." + error; - } - return null; - }; - - /** - * Creates a Map message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.Map - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.Map} Map - */ - Map.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.Map) - return object; - var message = new $root.onnx.TypeProto.Map(); - if (object.keyType != null) - message.keyType = object.keyType | 0; - if (object.valueType != null) { - if (typeof object.valueType !== "object") - throw TypeError(".onnx.TypeProto.Map.valueType: object expected"); - message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); - } - return message; - }; - - /** - * Creates a plain object from a Map message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.Map - * @static - * @param {onnx.TypeProto.Map} message Map - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Map.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.keyType = 0; - object.valueType = null; - } - if (message.keyType != null && message.hasOwnProperty("keyType")) - object.keyType = message.keyType; - if (message.valueType != null && message.hasOwnProperty("valueType")) - object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); - return object; - }; - - /** - * Converts this Map to JSON. - * @function toJSON - * @memberof onnx.TypeProto.Map - * @instance - * @returns {Object.} JSON object - */ - Map.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Map - * @function getTypeUrl - * @memberof onnx.TypeProto.Map - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.Map"; - }; - - return Map; - })(); - - TypeProto.Optional = (function() { - - /** - * Properties of an Optional. - * @memberof onnx.TypeProto - * @interface IOptional - * @property {onnx.ITypeProto|null} [elemType] Optional elemType - */ - - /** - * Constructs a new Optional. - * @memberof onnx.TypeProto - * @classdesc Represents an Optional. - * @implements IOptional - * @constructor - * @param {onnx.TypeProto.IOptional=} [properties] Properties to set - */ - function Optional(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto output. + * @member {Array.} output + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.output = $util.emptyArray; - /** - * Optional elemType. - * @member {onnx.ITypeProto|null|undefined} elemType - * @memberof onnx.TypeProto.Optional - * @instance - */ - Optional.prototype.elemType = null; - - /** - * Creates a new Optional instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.Optional - * @static - * @param {onnx.TypeProto.IOptional=} [properties] Properties to set - * @returns {onnx.TypeProto.Optional} Optional instance - */ - Optional.create = function create(properties) { - return new Optional(properties); - }; - - /** - * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.Optional - * @static - * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Optional.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) - $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.Optional - * @static - * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Optional.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes an Optional message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.Optional - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.Optional} Optional - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Optional.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Optional(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes an Optional message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.Optional - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.Optional} Optional - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Optional.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies an Optional message. - * @function verify - * @memberof onnx.TypeProto.Optional - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Optional.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.elemType != null && message.hasOwnProperty("elemType")) { - var error = $root.onnx.TypeProto.verify(message.elemType); - if (error) - return "elemType." + error; - } - return null; - }; - - /** - * Creates an Optional message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.Optional - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.Optional} Optional - */ - Optional.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.Optional) - return object; - var message = new $root.onnx.TypeProto.Optional(); - if (object.elemType != null) { - if (typeof object.elemType !== "object") - throw TypeError(".onnx.TypeProto.Optional.elemType: object expected"); - message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); - } - return message; - }; - - /** - * Creates a plain object from an Optional message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.Optional - * @static - * @param {onnx.TypeProto.Optional} message Optional - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Optional.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) - object.elemType = null; - if (message.elemType != null && message.hasOwnProperty("elemType")) - object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); - return object; - }; - - /** - * Converts this Optional to JSON. - * @function toJSON - * @memberof onnx.TypeProto.Optional - * @instance - * @returns {Object.} JSON object - */ - Optional.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Optional - * @function getTypeUrl - * @memberof onnx.TypeProto.Optional - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.Optional"; - }; - - return Optional; - })(); - - TypeProto.SparseTensor = (function() { - - /** - * Properties of a SparseTensor. - * @memberof onnx.TypeProto - * @interface ISparseTensor - * @property {number|null} [elemType] SparseTensor elemType - * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape - */ - - /** - * Constructs a new SparseTensor. - * @memberof onnx.TypeProto - * @classdesc Represents a SparseTensor. - * @implements ISparseTensor - * @constructor - * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set - */ - function SparseTensor(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto attribute. + * @member {Array.} attribute + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attribute = $util.emptyArray; - /** - * SparseTensor elemType. - * @member {number} elemType - * @memberof onnx.TypeProto.SparseTensor - * @instance - */ - SparseTensor.prototype.elemType = 0; - - /** - * SparseTensor shape. - * @member {onnx.ITensorShapeProto|null|undefined} shape - * @memberof onnx.TypeProto.SparseTensor - * @instance - */ - SparseTensor.prototype.shape = null; - - /** - * Creates a new SparseTensor instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set - * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance - */ - SparseTensor.create = function create(properties) { - return new SparseTensor(properties); - }; - - /** - * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - SparseTensor.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) - writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); - if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) - $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a SparseTensor message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.SparseTensor} SparseTensor - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - SparseTensor.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.SparseTensor(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.elemType = reader.int32(); - break; - } - case 2: { - message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a SparseTensor message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.SparseTensor} SparseTensor - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - SparseTensor.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a SparseTensor message. - * @function verify - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - SparseTensor.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.elemType != null && message.hasOwnProperty("elemType")) - if (!$util.isInteger(message.elemType)) - return "elemType: integer expected"; - if (message.shape != null && message.hasOwnProperty("shape")) { - var error = $root.onnx.TensorShapeProto.verify(message.shape); - if (error) - return "shape." + error; - } - return null; - }; - - /** - * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.SparseTensor} SparseTensor - */ - SparseTensor.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.SparseTensor) - return object; - var message = new $root.onnx.TypeProto.SparseTensor(); - if (object.elemType != null) - message.elemType = object.elemType | 0; - if (object.shape != null) { - if (typeof object.shape !== "object") - throw TypeError(".onnx.TypeProto.SparseTensor.shape: object expected"); - message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); - } - return message; - }; - - /** - * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {onnx.TypeProto.SparseTensor} message SparseTensor - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - SparseTensor.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.elemType = 0; - object.shape = null; - } - if (message.elemType != null && message.hasOwnProperty("elemType")) - object.elemType = message.elemType; - if (message.shape != null && message.hasOwnProperty("shape")) - object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); - return object; - }; - - /** - * Converts this SparseTensor to JSON. - * @function toJSON - * @memberof onnx.TypeProto.SparseTensor - * @instance - * @returns {Object.} JSON object - */ - SparseTensor.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for SparseTensor - * @function getTypeUrl - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.SparseTensor"; - }; - - return SparseTensor; - })(); - - return TypeProto; - })(); + /** + * FunctionProto attributeProto. + * @member {Array.} attributeProto + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attributeProto = $util.emptyArray; - onnx.OperatorSetIdProto = (function() { - - /** - * Properties of an OperatorSetIdProto. - * @memberof onnx - * @interface IOperatorSetIdProto - * @property {string|null} [domain] OperatorSetIdProto domain - * @property {number|Long|null} [version] OperatorSetIdProto version - */ - - /** - * Constructs a new OperatorSetIdProto. - * @memberof onnx - * @classdesc Represents an OperatorSetIdProto. - * @implements IOperatorSetIdProto - * @constructor - * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set - */ - function OperatorSetIdProto(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto node. + * @member {Array.} node + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.node = $util.emptyArray; - /** - * OperatorSetIdProto domain. - * @member {string} domain - * @memberof onnx.OperatorSetIdProto - * @instance - */ - OperatorSetIdProto.prototype.domain = ""; - - /** - * OperatorSetIdProto version. - * @member {number|Long} version - * @memberof onnx.OperatorSetIdProto - * @instance - */ - OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * Creates a new OperatorSetIdProto instance using the specified properties. - * @function create - * @memberof onnx.OperatorSetIdProto - * @static - * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set - * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance - */ - OperatorSetIdProto.create = function create(properties) { - return new OperatorSetIdProto(properties); - }; - - /** - * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. - * @function encode - * @memberof onnx.OperatorSetIdProto - * @static - * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - OperatorSetIdProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.domain); - if (message.version != null && Object.hasOwnProperty.call(message, "version")) - writer.uint32(/* id 2, wireType 0 =*/16).int64(message.version); - return writer; - }; - - /** - * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.OperatorSetIdProto - * @static - * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes an OperatorSetIdProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.OperatorSetIdProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - OperatorSetIdProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.OperatorSetIdProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.domain = reader.string(); - break; - } - case 2: { - message.version = reader.int64(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.OperatorSetIdProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies an OperatorSetIdProto message. - * @function verify - * @memberof onnx.OperatorSetIdProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - OperatorSetIdProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.domain != null && message.hasOwnProperty("domain")) - if (!$util.isString(message.domain)) - return "domain: string expected"; - if (message.version != null && message.hasOwnProperty("version")) - if (!$util.isInteger(message.version) && !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high))) - return "version: integer|Long expected"; - return null; - }; - - /** - * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.OperatorSetIdProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto - */ - OperatorSetIdProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.OperatorSetIdProto) - return object; - var message = new $root.onnx.OperatorSetIdProto(); - if (object.domain != null) - message.domain = String(object.domain); - if (object.version != null) - if ($util.Long) - (message.version = $util.Long.fromValue(object.version)).unsigned = false; - else if (typeof object.version === "string") - message.version = parseInt(object.version, 10); - else if (typeof object.version === "number") - message.version = object.version; - else if (typeof object.version === "object") - message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); - return message; - }; - - /** - * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.OperatorSetIdProto - * @static - * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - OperatorSetIdProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.domain = ""; - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.version = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.version = options.longs === String ? "0" : 0; - } - if (message.domain != null && message.hasOwnProperty("domain")) - object.domain = message.domain; - if (message.version != null && message.hasOwnProperty("version")) - if (typeof message.version === "number") - object.version = options.longs === String ? String(message.version) : message.version; - else - object.version = options.longs === String ? $util.Long.prototype.toString.call(message.version) : options.longs === Number ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() : message.version; - return object; - }; - - /** - * Converts this OperatorSetIdProto to JSON. - * @function toJSON - * @memberof onnx.OperatorSetIdProto - * @instance - * @returns {Object.} JSON object - */ - OperatorSetIdProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for OperatorSetIdProto - * @function getTypeUrl - * @memberof onnx.OperatorSetIdProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.OperatorSetIdProto"; - }; + /** + * FunctionProto docString. + * @member {string} docString + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.docString = ''; - return OperatorSetIdProto; - })(); + /** + * FunctionProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.opsetImport = $util.emptyArray; /** - * OperatorStatus enum. - * @name onnx.OperatorStatus - * @enum {number} - * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value - * @property {number} STABLE=1 STABLE value - */ - onnx.OperatorStatus = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "EXPERIMENTAL"] = 0; - values[valuesById[1] = "STABLE"] = 1; - return values; - })(); + * FunctionProto domain. + * @member {string} domain + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.domain = ''; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @function create + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto=} [properties] Properties to set + * @returns {onnx.FunctionProto} FunctionProto instance + */ + FunctionProto.create = function create(properties) { + return new FunctionProto(properties); + }; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encode + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.name); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 4, wireType 2 =*/ 34).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 5, wireType 2 =*/ 42).string(message.output[i]); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.attribute[i]); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/ 58).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 8, wireType 2 =*/ 66).string(message.docString); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode( + message.opsetImport[i], + writer.uint32(/* id 9, wireType 2 =*/ 74).fork(), + ).ldelim(); + if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + writer.uint32(/* id 10, wireType 2 =*/ 82).string(message.domain); + if (message.attributeProto != null && message.attributeProto.length) + for (var i = 0; i < message.attributeProto.length; ++i) + $root.onnx.AttributeProto.encode( + message.attributeProto[i], + writer.uint32(/* id 11, wireType 2 =*/ 90).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; - onnx.FunctionProto = (function() { - - /** - * Properties of a FunctionProto. - * @memberof onnx - * @interface IFunctionProto - * @property {string|null} [name] FunctionProto name - * @property {Array.|null} [input] FunctionProto input - * @property {Array.|null} [output] FunctionProto output - * @property {Array.|null} [attribute] FunctionProto attribute - * @property {Array.|null} [attributeProto] FunctionProto attributeProto - * @property {Array.|null} [node] FunctionProto node - * @property {string|null} [docString] FunctionProto docString - * @property {Array.|null} [opsetImport] FunctionProto opsetImport - * @property {string|null} [domain] FunctionProto domain - */ - - /** - * Constructs a new FunctionProto. - * @memberof onnx - * @classdesc Represents a FunctionProto. - * @implements IFunctionProto - * @constructor - * @param {onnx.IFunctionProto=} [properties] Properties to set - */ - function FunctionProto(properties) { - this.input = []; - this.output = []; - this.attribute = []; - this.attributeProto = []; - this.node = []; - this.opsetImport = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.FunctionProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 4: { + if (!(message.input && message.input.length)) message.input = []; + message.input.push(reader.string()); + break; + } + case 5: { + if (!(message.output && message.output.length)) message.output = []; + message.output.push(reader.string()); + break; + } + case 6: { + if (!(message.attribute && message.attribute.length)) message.attribute = []; + message.attribute.push(reader.string()); + break; + } + case 11: { + if (!(message.attributeProto && message.attributeProto.length)) message.attributeProto = []; + message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 7: { + if (!(message.node && message.node.length)) message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 8: { + message.docString = reader.string(); + break; + } + case 9: { + if (!(message.opsetImport && message.opsetImport.length)) message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.domain = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * FunctionProto name. - * @member {string} name - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.name = ""; - - /** - * FunctionProto input. - * @member {Array.} input - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.input = $util.emptyArray; - - /** - * FunctionProto output. - * @member {Array.} output - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.output = $util.emptyArray; - - /** - * FunctionProto attribute. - * @member {Array.} attribute - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.attribute = $util.emptyArray; - - /** - * FunctionProto attributeProto. - * @member {Array.} attributeProto - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.attributeProto = $util.emptyArray; - - /** - * FunctionProto node. - * @member {Array.} node - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.node = $util.emptyArray; - - /** - * FunctionProto docString. - * @member {string} docString - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.docString = ""; - - /** - * FunctionProto opsetImport. - * @member {Array.} opsetImport - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.opsetImport = $util.emptyArray; - - /** - * FunctionProto domain. - * @member {string} domain - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.domain = ""; - - /** - * Creates a new FunctionProto instance using the specified properties. - * @function create - * @memberof onnx.FunctionProto - * @static - * @param {onnx.IFunctionProto=} [properties] Properties to set - * @returns {onnx.FunctionProto} FunctionProto instance - */ - FunctionProto.create = function create(properties) { - return new FunctionProto(properties); - }; - - /** - * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. - * @function encode - * @memberof onnx.FunctionProto - * @static - * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - FunctionProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); - if (message.input != null && message.input.length) - for (var i = 0; i < message.input.length; ++i) - writer.uint32(/* id 4, wireType 2 =*/34).string(message.input[i]); - if (message.output != null && message.output.length) - for (var i = 0; i < message.output.length; ++i) - writer.uint32(/* id 5, wireType 2 =*/42).string(message.output[i]); - if (message.attribute != null && message.attribute.length) - for (var i = 0; i < message.attribute.length; ++i) - writer.uint32(/* id 6, wireType 2 =*/50).string(message.attribute[i]); - if (message.node != null && message.node.length) - for (var i = 0; i < message.node.length; ++i) - $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 8, wireType 2 =*/66).string(message.docString); - if (message.opsetImport != null && message.opsetImport.length) - for (var i = 0; i < message.opsetImport.length; ++i) - $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); - if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) - writer.uint32(/* id 10, wireType 2 =*/82).string(message.domain); - if (message.attributeProto != null && message.attributeProto.length) - for (var i = 0; i < message.attributeProto.length; ++i) - $root.onnx.AttributeProto.encode(message.attributeProto[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.FunctionProto - * @static - * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a FunctionProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.FunctionProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.FunctionProto} FunctionProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - FunctionProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.FunctionProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.name = reader.string(); - break; - } - case 4: { - if (!(message.input && message.input.length)) - message.input = []; - message.input.push(reader.string()); - break; - } - case 5: { - if (!(message.output && message.output.length)) - message.output = []; - message.output.push(reader.string()); - break; - } - case 6: { - if (!(message.attribute && message.attribute.length)) - message.attribute = []; - message.attribute.push(reader.string()); - break; - } - case 11: { - if (!(message.attributeProto && message.attributeProto.length)) - message.attributeProto = []; - message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); - break; - } - case 7: { - if (!(message.node && message.node.length)) - message.node = []; - message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); - break; - } - case 8: { - message.docString = reader.string(); - break; - } - case 9: { - if (!(message.opsetImport && message.opsetImport.length)) - message.opsetImport = []; - message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); - break; - } - case 10: { - message.domain = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a FunctionProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.FunctionProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.FunctionProto} FunctionProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - FunctionProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a FunctionProto message. - * @function verify - * @memberof onnx.FunctionProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - FunctionProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.input != null && message.hasOwnProperty("input")) { - if (!Array.isArray(message.input)) - return "input: array expected"; - for (var i = 0; i < message.input.length; ++i) - if (!$util.isString(message.input[i])) - return "input: string[] expected"; - } - if (message.output != null && message.hasOwnProperty("output")) { - if (!Array.isArray(message.output)) - return "output: array expected"; - for (var i = 0; i < message.output.length; ++i) - if (!$util.isString(message.output[i])) - return "output: string[] expected"; - } - if (message.attribute != null && message.hasOwnProperty("attribute")) { - if (!Array.isArray(message.attribute)) - return "attribute: array expected"; - for (var i = 0; i < message.attribute.length; ++i) - if (!$util.isString(message.attribute[i])) - return "attribute: string[] expected"; - } - if (message.attributeProto != null && message.hasOwnProperty("attributeProto")) { - if (!Array.isArray(message.attributeProto)) - return "attributeProto: array expected"; - for (var i = 0; i < message.attributeProto.length; ++i) { - var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); - if (error) - return "attributeProto." + error; - } - } - if (message.node != null && message.hasOwnProperty("node")) { - if (!Array.isArray(message.node)) - return "node: array expected"; - for (var i = 0; i < message.node.length; ++i) { - var error = $root.onnx.NodeProto.verify(message.node[i]); - if (error) - return "node." + error; - } - } - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { - if (!Array.isArray(message.opsetImport)) - return "opsetImport: array expected"; - for (var i = 0; i < message.opsetImport.length; ++i) { - var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); - if (error) - return "opsetImport." + error; - } - } - if (message.domain != null && message.hasOwnProperty("domain")) - if (!$util.isString(message.domain)) - return "domain: string expected"; - return null; - }; - - /** - * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.FunctionProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.FunctionProto} FunctionProto - */ - FunctionProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.FunctionProto) - return object; - var message = new $root.onnx.FunctionProto(); - if (object.name != null) - message.name = String(object.name); - if (object.input) { - if (!Array.isArray(object.input)) - throw TypeError(".onnx.FunctionProto.input: array expected"); - message.input = []; - for (var i = 0; i < object.input.length; ++i) - message.input[i] = String(object.input[i]); - } - if (object.output) { - if (!Array.isArray(object.output)) - throw TypeError(".onnx.FunctionProto.output: array expected"); - message.output = []; - for (var i = 0; i < object.output.length; ++i) - message.output[i] = String(object.output[i]); - } - if (object.attribute) { - if (!Array.isArray(object.attribute)) - throw TypeError(".onnx.FunctionProto.attribute: array expected"); - message.attribute = []; - for (var i = 0; i < object.attribute.length; ++i) - message.attribute[i] = String(object.attribute[i]); - } - if (object.attributeProto) { - if (!Array.isArray(object.attributeProto)) - throw TypeError(".onnx.FunctionProto.attributeProto: array expected"); - message.attributeProto = []; - for (var i = 0; i < object.attributeProto.length; ++i) { - if (typeof object.attributeProto[i] !== "object") - throw TypeError(".onnx.FunctionProto.attributeProto: object expected"); - message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); - } - } - if (object.node) { - if (!Array.isArray(object.node)) - throw TypeError(".onnx.FunctionProto.node: array expected"); - message.node = []; - for (var i = 0; i < object.node.length; ++i) { - if (typeof object.node[i] !== "object") - throw TypeError(".onnx.FunctionProto.node: object expected"); - message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); - } - } - if (object.docString != null) - message.docString = String(object.docString); - if (object.opsetImport) { - if (!Array.isArray(object.opsetImport)) - throw TypeError(".onnx.FunctionProto.opsetImport: array expected"); - message.opsetImport = []; - for (var i = 0; i < object.opsetImport.length; ++i) { - if (typeof object.opsetImport[i] !== "object") - throw TypeError(".onnx.FunctionProto.opsetImport: object expected"); - message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); - } - } - if (object.domain != null) - message.domain = String(object.domain); - return message; - }; - - /** - * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.FunctionProto - * @static - * @param {onnx.FunctionProto} message FunctionProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - FunctionProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.input = []; - object.output = []; - object.attribute = []; - object.node = []; - object.opsetImport = []; - object.attributeProto = []; - } - if (options.defaults) { - object.name = ""; - object.docString = ""; - object.domain = ""; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.input && message.input.length) { - object.input = []; - for (var j = 0; j < message.input.length; ++j) - object.input[j] = message.input[j]; - } - if (message.output && message.output.length) { - object.output = []; - for (var j = 0; j < message.output.length; ++j) - object.output[j] = message.output[j]; - } - if (message.attribute && message.attribute.length) { - object.attribute = []; - for (var j = 0; j < message.attribute.length; ++j) - object.attribute[j] = message.attribute[j]; - } - if (message.node && message.node.length) { - object.node = []; - for (var j = 0; j < message.node.length; ++j) - object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.opsetImport && message.opsetImport.length) { - object.opsetImport = []; - for (var j = 0; j < message.opsetImport.length; ++j) - object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); - } - if (message.domain != null && message.hasOwnProperty("domain")) - object.domain = message.domain; - if (message.attributeProto && message.attributeProto.length) { - object.attributeProto = []; - for (var j = 0; j < message.attributeProto.length; ++j) - object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); - } - return object; - }; - - /** - * Converts this FunctionProto to JSON. - * @function toJSON - * @memberof onnx.FunctionProto - * @instance - * @returns {Object.} JSON object - */ - FunctionProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for FunctionProto - * @function getTypeUrl - * @memberof onnx.FunctionProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.FunctionProto"; - }; + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; - return FunctionProto; - })(); + /** + * Verifies a FunctionProto message. + * @function verify + * @memberof onnx.FunctionProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + FunctionProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.input != null && message.hasOwnProperty('input')) { + if (!Array.isArray(message.input)) return 'input: array expected'; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) return 'input: string[] expected'; + } + if (message.output != null && message.hasOwnProperty('output')) { + if (!Array.isArray(message.output)) return 'output: array expected'; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) return 'output: string[] expected'; + } + if (message.attribute != null && message.hasOwnProperty('attribute')) { + if (!Array.isArray(message.attribute)) return 'attribute: array expected'; + for (var i = 0; i < message.attribute.length; ++i) + if (!$util.isString(message.attribute[i])) return 'attribute: string[] expected'; + } + if (message.attributeProto != null && message.hasOwnProperty('attributeProto')) { + if (!Array.isArray(message.attributeProto)) return 'attributeProto: array expected'; + for (var i = 0; i < message.attributeProto.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); + if (error) return 'attributeProto.' + error; + } + } + if (message.node != null && message.hasOwnProperty('node')) { + if (!Array.isArray(message.node)) return 'node: array expected'; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) return 'node.' + error; + } + } + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.opsetImport != null && message.hasOwnProperty('opsetImport')) { + if (!Array.isArray(message.opsetImport)) return 'opsetImport: array expected'; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) return 'opsetImport.' + error; + } + } + if (message.domain != null && message.hasOwnProperty('domain')) + if (!$util.isString(message.domain)) return 'domain: string expected'; + return null; + }; + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.FunctionProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.FunctionProto} FunctionProto + */ + FunctionProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.FunctionProto) return object; + var message = new $root.onnx.FunctionProto(); + if (object.name != null) message.name = String(object.name); + if (object.input) { + if (!Array.isArray(object.input)) throw TypeError('.onnx.FunctionProto.input: array expected'); + message.input = []; + for (var i = 0; i < object.input.length; ++i) message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) throw TypeError('.onnx.FunctionProto.output: array expected'); + message.output = []; + for (var i = 0; i < object.output.length; ++i) message.output[i] = String(object.output[i]); + } + if (object.attribute) { + if (!Array.isArray(object.attribute)) throw TypeError('.onnx.FunctionProto.attribute: array expected'); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) message.attribute[i] = String(object.attribute[i]); + } + if (object.attributeProto) { + if (!Array.isArray(object.attributeProto)) + throw TypeError('.onnx.FunctionProto.attributeProto: array expected'); + message.attributeProto = []; + for (var i = 0; i < object.attributeProto.length; ++i) { + if (typeof object.attributeProto[i] !== 'object') + throw TypeError('.onnx.FunctionProto.attributeProto: object expected'); + message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); + } + } + if (object.node) { + if (!Array.isArray(object.node)) throw TypeError('.onnx.FunctionProto.node: array expected'); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== 'object') throw TypeError('.onnx.FunctionProto.node: object expected'); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.docString != null) message.docString = String(object.docString); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) throw TypeError('.onnx.FunctionProto.opsetImport: array expected'); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== 'object') + throw TypeError('.onnx.FunctionProto.opsetImport: object expected'); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.domain != null) message.domain = String(object.domain); + return message; + }; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.FunctionProto + * @static + * @param {onnx.FunctionProto} message FunctionProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + FunctionProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + object.node = []; + object.opsetImport = []; + object.attributeProto = []; + } + if (options.defaults) { + object.name = ''; + object.docString = ''; + object.domain = ''; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) object.output[j] = message.output[j]; + } + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) object.attribute[j] = message.attribute[j]; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + if (message.attributeProto && message.attributeProto.length) { + object.attributeProto = []; + for (var j = 0; j < message.attributeProto.length; ++j) + object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); + } + return object; + }; + + /** + * Converts this FunctionProto to JSON. + * @function toJSON + * @memberof onnx.FunctionProto + * @instance + * @returns {Object.} JSON object + */ + FunctionProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for FunctionProto + * @function getTypeUrl + * @memberof onnx.FunctionProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.FunctionProto'; + }; + + return FunctionProto; + })(); - return onnx; + return onnx; })(); module.exports = $root; diff --git a/js/node/test/test-main.ts b/js/node/test/test-main.ts index 35b5d0006fca9..fc792179d3373 100644 --- a/js/node/test/test-main.ts +++ b/js/node/test/test-main.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {NODE_TESTS_ROOT, warmup} from './test-utils'; +import { NODE_TESTS_ROOT, warmup } from './test-utils'; // require onnxruntime-node. require('..'); @@ -22,7 +22,7 @@ require('./e2e/simple-e2e-tests'); require('./e2e/inference-session-run'); // Test ONNX spec tests -import {run as runTestRunner} from './test-runner'; +import { run as runTestRunner } from './test-runner'; describe('ONNX spec tests', () => { runTestRunner(NODE_TESTS_ROOT); }); diff --git a/js/node/test/test-runner.ts b/js/node/test/test-runner.ts index 06ed0acfca36c..160fa17e80f5f 100644 --- a/js/node/test/test-runner.ts +++ b/js/node/test/test-runner.ts @@ -2,10 +2,10 @@ // Licensed under the MIT License. import * as fs from 'fs-extra'; -import {InferenceSession, Tensor} from 'onnxruntime-common'; +import { InferenceSession, Tensor } from 'onnxruntime-common'; import * as path from 'path'; -import {assertTensorEqual, atol, loadTensorFromFile, rtol, shouldSkipModel} from './test-utils'; +import { assertTensorEqual, atol, loadTensorFromFile, rtol, shouldSkipModel } from './test-utils'; export function run(testDataRoot: string): void { const opsets = fs.readdirSync(testDataRoot); @@ -19,7 +19,7 @@ export function run(testDataRoot: string): void { // read each model folders const modelFolder = path.join(testDataFolder, model); let modelPath: string; - const modelTestCases: Array<[Array, Array]> = []; + const modelTestCases: Array<[Array, Array]> = []; for (const currentFile of fs.readdirSync(modelFolder)) { const currentPath = path.join(modelFolder, currentFile); const stat = fs.lstatSync(currentPath); @@ -29,14 +29,14 @@ export function run(testDataRoot: string): void { modelPath = currentPath; } } else if (stat.isDirectory()) { - const inputs: Array = []; - const outputs: Array = []; + const inputs: Array = []; + const outputs: Array = []; for (const dataFile of fs.readdirSync(currentPath)) { const dataFileFullPath = path.join(currentPath, dataFile); const ext = path.extname(dataFile); if (ext.toLowerCase() === '.pb') { - let tensor: Tensor|undefined; + let tensor: Tensor | undefined; try { tensor = loadTensorFromFile(dataFileFullPath); } catch (e) { @@ -56,7 +56,7 @@ export function run(testDataRoot: string): void { // add cases describe(`${opset}/${model}`, () => { - let session: InferenceSession|null = null; + let session: InferenceSession | null = null; let skipModel = shouldSkipModel(model, opset, ['cpu']); if (!skipModel) { before(async () => { @@ -68,8 +68,10 @@ export function run(testDataRoot: string): void { // fails. Since this is by design such a failure is acceptable in the context of this test. Therefore we // simply skip this test. Setting env variable ALLOW_RELEASED_ONNX_OPSET_ONLY=0 allows loading a model // with opset > released onnx opset. - if (process.env.ALLOW_RELEASED_ONNX_OPSET_ONLY !== '0' && - e.message.includes('ValidateOpsetForDomain')) { + if ( + process.env.ALLOW_RELEASED_ONNX_OPSET_ONLY !== '0' && + e.message.includes('ValidateOpsetForDomain') + ) { session = null; console.log(`Skipping ${model}. To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY=0`); skipModel = true; @@ -86,7 +88,7 @@ export function run(testDataRoot: string): void { const testCase = modelTestCases[i]; const inputs = testCase[0]; const expectedOutputs = testCase[1]; - if (!skipModel && !inputs.some(t => t === undefined) && !expectedOutputs.some(t => t === undefined)) { + if (!skipModel && !inputs.some((t) => t === undefined) && !expectedOutputs.some((t) => t === undefined)) { it(`case${i}`, async () => { if (skipModel) { return; diff --git a/js/node/test/test-utils.ts b/js/node/test/test-utils.ts index 3eef90356a335..72ed2c3db2b6e 100644 --- a/js/node/test/test-utils.ts +++ b/js/node/test/test-utils.ts @@ -3,8 +3,8 @@ import assert from 'assert'; import * as fs from 'fs-extra'; -import {jsonc} from 'jsonc'; -import {InferenceSession, Tensor} from 'onnxruntime-common'; +import { jsonc } from 'jsonc'; +import { InferenceSession, Tensor } from 'onnxruntime-common'; import * as path from 'path'; import * as onnx_proto from './ort-schema/protobuf/onnx'; @@ -18,12 +18,15 @@ export const NODE_TESTS_ROOT = path.join(ORT_ROOT, 'js/test/data/node'); export const SQUEEZENET_INPUT0_DATA: number[] = require(path.join(TEST_DATA_ROOT, 'squeezenet.input0.json')); export const SQUEEZENET_OUTPUT0_DATA: number[] = require(path.join(TEST_DATA_ROOT, 'squeezenet.output0.json')); -const BACKEND_TEST_SERIES_FILTERS: {[name: string]: Array} = - jsonc.readSync(path.join(ORT_ROOT, 'onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc')); +const BACKEND_TEST_SERIES_FILTERS: { [name: string]: Array } = jsonc.readSync( + path.join(ORT_ROOT, 'onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc'), +); const OVERRIDES: { - atol_default: number; rtol_default: number; atol_overrides: {[name: string]: number}; - rtol_overrides: {[name: string]: number}; + atol_default: number; + rtol_default: number; + atol_overrides: { [name: string]: number }; + rtol_overrides: { [name: string]: number }; } = jsonc.readSync(path.join(ORT_ROOT, 'onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc')); const ATOL_DEFAULT = OVERRIDES.atol_default; @@ -55,14 +58,14 @@ export function createTestData(type: Tensor.Type, length: number): Tensor.DataTy } else { data = new (NUMERIC_TYPE_MAP.get(type)!)(length); for (let i = 0; i < length; i++) { - data[i] = (type === 'uint64' || type === 'int64') ? BigInt(i) : i; + data[i] = type === 'uint64' || type === 'int64' ? BigInt(i) : i; } } return data; } // a simple function to create a tensor for test -export function createTestTensor(type: Tensor.Type, lengthOrDims?: number|number[]): Tensor { +export function createTestTensor(type: Tensor.Type, lengthOrDims?: number | number[]): Tensor { let length = 100; let dims = [100]; if (typeof lengthOrDims === 'number') { @@ -78,28 +81,31 @@ export function createTestTensor(type: Tensor.Type, lengthOrDims?: number|number // call the addon directly to make sure DLL is loaded export function warmup(): void { - describe('Warmup', async function() { + describe('Warmup', async function () { // eslint-disable-next-line no-invalid-this this.timeout(0); // we have test cases to verify correctness in other place, so do no check here. try { const session = await InferenceSession.create(path.join(TEST_DATA_ROOT, 'test_types_int32.onnx')); - await session.run({input: new Tensor(new Float32Array(5), [1, 5])}, {output: null}, {}); - } catch (e) { - } + await session.run({ input: new Tensor(new Float32Array(5), [1, 5]) }, { output: null }, {}); + } catch (e) {} }); } export function assertFloatEqual( - actual: number[]|Float32Array|Float64Array, expected: number[]|Float32Array|Float64Array, atol?: number, - rtol?: number): void { + actual: number[] | Float32Array | Float64Array, + expected: number[] | Float32Array | Float64Array, + atol?: number, + rtol?: number, +): void { const absolute_tol: number = atol ?? 1.0e-4; const relative_tol: number = 1 + (rtol ?? 1.0e-6); assert.strictEqual(actual.length, expected.length); for (let i = actual.length - 1; i >= 0; i--) { - const a = actual[i], b = expected[i]; + const a = actual[i], + b = expected[i]; if (a === b) { continue; @@ -108,7 +114,7 @@ export function assertFloatEqual( // check for NaN // if (Number.isNaN(a) && Number.isNaN(b)) { - continue; // 2 numbers are NaN, treat as equal + continue; // 2 numbers are NaN, treat as equal } if (Number.isNaN(a) || Number.isNaN(b)) { // one is NaN and the other is not @@ -124,10 +130,10 @@ export function assertFloatEqual( // endif // if (Math.abs(a - b) < absolute_tol) { - continue; // absolute error check pass + continue; // absolute error check pass } if (a !== 0 && b !== 0 && a * b > 0 && a / b < relative_tol && b / a < relative_tol) { - continue; // relative error check pass + continue; // relative error check pass } // if code goes here, it means both (abs/rel) check failed. @@ -136,13 +142,21 @@ export function assertFloatEqual( } export function assertDataEqual( - type: Tensor.Type, actual: Tensor.DataType, expected: Tensor.DataType, atol?: number, rtol?: number): void { + type: Tensor.Type, + actual: Tensor.DataType, + expected: Tensor.DataType, + atol?: number, + rtol?: number, +): void { switch (type) { case 'float32': case 'float64': assertFloatEqual( - actual as number[] | Float32Array | Float64Array, expected as number[] | Float32Array | Float64Array, atol, - rtol); + actual as number[] | Float32Array | Float64Array, + expected as number[] | Float32Array | Float64Array, + atol, + rtol, + ); break; case 'uint8': @@ -186,11 +200,15 @@ export function loadTensorFromFile(pbFile: string): Tensor { const tensorProto = onnx_proto.onnx.TensorProto.decode(fs.readFileSync(pbFile)); let transferredTypedArray: Tensor.DataType; let type: Tensor.Type; - const dims = tensorProto.dims.map((dim) => typeof dim === 'number' ? dim : dim.toNumber()); - - - if (tensorProto.dataType === 8) { // string - return new Tensor('string', tensorProto.stringData.map(i => i.toString()), dims); + const dims = tensorProto.dims.map((dim) => (typeof dim === 'number' ? dim : dim.toNumber())); + + if (tensorProto.dataType === 8) { + // string + return new Tensor( + 'string', + tensorProto.stringData.map((i) => i.toString()), + dims, + ); } else { switch (tensorProto.dataType) { // FLOAT = 1, @@ -253,16 +271,19 @@ export function loadTensorFromFile(pbFile: string): Tensor { default: throw new Error(`not supported tensor type: ${tensorProto.dataType}`); } - const transferredTypedArrayRawDataView = - new Uint8Array(transferredTypedArray.buffer, transferredTypedArray.byteOffset, tensorProto.rawData.byteLength); + const transferredTypedArrayRawDataView = new Uint8Array( + transferredTypedArray.buffer, + transferredTypedArray.byteOffset, + tensorProto.rawData.byteLength, + ); transferredTypedArrayRawDataView.set(tensorProto.rawData); return new Tensor(type, transferredTypedArray, dims); } } -function loadFiltersRegex(): Array<{opset?: RegExp | undefined; name: RegExp}> { - const filters: Array = ['(FLOAT16)']; +function loadFiltersRegex(): Array<{ opset?: RegExp | undefined; name: RegExp }> { + const filters: Array = ['(FLOAT16)']; filters.push(...BACKEND_TEST_SERIES_FILTERS.current_failing_tests); if (process.arch === 'ia32') { @@ -276,9 +297,11 @@ function loadFiltersRegex(): Array<{opset?: RegExp | undefined; name: RegExp}> { filters.push(...BACKEND_TEST_SERIES_FILTERS.failing_permanently_nodejs_binding); - return filters.map( - filter => typeof filter === 'string' ? {name: new RegExp(filter)} : - {opset: new RegExp(filter[0]), name: new RegExp(filter[1])}); + return filters.map((filter) => + typeof filter === 'string' + ? { name: new RegExp(filter) } + : { opset: new RegExp(filter[0]), name: new RegExp(filter[1]) }, + ); } const BACKEND_TEST_SERIES_FILTERS_REGEX = loadFiltersRegex(); diff --git a/js/node/test/unittests/lib/inference-session.ts b/js/node/test/unittests/lib/inference-session.ts index d8d961cc94398..645f62cece135 100644 --- a/js/node/test/unittests/lib/inference-session.ts +++ b/js/node/test/unittests/lib/inference-session.ts @@ -3,10 +3,10 @@ import assert from 'assert'; import * as fs from 'fs'; -import {InferenceSession, Tensor, TypedTensor} from 'onnxruntime-common'; +import { InferenceSession, Tensor, TypedTensor } from 'onnxruntime-common'; import * as path from 'path'; -import {assertTensorEqual} from '../../test-utils'; +import { assertTensorEqual } from '../../test-utils'; const SQUEEZENET_INPUT0_DATA = require(path.join(__dirname, '../../testdata/squeezenet.input0.json')); const SQUEEZENET_OUTPUT0_DATA = require(path.join(__dirname, '../../testdata/squeezenet.output0.json')); @@ -18,55 +18,85 @@ describe('UnitTests - InferenceSession.create()', () => { // #region test bad arguments it('BAD CALL - no argument', async () => { - await assert.rejects(async () => { - await createAny(); - }, {name: 'TypeError', message: /argument\[0\]/}); + await assert.rejects( + async () => { + await createAny(); + }, + { name: 'TypeError', message: /argument\[0\]/ }, + ); }); it('BAD CALL - byteOffset negative number (ArrayBuffer, number)', async () => { - await assert.rejects(async () => { - await createAny(modelBuffer.buffer, -1); - }, {name: 'RangeError', message: /'byteOffset'/}); + await assert.rejects( + async () => { + await createAny(modelBuffer.buffer, -1); + }, + { name: 'RangeError', message: /'byteOffset'/ }, + ); }); it('BAD CALL - byteOffset out of range (ArrayBuffer, number)', async () => { - await assert.rejects(async () => { - await createAny(modelBuffer.buffer, 100000000); - }, {name: 'RangeError', message: /'byteOffset'/}); + await assert.rejects( + async () => { + await createAny(modelBuffer.buffer, 100000000); + }, + { name: 'RangeError', message: /'byteOffset'/ }, + ); }); it('BAD CALL - byteLength negative number (ArrayBuffer, number)', async () => { - await assert.rejects(async () => { - await createAny(modelBuffer.buffer, 0, -1); - }, {name: 'RangeError', message: /'byteLength'/}); + await assert.rejects( + async () => { + await createAny(modelBuffer.buffer, 0, -1); + }, + { name: 'RangeError', message: /'byteLength'/ }, + ); }); it('BAD CALL - byteLength out of range (ArrayBuffer, number)', async () => { - await assert.rejects(async () => { - await createAny(modelBuffer.buffer, 0, 100000000); - }, {name: 'RangeError', message: /'byteLength'/}); + await assert.rejects( + async () => { + await createAny(modelBuffer.buffer, 0, 100000000); + }, + { name: 'RangeError', message: /'byteLength'/ }, + ); }); it('BAD CALL - options type mismatch (string, string)', async () => { - await assert.rejects(async () => { - await createAny(modelPath, 'cpu'); - }, {name: 'TypeError', message: /'options'/}); + await assert.rejects( + async () => { + await createAny(modelPath, 'cpu'); + }, + { name: 'TypeError', message: /'options'/ }, + ); }); it('BAD CALL - options type mismatch (Uint8Array, string)', async () => { - await assert.rejects(async () => { - await createAny(modelBuffer, 'cpu'); - }, {name: 'TypeError', message: /'options'/}); + await assert.rejects( + async () => { + await createAny(modelBuffer, 'cpu'); + }, + { name: 'TypeError', message: /'options'/ }, + ); }); it('BAD CALL - options type mismatch (ArrayBuffer, number, number, string)', async () => { - await assert.rejects(async () => { - await createAny(modelBuffer.buffer, modelBuffer.byteOffset, modelBuffer.byteLength, 'cpu'); - }, {name: 'TypeError', message: /'options'/}); + await assert.rejects( + async () => { + await createAny(modelBuffer.buffer, modelBuffer.byteOffset, modelBuffer.byteLength, 'cpu'); + }, + { name: 'TypeError', message: /'options'/ }, + ); }); it('EXPECTED FAILURE - Load model failed', async () => { - await assert.rejects(async () => { - await InferenceSession.create('/this/is/an/invalid/path.onnx'); - }, {name: 'Error', message: /failed/}); + await assert.rejects( + async () => { + await InferenceSession.create('/this/is/an/invalid/path.onnx'); + }, + { name: 'Error', message: /failed/ }, + ); }); it('EXPECTED FAILURE - empty buffer', async () => { - await assert.rejects(async () => { - await InferenceSession.create(new Uint8Array(0)); - }, {name: 'Error', message: /No graph was found in the protobuf/}); + await assert.rejects( + async () => { + await InferenceSession.create(new Uint8Array(0)); + }, + { name: 'Error', message: /No graph was found in the protobuf/ }, + ); }); // #endregion @@ -81,7 +111,7 @@ describe('UnitTests - InferenceSession.create()', () => { }); describe('UnitTests - InferenceSession.run()', () => { - let session: InferenceSession|null = null; + let session: InferenceSession | null = null; let sessionAny: any; const input0 = new Tensor('float32', SQUEEZENET_INPUT0_DATA, [1, 3, 224, 224]); const expectedOutput0 = new Tensor('float32', SQUEEZENET_OUTPUT0_DATA, [1, 1000, 1, 1]); @@ -93,50 +123,67 @@ describe('UnitTests - InferenceSession.run()', () => { // #region test bad input(feeds) it('BAD CALL - input type mismatch (null)', async () => { - await assert.rejects(async () => { - await sessionAny.run(null); - }, {name: 'TypeError', message: /'feeds'/}); + await assert.rejects( + async () => { + await sessionAny.run(null); + }, + { name: 'TypeError', message: /'feeds'/ }, + ); }); it('BAD CALL - input type mismatch (single tensor)', async () => { - await assert.rejects(async () => { - await sessionAny.run(input0); - }, {name: 'TypeError', message: /'feeds'/}); + await assert.rejects( + async () => { + await sessionAny.run(input0); + }, + { name: 'TypeError', message: /'feeds'/ }, + ); }); it('BAD CALL - input type mismatch (tensor array)', async () => { - await assert.rejects(async () => { - await sessionAny.run([input0]); - }, {name: 'TypeError', message: /'feeds'/}); + await assert.rejects( + async () => { + await sessionAny.run([input0]); + }, + { name: 'TypeError', message: /'feeds'/ }, + ); }); it('EXPECTED FAILURE - input name missing', async () => { - await assert.rejects(async () => { - await sessionAny.run({}); - }, {name: 'Error', message: /input 'data_0' is missing/}); + await assert.rejects( + async () => { + await sessionAny.run({}); + }, + { name: 'Error', message: /input 'data_0' is missing/ }, + ); }); it('EXPECTED FAILURE - input name incorrect', async () => { - await assert.rejects(async () => { - await sessionAny.run({'data_1': input0}); // correct name should be 'data_0' - }, {name: 'Error', message: /input 'data_0' is missing/}); + await assert.rejects( + async () => { + await sessionAny.run({ data_1: input0 }); // correct name should be 'data_0' + }, + { name: 'Error', message: /input 'data_0' is missing/ }, + ); }); // #endregion // #region test fetches overrides it('run() - no fetches', async () => { - const result = await session!.run({'data_0': input0}); + const result = await session!.run({ data_0: input0 }); assertTensorEqual(result.softmaxout_1, expectedOutput0); }); it('run() - fetches names', async () => { - const result = await session!.run({'data_0': input0}, ['softmaxout_1']); + const result = await session!.run({ data_0: input0 }, ['softmaxout_1']); assertTensorEqual(result.softmaxout_1, expectedOutput0); }); it('run() - fetches object', async () => { - const result = await session!.run({'data_0': input0}, {'softmaxout_1': null}); + const result = await session!.run({ data_0: input0 }, { softmaxout_1: null }); assertTensorEqual(result.softmaxout_1, expectedOutput0); }); // TODO: enable after buffer reuse is implemented it.skip('run() - fetches object (pre-allocated)', async () => { const preAllocatedOutputBuffer = new Float32Array(expectedOutput0.size); const result = await session!.run( - {'data_0': input0}, {'softmaxout_1': new Tensor(preAllocatedOutputBuffer, expectedOutput0.dims)}); + { data_0: input0 }, + { softmaxout_1: new Tensor(preAllocatedOutputBuffer, expectedOutput0.dims) }, + ); const softmaxout_1 = result.softmaxout_1 as TypedTensor<'float32'>; assert.strictEqual(softmaxout_1.data.buffer, preAllocatedOutputBuffer.buffer); assert.strictEqual(softmaxout_1.data.byteOffset, preAllocatedOutputBuffer.byteOffset); @@ -146,42 +193,65 @@ describe('UnitTests - InferenceSession.run()', () => { // #region test bad output(fetches) it('BAD CALL - fetches type mismatch (null)', async () => { - await assert.rejects(async () => { - await sessionAny.run({'data_0': input0}, null); - }, {name: 'TypeError', message: /argument\[1\]/}); + await assert.rejects( + async () => { + await sessionAny.run({ data_0: input0 }, null); + }, + { name: 'TypeError', message: /argument\[1\]/ }, + ); }); it('BAD CALL - fetches type mismatch (number)', async () => { - await assert.rejects(async () => { - await sessionAny.run({'data_0': input0}, 1); - }, {name: 'TypeError', message: /argument\[1\]/}); + await assert.rejects( + async () => { + await sessionAny.run({ data_0: input0 }, 1); + }, + { name: 'TypeError', message: /argument\[1\]/ }, + ); }); it('BAD CALL - fetches type mismatch (Tensor)', async () => { - await assert.rejects(async () => { - await sessionAny.run( - {'data_0': input0}, new Tensor(new Float32Array(expectedOutput0.size), expectedOutput0.dims)); - }, {name: 'TypeError', message: /'fetches'/}); + await assert.rejects( + async () => { + await sessionAny.run( + { data_0: input0 }, + new Tensor(new Float32Array(expectedOutput0.size), expectedOutput0.dims), + ); + }, + { name: 'TypeError', message: /'fetches'/ }, + ); }); it('BAD CALL - fetches as array (empty array)', async () => { - await assert.rejects(async () => { - await sessionAny.run({'data_0': input0}, []); - }, {name: 'TypeError', message: /'fetches'/}); + await assert.rejects( + async () => { + await sessionAny.run({ data_0: input0 }, []); + }, + { name: 'TypeError', message: /'fetches'/ }, + ); }); it('BAD CALL - fetches as array (non-string elements)', async () => { - await assert.rejects(async () => { - await sessionAny.run({'data_0': input0}, [1, 2, 3]); - }, {name: 'TypeError', message: /'fetches'/}); + await assert.rejects( + async () => { + await sessionAny.run({ data_0: input0 }, [1, 2, 3]); + }, + { name: 'TypeError', message: /'fetches'/ }, + ); }); it('BAD CALL - fetches as array (invalid name)', async () => { - await assert.rejects(async () => { - await sessionAny.run({'data_0': input0}, ['im_a_wrong_output_name']); - }, {name: 'RangeError', message: /'fetches'/}); + await assert.rejects( + async () => { + await sessionAny.run({ data_0: input0 }, ['im_a_wrong_output_name']); + }, + { name: 'RangeError', message: /'fetches'/ }, + ); }); // #endregion it('BAD CALL - options type mismatch (number)', async () => { - await assert.rejects(async () => { - await sessionAny.run({'data_0': input0}, ['softmaxout_1'], 1); - }, {name: 'TypeError', message: /'options'/}); + await assert.rejects( + async () => { + await sessionAny.run({ data_0: input0 }, ['softmaxout_1'], 1); + }, + { name: 'TypeError', message: /'options'/ }, + ); }); }); @@ -190,134 +260,182 @@ describe('UnitTests - InferenceSession.SessionOptions', () => { const createAny: any = InferenceSession.create; it('BAD CALL - type mismatch', async () => { - await assert.rejects(async () => { - await createAny(modelPath, 'cpu'); - }, {name: 'TypeError', message: /'options'/}); + await assert.rejects( + async () => { + await createAny(modelPath, 'cpu'); + }, + { name: 'TypeError', message: /'options'/ }, + ); }); describe('executionProviders', () => { it.skip('BAD CALL - type mismatch', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {executionProviders: 'bad-EP-name'}); - }, {name: 'TypeError', message: /executionProviders/}); + await assert.rejects( + async () => { + await createAny(modelPath, { executionProviders: 'bad-EP-name' }); + }, + { name: 'TypeError', message: /executionProviders/ }, + ); }); it.skip('EXPECTED FAILURE - invalid EP name, string list', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {executionProviders: ['bad-EP-name']}); - }, {name: 'Error', message: /executionProviders.+bad-EP-name/}); + await assert.rejects( + async () => { + await createAny(modelPath, { executionProviders: ['bad-EP-name'] }); + }, + { name: 'Error', message: /executionProviders.+bad-EP-name/ }, + ); }); it.skip('EXPECTED FAILURE - invalid EP name, object list', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {executionProviders: [{name: 'bad-EP-name'}]}); - }, {name: 'Error', message: /executionProviders.+bad-EP-name/}); + await assert.rejects( + async () => { + await createAny(modelPath, { executionProviders: [{ name: 'bad-EP-name' }] }); + }, + { name: 'Error', message: /executionProviders.+bad-EP-name/ }, + ); }); it('string list (CPU)', async () => { - await InferenceSession.create(modelPath, {executionProviders: ['cpu']}); + await InferenceSession.create(modelPath, { executionProviders: ['cpu'] }); }); it('object list (CPU)', async () => { - await InferenceSession.create(modelPath, {executionProviders: [{name: 'cpu'}]}); + await InferenceSession.create(modelPath, { executionProviders: [{ name: 'cpu' }] }); }); }); describe('intraOpNumThreads', () => { it('BAD CALL - type mismatch', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {intraOpNumThreads: 'bad-value'}); - }, {name: 'TypeError', message: /intraOpNumThreads/}); + await assert.rejects( + async () => { + await createAny(modelPath, { intraOpNumThreads: 'bad-value' }); + }, + { name: 'TypeError', message: /intraOpNumThreads/ }, + ); }); it('BAD CALL - non-integer', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {intraOpNumThreads: 1.5}); - }, {name: 'RangeError', message: /intraOpNumThreads/}); + await assert.rejects( + async () => { + await createAny(modelPath, { intraOpNumThreads: 1.5 }); + }, + { name: 'RangeError', message: /intraOpNumThreads/ }, + ); }); it('BAD CALL - negative integer', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {intraOpNumThreads: -1}); - }, {name: 'RangeError', message: /intraOpNumThreads/}); + await assert.rejects( + async () => { + await createAny(modelPath, { intraOpNumThreads: -1 }); + }, + { name: 'RangeError', message: /intraOpNumThreads/ }, + ); }); it('intraOpNumThreads = 1', async () => { - await InferenceSession.create(modelPath, {intraOpNumThreads: 1}); + await InferenceSession.create(modelPath, { intraOpNumThreads: 1 }); }); }); describe('interOpNumThreads', () => { it('BAD CALL - type mismatch', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {interOpNumThreads: 'bad-value'}); - }, {name: 'TypeError', message: /interOpNumThreads/}); + await assert.rejects( + async () => { + await createAny(modelPath, { interOpNumThreads: 'bad-value' }); + }, + { name: 'TypeError', message: /interOpNumThreads/ }, + ); }); it('BAD CALL - non-integer', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {interOpNumThreads: 1.5}); - }, {name: 'RangeError', message: /interOpNumThreads/}); + await assert.rejects( + async () => { + await createAny(modelPath, { interOpNumThreads: 1.5 }); + }, + { name: 'RangeError', message: /interOpNumThreads/ }, + ); }); it('BAD CALL - negative integer', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {interOpNumThreads: -1}); - }, {name: 'RangeError', message: /interOpNumThreads/}); + await assert.rejects( + async () => { + await createAny(modelPath, { interOpNumThreads: -1 }); + }, + { name: 'RangeError', message: /interOpNumThreads/ }, + ); }); it('interOpNumThreads = 1', async () => { - await InferenceSession.create(modelPath, {interOpNumThreads: 1}); + await InferenceSession.create(modelPath, { interOpNumThreads: 1 }); }); }); describe('graphOptimizationLevel', () => { it('BAD CALL - type mismatch', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {graphOptimizationLevel: 0}); - }, {name: 'TypeError', message: /graphOptimizationLevel/}); + await assert.rejects( + async () => { + await createAny(modelPath, { graphOptimizationLevel: 0 }); + }, + { name: 'TypeError', message: /graphOptimizationLevel/ }, + ); }); it('BAD CALL - invalid config', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {graphOptimizationLevel: 'bad-value'}); - }, {name: 'TypeError', message: /graphOptimizationLevel/}); + await assert.rejects( + async () => { + await createAny(modelPath, { graphOptimizationLevel: 'bad-value' }); + }, + { name: 'TypeError', message: /graphOptimizationLevel/ }, + ); }); it('graphOptimizationLevel = basic', async () => { - await InferenceSession.create(modelPath, {graphOptimizationLevel: 'basic'}); + await InferenceSession.create(modelPath, { graphOptimizationLevel: 'basic' }); }); }); describe('enableCpuMemArena', () => { it('BAD CALL - type mismatch', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {enableCpuMemArena: 0}); - }, {name: 'TypeError', message: /enableCpuMemArena/}); + await assert.rejects( + async () => { + await createAny(modelPath, { enableCpuMemArena: 0 }); + }, + { name: 'TypeError', message: /enableCpuMemArena/ }, + ); }); it('enableCpuMemArena = true', async () => { - await InferenceSession.create(modelPath, {enableCpuMemArena: true}); + await InferenceSession.create(modelPath, { enableCpuMemArena: true }); }); }); describe('enableMemPattern', () => { it('BAD CALL - type mismatch', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {enableMemPattern: 0}); - }, {name: 'TypeError', message: /enableMemPattern/}); + await assert.rejects( + async () => { + await createAny(modelPath, { enableMemPattern: 0 }); + }, + { name: 'TypeError', message: /enableMemPattern/ }, + ); }); it('enableMemPattern = true', async () => { - await InferenceSession.create(modelPath, {enableMemPattern: true}); + await InferenceSession.create(modelPath, { enableMemPattern: true }); }); }); describe('executionMode', () => { it('BAD CALL - type mismatch', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {executionMode: 0}); - }, {name: 'TypeError', message: /executionMode/}); + await assert.rejects( + async () => { + await createAny(modelPath, { executionMode: 0 }); + }, + { name: 'TypeError', message: /executionMode/ }, + ); }); it('BAD CALL - invalid config', async () => { - await assert.rejects(async () => { - await createAny(modelPath, {executionMode: 'bad-value'}); - }, {name: 'TypeError', message: /executionMode/}); + await assert.rejects( + async () => { + await createAny(modelPath, { executionMode: 'bad-value' }); + }, + { name: 'TypeError', message: /executionMode/ }, + ); }); it('executionMode = sequential', async () => { - await InferenceSession.create(modelPath, {executionMode: 'sequential'}); + await InferenceSession.create(modelPath, { executionMode: 'sequential' }); }); }); }); describe('UnitTests - InferenceSession.RunOptions', () => { - let session: InferenceSession|null = null; + let session: InferenceSession | null = null; let sessionAny: any; const input0 = new Tensor('float32', [1, 2, 3, 4, 5], [1, 5]); const expectedOutput0 = new Tensor('float32', [1, 2, 3, 4, 5], [1, 5]); @@ -330,22 +448,31 @@ describe('UnitTests - InferenceSession.RunOptions', () => { describe('logSeverityLevel', () => { it('BAD CALL - type mismatch', async () => { - await assert.rejects(async () => { - await sessionAny.run({input: input0}, {logSeverityLevel: 'error'}); - }, {name: 'TypeError', message: /logSeverityLevel/}); + await assert.rejects( + async () => { + await sessionAny.run({ input: input0 }, { logSeverityLevel: 'error' }); + }, + { name: 'TypeError', message: /logSeverityLevel/ }, + ); }); it('BAD CALL - out of range', async () => { - await assert.rejects(async () => { - await sessionAny.run({input: input0}, {logSeverityLevel: 8}); - }, {name: 'RangeError', message: /logSeverityLevel/}); + await assert.rejects( + async () => { + await sessionAny.run({ input: input0 }, { logSeverityLevel: 8 }); + }, + { name: 'RangeError', message: /logSeverityLevel/ }, + ); }); it('BAD CALL - out of range', async () => { - await assert.rejects(async () => { - await sessionAny.run({input: input0}, {logSeverityLevel: 8}); - }, {name: 'RangeError', message: /logSeverityLevel/}); + await assert.rejects( + async () => { + await sessionAny.run({ input: input0 }, { logSeverityLevel: 8 }); + }, + { name: 'RangeError', message: /logSeverityLevel/ }, + ); }); it('logSeverityLevel = 4', async () => { - const result = await sessionAny.run({input: input0}, {logSeverityLevel: 4}); + const result = await sessionAny.run({ input: input0 }, { logSeverityLevel: 4 }); assertTensorEqual(result.output, expectedOutput0); }); }); diff --git a/js/node/test/unittests/lib/tensor.ts b/js/node/test/unittests/lib/tensor.ts index 49b73da2e87c1..9e09c4e816fba 100644 --- a/js/node/test/unittests/lib/tensor.ts +++ b/js/node/test/unittests/lib/tensor.ts @@ -3,17 +3,19 @@ import * as assert from 'assert'; // tensor with type information -import {Tensor} from 'onnxruntime-common'; +import { Tensor } from 'onnxruntime-common'; -import {createTestData, NUMERIC_TYPE_MAP} from '../../test-utils'; +import { createTestData, NUMERIC_TYPE_MAP } from '../../test-utils'; // tensor with no type information, used for testing type check const TensorAny = Tensor as any; function testAllTensortypes( - title: string, length: number, - funcNumerictypes: (passtypeParam: boolean, type: Tensor.Type, data: Tensor.DataType) => void, - funcStringtype?: (passtypeParam: boolean, data: string[]) => void): void { + title: string, + length: number, + funcNumerictypes: (passtypeParam: boolean, type: Tensor.Type, data: Tensor.DataType) => void, + funcStringtype?: (passtypeParam: boolean, data: string[]) => void, +): void { NUMERIC_TYPE_MAP.forEach((ctor, type) => { it(`${title} - (${type}, ${ctor.name})`, () => { funcNumerictypes(true, type, createTestData(type, length)); @@ -42,60 +44,78 @@ function testAllTensortypes( } describe('UnitTests - tensor', () => { - testAllTensortypes('check data and type', 100, (passtypeParam, type, data) => { // numeric and string tensors + testAllTensortypes('check data and type', 100, (passtypeParam, type, data) => { + // numeric and string tensors const tensor0 = passtypeParam ? new Tensor(type, data) : new Tensor(data); assert.strictEqual(tensor0.data, data, 'tensor.data and data should be the same object.'); assert.strictEqual(tensor0.type, type, 'tensor.type and type should be equal.'); }); - testAllTensortypes('check dims (omitted)', 200, (passtypeParam, type, data) => { // numeric and string tensors + testAllTensortypes('check dims (omitted)', 200, (passtypeParam, type, data) => { + // numeric and string tensors const tensor0 = passtypeParam ? new Tensor(type, data) : new Tensor(data); assert.deepStrictEqual( - tensor0.dims, [200], - 'tensor.dims should be a number array with exactly 1 item, with value of the array length.'); + tensor0.dims, + [200], + 'tensor.dims should be a number array with exactly 1 item, with value of the array length.', + ); }); - testAllTensortypes('check dims (specified)', 60, (passtypeParam, type, data) => { // numeric and string tensors + testAllTensortypes('check dims (specified)', 60, (passtypeParam, type, data) => { + // numeric and string tensors const tensor0 = passtypeParam ? new Tensor(type, data, [3, 4, 5]) : new Tensor(data, [3, 4, 5]); assert.deepStrictEqual(tensor0.dims, [3, 4, 5], 'tensor.dims should be a number array with the given 3 items.'); }); - testAllTensortypes( - 'BAD CALL - invalid dims type', 100, (passtypeParam, type, data) => { // numeric and string tensors - assert.throws(() => { - const badDims = {}; - passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); - }, {name: 'TypeError', message: /must be a number array/}); - }); - testAllTensortypes( - 'BAD CALL - invalid dims element type', 100, (passtypeParam, type, data) => { // numeric and string tensors - assert.throws(() => { - const badDims = [1, 2, '']; - passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); - }, {name: 'TypeError', message: /must be an integer/}); - }); - testAllTensortypes( - 'BAD CALL - invalid dims number type (negative)', 100, - (passtypeParam, type, data) => { // numeric and string tensors - assert.throws(() => { - const badDims = [1, 2, -1]; - passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); - }, {name: 'RangeError', message: /must be a non-negative integer/}); - }); - testAllTensortypes( - 'BAD CALL - invalid dims number type (non-integer)', 100, - (passtypeParam, type, data) => { // numeric and string tensors - assert.throws(() => { - const badDims = [1, 2, 1.5]; - passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); - }, {name: 'TypeError', message: /must be an integer/}); - }); + testAllTensortypes('BAD CALL - invalid dims type', 100, (passtypeParam, type, data) => { + // numeric and string tensors + assert.throws( + () => { + const badDims = {}; + passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); + }, + { name: 'TypeError', message: /must be a number array/ }, + ); + }); + testAllTensortypes('BAD CALL - invalid dims element type', 100, (passtypeParam, type, data) => { + // numeric and string tensors + assert.throws( + () => { + const badDims = [1, 2, '']; + passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); + }, + { name: 'TypeError', message: /must be an integer/ }, + ); + }); + testAllTensortypes('BAD CALL - invalid dims number type (negative)', 100, (passtypeParam, type, data) => { + // numeric and string tensors + assert.throws( + () => { + const badDims = [1, 2, -1]; + passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); + }, + { name: 'RangeError', message: /must be a non-negative integer/ }, + ); + }); + testAllTensortypes('BAD CALL - invalid dims number type (non-integer)', 100, (passtypeParam, type, data) => { + // numeric and string tensors + assert.throws( + () => { + const badDims = [1, 2, 1.5]; + passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); + }, + { name: 'TypeError', message: /must be an integer/ }, + ); + }); - testAllTensortypes( - 'BAD CALL - length and dims does not match', 100, (passtypeParam, type, data) => { // numeric and string tensors - assert.throws(() => { - const badDims = [10, 8]; - passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); - }, {name: 'Error', message: /does not match data length/}); - }); + testAllTensortypes('BAD CALL - length and dims does not match', 100, (passtypeParam, type, data) => { + // numeric and string tensors + assert.throws( + () => { + const badDims = [10, 8]; + passtypeParam ? new TensorAny(type, data, badDims) : new TensorAny(data, badDims); + }, + { name: 'Error', message: /does not match data length/ }, + ); + }); }); diff --git a/js/package-lock.json b/js/package-lock.json index fca482c7879d3..d3684dfdf9117 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -12,11 +12,11 @@ "@types/npmlog": "^4.1.4", "@typescript-eslint/eslint-plugin": "^7.4.0", "@typescript-eslint/parser": "^7.4.0", - "clang-format": "^1.8.0", "dir-compare": "^4.2.0", "esbuild": "^0.19.3", "esbuild-plugin-polyfill-node": "^0.3.0", "eslint": "^8.51.0", + "eslint-config-prettier": "^9.1.0", "eslint-plugin-header": "^3.1.1", "eslint-plugin-import": "^2.28.1", "eslint-plugin-jsdoc": "^46.8.2", @@ -26,7 +26,7 @@ "jszip": "^3.10.1", "mocha": "^10.2.0", "npmlog": "^7.0.1", - "prettier": "^3.0.3", + "prettier": "^3.3.3", "terser": "^5.31.0", "typescript": "^5.2.2" } @@ -1242,12 +1242,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/async": { - "version": "3.2.4", - "resolved": "https://registry.npmjs.org/async/-/async-3.2.4.tgz", - "integrity": "sha512-iAB+JbDEGXhyIUavoDl9WP/Jj106Kz9DEn1DPgYw5ruDn0e3Wgi3sKFm55sASdGBNOQB8F59d9qQ7deqrHA8wQ==", - "dev": true - }, "node_modules/available-typed-arrays": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/available-typed-arrays/-/available-typed-arrays-1.0.5.tgz", @@ -1469,22 +1463,6 @@ "node": ">=8" } }, - "node_modules/clang-format": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/clang-format/-/clang-format-1.8.0.tgz", - "integrity": "sha512-pK8gzfu55/lHzIpQ1givIbWfn3eXnU7SfxqIwVgnn5jEM6j4ZJYjpFqFs4iSBPNedzRMmfjYjuQhu657WAXHXw==", - "dev": true, - "dependencies": { - "async": "^3.2.3", - "glob": "^7.0.0", - "resolve": "^1.1.6" - }, - "bin": { - "check-clang-format": "bin/check-clang-format.js", - "clang-format": "index.js", - "git-clang-format": "bin/git-clang-format" - } - }, "node_modules/clean-regexp": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/clean-regexp/-/clean-regexp-1.0.0.tgz", @@ -1939,6 +1917,18 @@ "url": "https://opencollective.com/eslint" } }, + "node_modules/eslint-config-prettier": { + "version": "9.1.0", + "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-9.1.0.tgz", + "integrity": "sha512-NSWl5BFQWEPi1j4TjVNItzYV7dZXZ+wP6I6ZhrBGpChQhZRUaElihE9uRRkcbRnNb76UMKDF3r+WTmNcGPKsqw==", + "dev": true, + "bin": { + "eslint-config-prettier": "bin/cli.js" + }, + "peerDependencies": { + "eslint": ">=7.0.0" + } + }, "node_modules/eslint-import-resolver-node": { "version": "0.3.7", "resolved": "https://registry.npmjs.org/eslint-import-resolver-node/-/eslint-import-resolver-node-0.3.7.tgz", @@ -3782,9 +3772,9 @@ } }, "node_modules/prettier": { - "version": "3.0.3", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.3.tgz", - "integrity": "sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", + "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", "dev": true, "bin": { "prettier": "bin/prettier.cjs" @@ -5574,12 +5564,6 @@ "is-shared-array-buffer": "^1.0.2" } }, - "async": { - "version": "3.2.4", - "resolved": "https://registry.npmjs.org/async/-/async-3.2.4.tgz", - "integrity": "sha512-iAB+JbDEGXhyIUavoDl9WP/Jj106Kz9DEn1DPgYw5ruDn0e3Wgi3sKFm55sASdGBNOQB8F59d9qQ7deqrHA8wQ==", - "dev": true - }, "available-typed-arrays": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/available-typed-arrays/-/available-typed-arrays-1.0.5.tgz", @@ -5716,17 +5700,6 @@ "integrity": "sha512-eXTggHWSooYhq49F2opQhuHWgzucfF2YgODK4e1566GQs5BIfP30B0oenwBJHfWxAs2fyPB1s7Mg949zLf61Yw==", "dev": true }, - "clang-format": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/clang-format/-/clang-format-1.8.0.tgz", - "integrity": "sha512-pK8gzfu55/lHzIpQ1givIbWfn3eXnU7SfxqIwVgnn5jEM6j4ZJYjpFqFs4iSBPNedzRMmfjYjuQhu657WAXHXw==", - "dev": true, - "requires": { - "async": "^3.2.3", - "glob": "^7.0.0", - "resolve": "^1.1.6" - } - }, "clean-regexp": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/clean-regexp/-/clean-regexp-1.0.0.tgz", @@ -6090,6 +6063,13 @@ "text-table": "^0.2.0" } }, + "eslint-config-prettier": { + "version": "9.1.0", + "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-9.1.0.tgz", + "integrity": "sha512-NSWl5BFQWEPi1j4TjVNItzYV7dZXZ+wP6I6ZhrBGpChQhZRUaElihE9uRRkcbRnNb76UMKDF3r+WTmNcGPKsqw==", + "dev": true, + "requires": {} + }, "eslint-import-resolver-node": { "version": "0.3.7", "resolved": "https://registry.npmjs.org/eslint-import-resolver-node/-/eslint-import-resolver-node-0.3.7.tgz", @@ -7446,9 +7426,9 @@ "dev": true }, "prettier": { - "version": "3.0.3", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.3.tgz", - "integrity": "sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", + "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", "dev": true }, "process": { diff --git a/js/package.json b/js/package.json index 308d6931a927c..a3bd18adce98e 100644 --- a/js/package.json +++ b/js/package.json @@ -6,11 +6,11 @@ "@types/npmlog": "^4.1.4", "@typescript-eslint/eslint-plugin": "^7.4.0", "@typescript-eslint/parser": "^7.4.0", - "clang-format": "^1.8.0", "dir-compare": "^4.2.0", "esbuild": "^0.19.3", "esbuild-plugin-polyfill-node": "^0.3.0", "eslint": "^8.51.0", + "eslint-config-prettier": "^9.1.0", "eslint-plugin-header": "^3.1.1", "eslint-plugin-import": "^2.28.1", "eslint-plugin-jsdoc": "^46.8.2", @@ -20,19 +20,14 @@ "jszip": "^3.10.1", "mocha": "^10.2.0", "npmlog": "^7.0.1", - "prettier": "^3.0.3", + "prettier": "^3.3.3", "terser": "^5.31.0", "typescript": "^5.2.2" }, "scripts": { "prepare": "tsc --build scripts", - "lint": "eslint . --ext .ts --ext .tsx", - "format:ts": "clang-format --glob=\"{scripts/**/*.ts,common/{lib,test}/**/*.ts,node/{lib,script,test}/**/*.ts,web/{lib,script,test}/**/*.ts,react_native/{android,example,ios,lib}/**/*.{ts,tsx}}\" --style=file -i", - "format:js": "clang-format --glob=\"{{,common,node,web,react_native}/{*,.*}.{,m,c}js,web/test/e2e/**/*.{,m,c}js}\" --style=file -i", - "format:cf": "clang-format --glob=\"{node/src/**/*.{cc,h},react_native/{android,example,ios,lib}/**/*.{mm,java}}\" --style=file -i", - "format:json": "prettier \"**/*.{json,jsonc}\" --write", - "format:md": "prettier \"**/*.md\" --write", - "format": "npm run format:ts && npm run format:js && npm run format:cf && npm run format:json && npm run format:md", + "lint": "eslint .", + "format": "prettier \"**/*.{json,jsonc,js,mjs,cjs,ts,mts,cts,md}\" --write", "prepare-node-tests": "node ./scripts/prepare-onnx-node-tests", "update-version": "node ./scripts/update-version" }, diff --git a/js/react_native/android/src/main/cpp/cpp-adapter.cpp b/js/react_native/android/src/main/cpp/cpp-adapter.cpp index be1228bbfe959..d75a2f9c99d8b 100644 --- a/js/react_native/android/src/main/cpp/cpp-adapter.cpp +++ b/js/react_native/android/src/main/cpp/cpp-adapter.cpp @@ -6,17 +6,17 @@ using namespace facebook; typedef u_int8_t byte; -std::string jstring2string(JNIEnv *env, jstring jStr) { +std::string jstring2string(JNIEnv* env, jstring jStr) { if (!jStr) return ""; jclass stringClass = env->GetObjectClass(jStr); jmethodID getBytes = env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B"); - const auto stringJbytes = (jbyteArray) env->CallObjectMethod(jStr, getBytes, env->NewStringUTF("UTF-8")); + const auto stringJbytes = (jbyteArray)env->CallObjectMethod(jStr, getBytes, env->NewStringUTF("UTF-8")); - auto length = (size_t) env->GetArrayLength(stringJbytes); + auto length = (size_t)env->GetArrayLength(stringJbytes); jbyte* pBytes = env->GetByteArrayElements(stringJbytes, nullptr); - std::string ret = std::string((char *)pBytes, length); + std::string ret = std::string((char*)pBytes, length); env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT); env->DeleteLocalRef(stringJbytes); @@ -24,7 +24,7 @@ std::string jstring2string(JNIEnv *env, jstring jStr) { return ret; } -byte* getBytesFromBlob(JNIEnv *env, jobject instanceGlobal, const std::string& blobId, int offset, int size) { +byte* getBytesFromBlob(JNIEnv* env, jobject instanceGlobal, const std::string& blobId, int offset, int size) { if (!env) throw std::runtime_error("JNI Environment is gone!"); // get java class @@ -33,12 +33,12 @@ byte* getBytesFromBlob(JNIEnv *env, jobject instanceGlobal, const std::string& b jmethodID getBufferJava = env->GetMethodID(clazz, "getBlobBuffer", "(Ljava/lang/String;II)[B"); // call method auto jstring = env->NewStringUTF(blobId.c_str()); - auto boxedBytes = (jbyteArray) env->CallObjectMethod(instanceGlobal, - getBufferJava, - // arguments - jstring, - offset, - size); + auto boxedBytes = (jbyteArray)env->CallObjectMethod(instanceGlobal, + getBufferJava, + // arguments + jstring, + offset, + size); env->DeleteLocalRef(jstring); jboolean isCopy = true; @@ -47,7 +47,7 @@ byte* getBytesFromBlob(JNIEnv *env, jobject instanceGlobal, const std::string& b return reinterpret_cast(bytes); }; -std::string createBlob(JNIEnv *env, jobject instanceGlobal, byte* bytes, size_t size) { +std::string createBlob(JNIEnv* env, jobject instanceGlobal, byte* bytes, size_t size) { if (!env) throw std::runtime_error("JNI Environment is gone!"); // get java class @@ -57,15 +57,14 @@ std::string createBlob(JNIEnv *env, jobject instanceGlobal, byte* bytes, size_t // call method auto byteArray = env->NewByteArray(size); env->SetByteArrayRegion(byteArray, 0, size, reinterpret_cast(bytes)); - auto blobId = (jstring) env->CallObjectMethod(instanceGlobal, getBufferJava, byteArray); + auto blobId = (jstring)env->CallObjectMethod(instanceGlobal, getBufferJava, byteArray); env->DeleteLocalRef(byteArray); return jstring2string(env, blobId); }; -extern "C" -JNIEXPORT void JNICALL -Java_ai_onnxruntime_reactnative_OnnxruntimeJSIHelper_nativeInstall(JNIEnv *env, jclass _, jlong jsiPtr, jobject instance) { +extern "C" JNIEXPORT void JNICALL +Java_ai_onnxruntime_reactnative_OnnxruntimeJSIHelper_nativeInstall(JNIEnv* env, jclass _, jlong jsiPtr, jobject instance) { auto jsiRuntime = reinterpret_cast(jsiPtr); auto& runtime = *jsiRuntime; @@ -76,28 +75,28 @@ Java_ai_onnxruntime_reactnative_OnnxruntimeJSIHelper_nativeInstall(JNIEnv *env, jsi::PropNameID::forAscii(runtime, "jsiOnnxruntimeResolveArrayBuffer"), 1, [=](jsi::Runtime& runtime, - const jsi::Value& thisValue, - const jsi::Value* arguments, - size_t count) -> jsi::Value { - if (count != 1) { - throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!"); - } - - jsi::Object data = arguments[0].asObject(runtime); - auto blobId = data.getProperty(runtime, "blobId").asString(runtime); - auto offset = data.getProperty(runtime, "offset").asNumber(); - auto size = data.getProperty(runtime, "size").asNumber(); - - auto bytes = getBytesFromBlob(env, instanceGlobal, blobId.utf8(runtime), offset, size); - - size_t totalSize = size - offset; - jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer"); - jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int) totalSize).getObject(runtime); - jsi::ArrayBuffer buf = o.getArrayBuffer(runtime); - memcpy(buf.data(runtime), reinterpret_cast(bytes), totalSize); - - return buf; - }); + const jsi::Value& thisValue, + const jsi::Value* arguments, + size_t count) -> jsi::Value { + if (count != 1) { + throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!"); + } + + jsi::Object data = arguments[0].asObject(runtime); + auto blobId = data.getProperty(runtime, "blobId").asString(runtime); + auto offset = data.getProperty(runtime, "offset").asNumber(); + auto size = data.getProperty(runtime, "size").asNumber(); + + auto bytes = getBytesFromBlob(env, instanceGlobal, blobId.utf8(runtime), offset, size); + + size_t totalSize = size - offset; + jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer"); + jsi::Object o = arrayBufferCtor.callAsConstructor(runtime, (int)totalSize).getObject(runtime); + jsi::ArrayBuffer buf = o.getArrayBuffer(runtime); + memcpy(buf.data(runtime), reinterpret_cast(bytes), totalSize); + + return buf; + }); runtime.global().setProperty(runtime, "jsiOnnxruntimeResolveArrayBuffer", std::move(resolveArrayBuffer)); auto storeArrayBuffer = jsi::Function::createFromHostFunction(runtime, @@ -107,21 +106,21 @@ Java_ai_onnxruntime_reactnative_OnnxruntimeJSIHelper_nativeInstall(JNIEnv *env, const jsi::Value& thisValue, const jsi::Value* arguments, size_t count) -> jsi::Value { - if (count != 1) { - throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!"); - } - - auto arrayBuffer = arguments[0].asObject(runtime).getArrayBuffer(runtime); - auto size = arrayBuffer.size(runtime); - - std::string blobId = createBlob(env, instanceGlobal, arrayBuffer.data(runtime), size); - - jsi::Object result(runtime); - auto blobIdString = jsi::String::createFromUtf8(runtime, blobId); - result.setProperty(runtime, "blobId", blobIdString); - result.setProperty(runtime, "offset", jsi::Value(0)); - result.setProperty(runtime, "size", jsi::Value(static_cast(size))); - return result; - }); + if (count != 1) { + throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!"); + } + + auto arrayBuffer = arguments[0].asObject(runtime).getArrayBuffer(runtime); + auto size = arrayBuffer.size(runtime); + + std::string blobId = createBlob(env, instanceGlobal, arrayBuffer.data(runtime), size); + + jsi::Object result(runtime); + auto blobIdString = jsi::String::createFromUtf8(runtime, blobId); + result.setProperty(runtime, "blobId", blobIdString); + result.setProperty(runtime, "offset", jsi::Value(0)); + result.setProperty(runtime, "size", jsi::Value(static_cast(size))); + return result; + }); runtime.global().setProperty(runtime, "jsiOnnxruntimeStoreArrayBuffer", std::move(storeArrayBuffer)); } diff --git a/js/react_native/app.plugin.js b/js/react_native/app.plugin.js index ed4cfe48563bd..2fa117b1a14e5 100644 --- a/js/react_native/app.plugin.js +++ b/js/react_native/app.plugin.js @@ -8,16 +8,14 @@ const withOrt = (config) => { // Add build dependency to gradle file config = configPlugin.withAppBuildGradle(config, (config) => { if (config.modResults.language === 'groovy') { - config.modResults.contents = generateCode - .mergeContents({ - src: config.modResults.contents, - newSrc: ' implementation project(\':onnxruntime-react-native\')', - tag: 'onnxruntime-react-native', - anchor: /^dependencies[ \t]*\{$/, - offset: 1, - comment: ' // onnxruntime-react-native' - }) - .contents; + config.modResults.contents = generateCode.mergeContents({ + src: config.modResults.contents, + newSrc: " implementation project(':onnxruntime-react-native')", + tag: 'onnxruntime-react-native', + anchor: /^dependencies[ \t]*\{$/, + offset: 1, + comment: ' // onnxruntime-react-native', + }).contents; } else { throw new Error('Cannot add ONNX Runtime maven gradle because the build.gradle is not groovy'); } @@ -30,24 +28,21 @@ const withOrt = (config) => { 'ios', (config) => { const podFilePath = path.join(config.modRequest.platformProjectRoot, 'Podfile'); - const contents = fs.readFileSync(podFilePath, {encoding: 'utf-8'}); - const updatedContents = - generateCode - .mergeContents({ - src: contents, - newSrc: ' pod \'onnxruntime-react-native\', :path => \'../node_modules/onnxruntime-react-native\'', - tag: 'onnxruntime-react-native', - anchor: /^target.+do$/, - offset: 1, - comment: ' # onnxruntime-react-native' - }) - .contents; + const contents = fs.readFileSync(podFilePath, { encoding: 'utf-8' }); + const updatedContents = generateCode.mergeContents({ + src: contents, + newSrc: " pod 'onnxruntime-react-native', :path => '../node_modules/onnxruntime-react-native'", + tag: 'onnxruntime-react-native', + anchor: /^target.+do$/, + offset: 1, + comment: ' # onnxruntime-react-native', + }).contents; fs.writeFileSync(podFilePath, updatedContents); return config; - } + }, ]); return config; }; -exports.default = configPlugin.createRunOncePlugin(withOrt, pkg.name, pkg.version) +exports.default = configPlugin.createRunOncePlugin(withOrt, pkg.name, pkg.version); diff --git a/js/react_native/babel.config.js b/js/react_native/babel.config.js index b667f9a55a389..e2240f1f51f8b 100644 --- a/js/react_native/babel.config.js +++ b/js/react_native/babel.config.js @@ -1,5 +1,5 @@ 'use strict'; module.exports = { - presets : ['module:metro-react-native-babel-preset'], + presets: ['module:metro-react-native-babel-preset'], }; diff --git a/js/react_native/e2e/.detoxrc.js b/js/react_native/e2e/.detoxrc.js index 94ff7272972c4..e24833a1d09c9 100644 --- a/js/react_native/e2e/.detoxrc.js +++ b/js/react_native/e2e/.detoxrc.js @@ -2,82 +2,82 @@ module.exports = { testRunner: { args: { - '$0': 'jest', - config: 'test/jest.config.js' + $0: 'jest', + config: 'test/jest.config.js', }, jest: { - setupTimeout: 120000 - } + setupTimeout: 120000, + }, }, apps: { 'ios.debug': { type: 'ios.app', binaryPath: 'ios/build/Build/Products/Debug-iphonesimulator/OnnxruntimeModuleExample.app', - build: 'xcodebuild ARCHS=x86_64 ONLY_ACTIVE_ARCH=NO -workspace ios/OnnxruntimeModuleExample.xcworkspace -scheme OnnxruntimeModuleExample -configuration Debug -sdk iphonesimulator -derivedDataPath ios/build' + build: + 'xcodebuild ARCHS=x86_64 ONLY_ACTIVE_ARCH=NO -workspace ios/OnnxruntimeModuleExample.xcworkspace -scheme OnnxruntimeModuleExample -configuration Debug -sdk iphonesimulator -derivedDataPath ios/build', }, 'ios.release': { type: 'ios.app', binaryPath: 'ios/build/Build/Products/Release-iphonesimulator/OnnxruntimeModuleExample.app', - build: 'xcodebuild ARCHS=x86_64 ONLY_ACTIVE_ARCH=NO -workspace ios/OnnxruntimeModuleExample.xcworkspace -scheme OnnxruntimeModuleExample -configuration Release -sdk iphonesimulator -derivedDataPath ios/build' + build: + 'xcodebuild ARCHS=x86_64 ONLY_ACTIVE_ARCH=NO -workspace ios/OnnxruntimeModuleExample.xcworkspace -scheme OnnxruntimeModuleExample -configuration Release -sdk iphonesimulator -derivedDataPath ios/build', }, 'android.debug': { type: 'android.apk', binaryPath: 'android/app/build/outputs/apk/debug/app-debug.apk', build: 'cd android && ./gradlew assembleDebug assembleAndroidTest -DtestBuildType=debug', - reversePorts: [ - 8081 - ] + reversePorts: [8081], }, 'android.release': { type: 'android.apk', binaryPath: 'android/app/build/outputs/apk/release/app-release.apk', - build: 'cd android && ./gradlew assembleRelease assembleAndroidTest -DtestBuildType=release' - } + build: 'cd android && ./gradlew assembleRelease assembleAndroidTest -DtestBuildType=release', + }, }, devices: { simulator: { type: 'ios.simulator', device: { - type: 'iPhone 13' - } + type: 'iPhone 13', + }, }, attached: { type: 'android.attached', device: { - adbName: '.*' - } + adbName: '.*', + }, }, emulator: { type: 'android.emulator', device: { - avdName: 'ort_android' - } - } + avdName: 'ort_android', + }, + }, }, configurations: { 'ios.sim.debug': { device: 'simulator', - app: 'ios.debug' + app: 'ios.debug', }, 'ios.sim.release': { device: 'simulator', - app: 'ios.release' + app: 'ios.release', }, 'android.att.debug': { device: 'attached', - app: 'android.debug' + app: 'android.debug', }, 'android.att.release': { device: 'attached', - app: 'android.release' + app: 'android.release', }, 'android.emu.debug': { device: 'emulator', - app: 'android.debug' + app: 'android.debug', }, 'android.emu.release': { device: 'emulator', - app: 'android.release' - } - } + app: 'android.release', + }, + }, }; diff --git a/js/react_native/e2e/ios/MNISTDataHandler.h b/js/react_native/e2e/ios/MNISTDataHandler.h index 1112eb31c8559..da05843e8a41f 100644 --- a/js/react_native/e2e/ios/MNISTDataHandler.h +++ b/js/react_native/e2e/ios/MNISTDataHandler.h @@ -6,7 +6,7 @@ #import -@interface MNISTDataHandler : NSObject +@interface MNISTDataHandler : NSObject @end #endif /* MNISTDataHandler_h */ diff --git a/js/react_native/e2e/ios/MNISTDataHandler.mm b/js/react_native/e2e/ios/MNISTDataHandler.mm index b935a91b63503..54a4b629865d0 100644 --- a/js/react_native/e2e/ios/MNISTDataHandler.mm +++ b/js/react_native/e2e/ios/MNISTDataHandler.mm @@ -17,14 +17,14 @@ @implementation MNISTDataHandler // so that onnxruntime is able to load a model using a given path. RCT_EXPORT_METHOD(getLocalModelPath : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) { @try { - NSString *modelPath = [[NSBundle mainBundle] pathForResource:@"mnist" ofType:@"ort"]; - NSFileManager *fileManager = [NSFileManager defaultManager]; + NSString* modelPath = [[NSBundle mainBundle] pathForResource:@"mnist" ofType:@"ort"]; + NSFileManager* fileManager = [NSFileManager defaultManager]; if ([fileManager fileExistsAtPath:modelPath]) { resolve(modelPath); } else { reject(@"mnist", @"no such a model", nil); } - } @catch (NSException *exception) { + } @catch (NSException* exception) { reject(@"mnist", @"no such a model", nil); } } @@ -32,14 +32,14 @@ @implementation MNISTDataHandler // It returns image path. RCT_EXPORT_METHOD(getImagePath : (RCTPromiseResolveBlock)resolve reject : (RCTPromiseRejectBlock)reject) { @try { - NSString *imagePath = [[NSBundle mainBundle] pathForResource:@"3" ofType:@"jpg"]; - NSFileManager *fileManager = [NSFileManager defaultManager]; + NSString* imagePath = [[NSBundle mainBundle] pathForResource:@"3" ofType:@"jpg"]; + NSFileManager* fileManager = [NSFileManager defaultManager]; if ([fileManager fileExistsAtPath:imagePath]) { resolve(imagePath); } else { reject(@"mnist", @"no such an image", nil); } - } @catch (NSException *exception) { + } @catch (NSException* exception) { reject(@"mnist", @"no such an image", nil); } } @@ -47,13 +47,13 @@ @implementation MNISTDataHandler // It gets raw input data, which can be uri or byte array and others, // returns cooked data formatted as input of a model. RCT_EXPORT_METHOD(preprocess - : (NSString *)uri resolve + : (NSString*)uri resolve : (RCTPromiseResolveBlock)resolve reject : (RCTPromiseRejectBlock)reject) { @try { - NSDictionary *inputDataMap = [self preprocess:uri]; + NSDictionary* inputDataMap = [self preprocess:uri]; resolve(inputDataMap); - } @catch (NSException *exception) { + } @catch (NSException* exception) { reject(@"mnist", @"can't load an image", nil); } } @@ -61,24 +61,24 @@ @implementation MNISTDataHandler // It gets a result from onnxruntime and a duration of session time for input data, // returns output data formatted as React Native map. RCT_EXPORT_METHOD(postprocess - : (NSDictionary *)result resolve + : (NSDictionary*)result resolve : (RCTPromiseResolveBlock)resolve reject : (RCTPromiseRejectBlock)reject) { @try { - NSDictionary *cookedMap = [self postprocess:result]; + NSDictionary* cookedMap = [self postprocess:result]; resolve(cookedMap); - } @catch (NSException *exception) { + } @catch (NSException* exception) { reject(@"mnist", @"can't pose-process an image", nil); } } -- (NSDictionary *)preprocess:(NSString *)uri { - UIImage *image = [UIImage imageNamed:@"3.jpg"]; +- (NSDictionary*)preprocess:(NSString*)uri { + UIImage* image = [UIImage imageNamed:@"3.jpg"]; CGSize scale = CGSizeMake(28, 28); UIGraphicsBeginImageContextWithOptions(scale, NO, 1.0); [image drawInRect:CGRectMake(0, 0, scale.width, scale.height)]; - UIImage *scaledImage = UIGraphicsGetImageFromCurrentImageContext(); + UIImage* scaledImage = UIGraphicsGetImageFromCurrentImageContext(); UIGraphicsEndImageContext(); CGImageRef imageRef = [scaledImage CGImage]; @@ -100,23 +100,23 @@ - (NSDictionary *)preprocess:(NSString *)uri { const NSInteger dimSize = height * width; const NSInteger byteBufferSize = dimSize * sizeof(float); - unsigned char *byteBuffer = static_cast(malloc(byteBufferSize)); - NSData *byteBufferRef = [NSData dataWithBytesNoCopy:byteBuffer length:byteBufferSize]; - float *floatPtr = (float *)[byteBufferRef bytes]; + unsigned char* byteBuffer = static_cast(malloc(byteBufferSize)); + NSData* byteBufferRef = [NSData dataWithBytesNoCopy:byteBuffer length:byteBufferSize]; + float* floatPtr = (float*)[byteBufferRef bytes]; for (NSUInteger h = 0; h < height; ++h) { for (NSUInteger w = 0; w < width; ++w) { NSUInteger byteIndex = (bytesPerRow * h) + w * bytesPerPixel; *floatPtr++ = rawData[byteIndex]; } } - floatPtr = (float *)[byteBufferRef bytes]; + floatPtr = (float*)[byteBufferRef bytes]; - NSMutableDictionary *inputDataMap = [NSMutableDictionary dictionary]; + NSMutableDictionary* inputDataMap = [NSMutableDictionary dictionary]; - NSMutableDictionary *inputTensorMap = [NSMutableDictionary dictionary]; + NSMutableDictionary* inputTensorMap = [NSMutableDictionary dictionary]; // dims - NSArray *dims = @[ + NSArray* dims = @[ [NSNumber numberWithInt:1], [NSNumber numberWithInt:1], [NSNumber numberWithInt:static_cast(height)], @@ -128,7 +128,7 @@ - (NSDictionary *)preprocess:(NSString *)uri { inputTensorMap[@"type"] = JsTensorTypeFloat; // encoded data - NSString *data = [byteBufferRef base64EncodedStringWithOptions:0]; + NSString* data = [byteBufferRef base64EncodedStringWithOptions:0]; inputTensorMap[@"data"] = data; inputDataMap[@"Input3"] = inputTensorMap; @@ -136,14 +136,14 @@ - (NSDictionary *)preprocess:(NSString *)uri { return inputDataMap; } -- (NSDictionary *)postprocess:(NSDictionary *)result { - NSMutableString *detectionResult = [NSMutableString string]; +- (NSDictionary*)postprocess:(NSDictionary*)result { + NSMutableString* detectionResult = [NSMutableString string]; - NSDictionary *outputTensor = [result objectForKey:@"Plus214_Output_0"]; + NSDictionary* outputTensor = [result objectForKey:@"Plus214_Output_0"]; - NSString *data = [outputTensor objectForKey:@"data"]; - NSData *buffer = [[NSData alloc] initWithBase64EncodedString:data options:0]; - float *values = (float *)[buffer bytes]; + NSString* data = [outputTensor objectForKey:@"data"]; + NSData* buffer = [[NSData alloc] initWithBase64EncodedString:data options:0]; + float* values = (float*)[buffer bytes]; int count = (int)[buffer length] / 4; int argmax = 0; @@ -161,7 +161,7 @@ - (NSDictionary *)postprocess:(NSDictionary *)result { detectionResult = [NSMutableString stringWithFormat:@"%d", argmax]; } - NSDictionary *cookedMap = @{@"result" : detectionResult}; + NSDictionary* cookedMap = @{@"result" : detectionResult}; return cookedMap; } diff --git a/js/react_native/e2e/ios/OnnxruntimeModuleExample/AppDelegate.h b/js/react_native/e2e/ios/OnnxruntimeModuleExample/AppDelegate.h index 2726d5e13c723..ad01d3fff4d4c 100644 --- a/js/react_native/e2e/ios/OnnxruntimeModuleExample/AppDelegate.h +++ b/js/react_native/e2e/ios/OnnxruntimeModuleExample/AppDelegate.h @@ -10,6 +10,6 @@ @interface AppDelegate : UIResponder -@property (nonatomic, strong) UIWindow *window; +@property(nonatomic, strong) UIWindow* window; @end diff --git a/js/react_native/e2e/ios/OnnxruntimeModuleExample/AppDelegate.m b/js/react_native/e2e/ios/OnnxruntimeModuleExample/AppDelegate.m index c184b705e9e7d..44bfc81f4ad79 100644 --- a/js/react_native/e2e/ios/OnnxruntimeModuleExample/AppDelegate.m +++ b/js/react_native/e2e/ios/OnnxruntimeModuleExample/AppDelegate.m @@ -18,9 +18,9 @@ #import #import #import -static void InitializeFlipper(UIApplication *application) { - FlipperClient *client = [FlipperClient sharedClient]; - SKDescriptorMapper *layoutDescriptorMapper = [[SKDescriptorMapper alloc] initWithDefaults]; +static void InitializeFlipper(UIApplication* application) { + FlipperClient* client = [FlipperClient sharedClient]; + SKDescriptorMapper* layoutDescriptorMapper = [[SKDescriptorMapper alloc] initWithDefaults]; [client addPlugin:[[FlipperKitLayoutPlugin alloc] initWithRootNode:application withDescriptorMapper:layoutDescriptorMapper]]; [client addPlugin:[[FKUserDefaultsPlugin alloc] initWithSuiteName:nil]]; [client addPlugin:[FlipperKitReactPlugin new]]; @@ -31,28 +31,26 @@ static void InitializeFlipper(UIApplication *application) { @implementation AppDelegate -- (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:(NSDictionary *)launchOptions -{ - #ifdef FB_SONARKIT_ENABLED - InitializeFlipper(application); - #endif - RCTBridge *bridge = [[RCTBridge alloc] initWithDelegate:self launchOptions:launchOptions]; - RCTRootView *rootView = [[RCTRootView alloc] initWithBridge:bridge +- (BOOL)application:(UIApplication*)application didFinishLaunchingWithOptions:(NSDictionary*)launchOptions { +#ifdef FB_SONARKIT_ENABLED + InitializeFlipper(application); +#endif + RCTBridge* bridge = [[RCTBridge alloc] initWithDelegate:self launchOptions:launchOptions]; + RCTRootView* rootView = [[RCTRootView alloc] initWithBridge:bridge moduleName:@"OnnxruntimeModuleExample" initialProperties:nil]; rootView.backgroundColor = [[UIColor alloc] initWithRed:1.0f green:1.0f blue:1.0f alpha:1]; self.window = [[UIWindow alloc] initWithFrame:[UIScreen mainScreen].bounds]; - UIViewController *rootViewController = [UIViewController new]; + UIViewController* rootViewController = [UIViewController new]; rootViewController.view = rootView; self.window.rootViewController = rootViewController; [self.window makeKeyAndVisible]; return YES; } -- (NSURL *)sourceURLForBridge:(RCTBridge *)bridge -{ +- (NSURL*)sourceURLForBridge:(RCTBridge*)bridge { #if DEBUG return [[RCTBundleURLProvider sharedSettings] jsBundleURLForBundleRoot:@"index"]; #else diff --git a/js/react_native/e2e/ios/OnnxruntimeModuleExample/main.m b/js/react_native/e2e/ios/OnnxruntimeModuleExample/main.m index c316cf816e736..3ed24eae1b104 100644 --- a/js/react_native/e2e/ios/OnnxruntimeModuleExample/main.m +++ b/js/react_native/e2e/ios/OnnxruntimeModuleExample/main.m @@ -9,7 +9,7 @@ #import "AppDelegate.h" -int main(int argc, char * argv[]) { +int main(int argc, char* argv[]) { @autoreleasepool { return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); } diff --git a/js/react_native/e2e/metro.config.js b/js/react_native/e2e/metro.config.js index 56941aa01458c..9e7fb1c73d9cf 100644 --- a/js/react_native/e2e/metro.config.js +++ b/js/react_native/e2e/metro.config.js @@ -19,10 +19,7 @@ module.exports = { // So we exclusionlist them at the root, and alias them to the versions in example's node_modules resolver: { exclusionlistRE: exclusionlist( - modules.map( - (m) => - new RegExp(`^${escape(path.join(root, 'node_modules', m))}\\/.*$`) - ) + modules.map((m) => new RegExp(`^${escape(path.join(root, 'node_modules', m))}\\/.*$`)), ), extraNodeModules: modules.reduce((acc, name) => { diff --git a/js/react_native/e2e/src/mnist-data-handler.ts b/js/react_native/e2e/src/mnist-data-handler.ts index cde5aa8b1fefe..906e8e0ac15e8 100644 --- a/js/react_native/e2e/src/mnist-data-handler.ts +++ b/js/react_native/e2e/src/mnist-data-handler.ts @@ -1,17 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {NativeModules} from 'react-native'; +import { NativeModules } from 'react-native'; export interface MNISTInput { [name: string]: { - dims: number[]; type: string; data: string; // encoded tensor data + dims: number[]; + type: string; + data: string; // encoded tensor data }; } export interface MNISTOutput { [name: string]: { - data: string; // encoded tensor data + data: string; // encoded tensor data }; } @@ -20,7 +22,9 @@ export interface MNISTResult { } type MNISTType = { - getLocalModelPath(): Promise; getImagePath(): Promise; preprocess(uri: string): Promise; + getLocalModelPath(): Promise; + getImagePath(): Promise; + preprocess(uri: string): Promise; postprocess(result: MNISTOutput): Promise; }; diff --git a/js/react_native/e2e/test/OnnxruntimeModuleExample.test.js b/js/react_native/e2e/test/OnnxruntimeModuleExample.test.js index 5b524039ca4e1..2e8a7446b6330 100644 --- a/js/react_native/e2e/test/OnnxruntimeModuleExample.test.js +++ b/js/react_native/e2e/test/OnnxruntimeModuleExample.test.js @@ -24,4 +24,4 @@ describe('OnnxruntimeModuleExample', () => { await expect(element(by.label('output'))).toHaveText('Result: 3'); } }); -}); \ No newline at end of file +}); diff --git a/js/react_native/ios/OnnxruntimeJSIHelper.mm b/js/react_native/ios/OnnxruntimeJSIHelper.mm index f6ce63c172fc5..7d93eaf1742fd 100644 --- a/js/react_native/ios/OnnxruntimeJSIHelper.mm +++ b/js/react_native/ios/OnnxruntimeJSIHelper.mm @@ -9,27 +9,27 @@ @implementation OnnxruntimeJSIHelper RCT_EXPORT_MODULE() -- (void)setBridge:(RCTBridge *)bridge { +- (void)setBridge:(RCTBridge*)bridge { _bridge = bridge; } RCT_EXPORT_BLOCKING_SYNCHRONOUS_METHOD(install) { - RCTCxxBridge *cxxBridge = (RCTCxxBridge *)_bridge; + RCTCxxBridge* cxxBridge = (RCTCxxBridge*)_bridge; if (cxxBridge == nil) { return @false; } using namespace facebook; - auto jsiRuntime = (jsi::Runtime *)cxxBridge.runtime; + auto jsiRuntime = (jsi::Runtime*)cxxBridge.runtime; if (jsiRuntime == nil) { return @false; } - auto &runtime = *jsiRuntime; + auto& runtime = *jsiRuntime; auto resolveArrayBuffer = jsi::Function::createFromHostFunction( runtime, jsi::PropNameID::forUtf8(runtime, "jsiOnnxruntimeResolveArrayBuffer"), 1, - [](jsi::Runtime &runtime, const jsi::Value &thisArg, const jsi::Value *args, size_t count) -> jsi::Value { + [](jsi::Runtime& runtime, const jsi::Value& thisArg, const jsi::Value* args, size_t count) -> jsi::Value { if (count != 1) { throw jsi::JSError(runtime, "jsiOnnxruntimeResolveArrayBuffer(..) expects one argument (object)!"); } @@ -39,12 +39,12 @@ - (void)setBridge:(RCTBridge *)bridge { auto size = data.getProperty(runtime, "size").asNumber(); auto offset = data.getProperty(runtime, "offset").asNumber(); - RCTBlobManager *blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class]; + RCTBlobManager* blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class]; if (blobManager == nil) { throw jsi::JSError(runtime, "RCTBlobManager is not initialized"); } - NSString *blobIdStr = [NSString stringWithUTF8String:blobId.c_str()]; + NSString* blobIdStr = [NSString stringWithUTF8String:blobId.c_str()]; auto blob = [blobManager resolve:blobIdStr offset:(long)offset size:(long)size]; jsi::Function arrayBufferCtor = runtime.global().getPropertyAsFunction(runtime, "ArrayBuffer"); @@ -58,21 +58,21 @@ - (void)setBridge:(RCTBridge *)bridge { auto storeArrayBuffer = jsi::Function::createFromHostFunction( runtime, jsi::PropNameID::forUtf8(runtime, "jsiOnnxruntimeStoreArrayBuffer"), 1, - [](jsi::Runtime &runtime, const jsi::Value &thisArg, const jsi::Value *args, size_t count) -> jsi::Value { + [](jsi::Runtime& runtime, const jsi::Value& thisArg, const jsi::Value* args, size_t count) -> jsi::Value { if (count != 1) { throw jsi::JSError(runtime, "jsiOnnxruntimeStoreArrayBuffer(..) expects one argument (object)!"); } auto arrayBuffer = args[0].asObject(runtime).getArrayBuffer(runtime); auto size = arrayBuffer.length(runtime); - NSData *data = [NSData dataWithBytesNoCopy:arrayBuffer.data(runtime) length:size freeWhenDone:NO]; + NSData* data = [NSData dataWithBytesNoCopy:arrayBuffer.data(runtime) length:size freeWhenDone:NO]; - RCTBlobManager *blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class]; + RCTBlobManager* blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class]; if (blobManager == nil) { throw jsi::JSError(runtime, "RCTBlobManager is not initialized"); } - NSString *blobId = [blobManager store:data]; + NSString* blobId = [blobManager store:data]; jsi::Object result(runtime); auto blobIdString = jsi::String::createFromUtf8(runtime, [blobId cStringUsingEncoding:NSUTF8StringEncoding]); diff --git a/js/react_native/ios/OnnxruntimeModule.h b/js/react_native/ios/OnnxruntimeModule.h index 24603cc648525..2abdd39f019d1 100644 --- a/js/react_native/ios/OnnxruntimeModule.h +++ b/js/react_native/ios/OnnxruntimeModule.h @@ -7,22 +7,22 @@ #import #import -@interface OnnxruntimeModule : NSObject +@interface OnnxruntimeModule : NSObject -- (void)setBlobManager:(RCTBlobManager *)manager; +- (void)setBlobManager:(RCTBlobManager*)manager; --(NSDictionary*)loadModel:(NSString*)modelPath - options:(NSDictionary*)options; +- (NSDictionary*)loadModel:(NSString*)modelPath + options:(NSDictionary*)options; --(NSDictionary*)loadModelFromBuffer:(NSData*)modelData - options:(NSDictionary*)options; +- (NSDictionary*)loadModelFromBuffer:(NSData*)modelData + options:(NSDictionary*)options; --(void)dispose:(NSString*)key; +- (void)dispose:(NSString*)key; --(NSDictionary*)run:(NSString*)url - input:(NSDictionary*)input - output:(NSArray*)output - options:(NSDictionary*)options; +- (NSDictionary*)run:(NSString*)url + input:(NSDictionary*)input + output:(NSArray*)output + options:(NSDictionary*)options; @end diff --git a/js/react_native/ios/OnnxruntimeModule.mm b/js/react_native/ios/OnnxruntimeModule.mm index 040e1dc29ef24..9da76034fc1ad 100644 --- a/js/react_native/ios/OnnxruntimeModule.mm +++ b/js/react_native/ios/OnnxruntimeModule.mm @@ -29,26 +29,26 @@ @implementation OnnxruntimeModule struct SessionInfo { std::unique_ptr session; - std::vector inputNames; + std::vector inputNames; std::vector inputNames_ptrs; - std::vector outputNames; + std::vector outputNames; std::vector outputNames_ptrs; }; -static Ort::Env *ortEnv = new Ort::Env(ORT_LOGGING_LEVEL_INFO, "Default"); -static NSMutableDictionary *sessionMap = [NSMutableDictionary dictionary]; +static Ort::Env* ortEnv = new Ort::Env(ORT_LOGGING_LEVEL_INFO, "Default"); +static NSMutableDictionary* sessionMap = [NSMutableDictionary dictionary]; static Ort::AllocatorWithDefaultOptions ortAllocator; static int nextSessionId = 0; -- (NSString *)getNextSessionKey { - NSString *key = @(nextSessionId).stringValue; +- (NSString*)getNextSessionKey { + NSString* key = @(nextSessionId).stringValue; nextSessionId++; return key; } RCT_EXPORT_MODULE(Onnxruntime) -RCTBlobManager *blobManager = nil; +RCTBlobManager* blobManager = nil; - (void)checkBlobManager { if (blobManager == nil) { @@ -59,7 +59,7 @@ - (void)checkBlobManager { } } -- (void)setBlobManager:(RCTBlobManager *)manager { +- (void)setBlobManager:(RCTBlobManager*)manager { blobManager = manager; } @@ -74,12 +74,12 @@ - (void)setBlobManager:(RCTBlobManager *)manager { * @note when run() is called, the same modelPath must be passed into the first parameter. */ RCT_EXPORT_METHOD(loadModel - : (NSString *)modelPath options - : (NSDictionary *)options resolver + : (NSString*)modelPath options + : (NSDictionary*)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) { @try { - NSDictionary *resultMap = [self loadModel:modelPath options:options]; + NSDictionary* resultMap = [self loadModel:modelPath options:options]; resolve(resultMap); } @catch (...) { reject(@"onnxruntime", @"failed to load model", nil); @@ -96,17 +96,17 @@ - (void)setBlobManager:(RCTBlobManager *)manager { * @note when run() is called, the same modelPath must be passed into the first parameter. */ RCT_EXPORT_METHOD(loadModelFromBlob - : (NSDictionary *)modelDataBlob options - : (NSDictionary *)options resolver + : (NSDictionary*)modelDataBlob options + : (NSDictionary*)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) { @try { [self checkBlobManager]; - NSString *blobId = [modelDataBlob objectForKey:@"blobId"]; + NSString* blobId = [modelDataBlob objectForKey:@"blobId"]; long size = [[modelDataBlob objectForKey:@"size"] longValue]; long offset = [[modelDataBlob objectForKey:@"offset"] longValue]; auto modelData = [blobManager resolve:blobId offset:offset size:size]; - NSDictionary *resultMap = [self loadModelFromBuffer:modelData options:options]; + NSDictionary* resultMap = [self loadModelFromBuffer:modelData options:options]; [blobManager remove:blobId]; resolve(resultMap); } @catch (...) { @@ -122,7 +122,7 @@ - (void)setBlobManager:(RCTBlobManager *)manager { * @param reject callback for returning an error back to react native js */ RCT_EXPORT_METHOD(dispose - : (NSString *)key resolver + : (NSString*)key resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) { @try { @@ -144,14 +144,14 @@ - (void)setBlobManager:(RCTBlobManager *)manager { * @param reject callback for returning an error back to react native js */ RCT_EXPORT_METHOD(run - : (NSString *)url input - : (NSDictionary *)input output - : (NSArray *)output options - : (NSDictionary *)options resolver + : (NSString*)url input + : (NSDictionary*)input output + : (NSArray*)output options + : (NSDictionary*)options resolver : (RCTPromiseResolveBlock)resolve rejecter : (RCTPromiseRejectBlock)reject) { @try { - NSDictionary *resultMap = [self run:url input:input output:output options:options]; + NSDictionary* resultMap = [self run:url input:input output:output options:options]; resolve(resultMap); } @catch (...) { reject(@"onnxruntime", @"failed to run model", nil); @@ -165,7 +165,7 @@ - (void)setBlobManager:(RCTBlobManager *)manager { * @param options onnxruntime session options. * @note when run() is called, the same modelPath must be passed into the first parameter. */ -- (NSDictionary *)loadModel:(NSString *)modelPath options:(NSDictionary *)options { +- (NSDictionary*)loadModel:(NSString*)modelPath options:(NSDictionary*)options { return [self loadModelImpl:modelPath modelData:nil options:options]; } @@ -175,7 +175,7 @@ - (NSDictionary *)loadModel:(NSString *)modelPath options:(NSDictionary *)option * @param modelData the model data buffer. * @param options onnxruntime session options */ -- (NSDictionary *)loadModelFromBuffer:(NSData *)modelData options:(NSDictionary *)options { +- (NSDictionary*)loadModelFromBuffer:(NSData*)modelData options:(NSDictionary*)options { return [self loadModelImpl:@"" modelData:modelData options:options]; } @@ -186,8 +186,8 @@ - (NSDictionary *)loadModelFromBuffer:(NSData *)modelData options:(NSDictionary * @param modelData the model data buffer. * @param options onnxruntime session options. */ -- (NSDictionary *)loadModelImpl:(NSString *)modelPath modelData:(NSData *)modelData options:(NSDictionary *)options { - SessionInfo *sessionInfo = nullptr; +- (NSDictionary*)loadModelImpl:(NSString*)modelPath modelData:(NSData*)modelData options:(NSDictionary*)options { + SessionInfo* sessionInfo = nullptr; sessionInfo = new SessionInfo(); Ort::SessionOptions sessionOptions = [self parseSessionOptions:options]; @@ -199,7 +199,7 @@ - (NSDictionary *)loadModelImpl:(NSString *)modelPath modelData:(NSData *)modelD sessionInfo->session.reset(new Ort::Session(*ortEnv, [modelPath UTF8String], sessionOptions)); } else { NSUInteger dataLength = [modelData length]; - Byte *modelBytes = (Byte *)[modelData bytes]; + Byte* modelBytes = (Byte*)[modelData bytes]; sessionInfo->session.reset(new Ort::Session(*ortEnv, modelBytes, (size_t)dataLength, sessionOptions)); } @@ -217,20 +217,20 @@ - (NSDictionary *)loadModelImpl:(NSString *)modelPath modelData:(NSData *)modelD sessionInfo->outputNames_ptrs.emplace_back(std::move(outputName)); } - NSString *key = [self getNextSessionKey]; - NSValue *value = [NSValue valueWithPointer:(void *)sessionInfo]; + NSString* key = [self getNextSessionKey]; + NSValue* value = [NSValue valueWithPointer:(void*)sessionInfo]; sessionMap[key] = value; - NSMutableDictionary *resultMap = [NSMutableDictionary dictionary]; + NSMutableDictionary* resultMap = [NSMutableDictionary dictionary]; resultMap[@"key"] = key; - NSMutableArray *inputNames = [NSMutableArray array]; + NSMutableArray* inputNames = [NSMutableArray array]; for (auto inputName : sessionInfo->inputNames) { [inputNames addObject:[NSString stringWithCString:inputName encoding:NSUTF8StringEncoding]]; } resultMap[@"inputNames"] = inputNames; - NSMutableArray *outputNames = [NSMutableArray array]; + NSMutableArray* outputNames = [NSMutableArray array]; for (auto outputName : sessionInfo->outputNames) { [outputNames addObject:[NSString stringWithCString:outputName encoding:NSUTF8StringEncoding]]; } @@ -244,16 +244,16 @@ - (NSDictionary *)loadModelImpl:(NSString *)modelPath modelData:(NSData *)modelD * * @param key a session key returned from loadModel() */ -- (void)dispose:(NSString *)key { - NSValue *value = [sessionMap objectForKey:key]; +- (void)dispose:(NSString*)key { + NSValue* value = [sessionMap objectForKey:key]; if (value == nil) { - NSException *exception = [NSException exceptionWithName:@"onnxruntime" + NSException* exception = [NSException exceptionWithName:@"onnxruntime" reason:@"can't find onnxruntime session" userInfo:nil]; @throw exception; } [sessionMap removeObjectForKey:key]; - SessionInfo *sessionInfo = (SessionInfo *)[value pointerValue]; + SessionInfo* sessionInfo = (SessionInfo*)[value pointerValue]; delete sessionInfo; sessionInfo = nullptr; } @@ -266,18 +266,18 @@ - (void)dispose:(NSString *)key { * @param output an output names to be returned * @param options onnxruntime run options */ -- (NSDictionary *)run:(NSString *)url - input:(NSDictionary *)input - output:(NSArray *)output - options:(NSDictionary *)options { - NSValue *value = [sessionMap objectForKey:url]; +- (NSDictionary*)run:(NSString*)url + input:(NSDictionary*)input + output:(NSArray*)output + options:(NSDictionary*)options { + NSValue* value = [sessionMap objectForKey:url]; if (value == nil) { - NSException *exception = [NSException exceptionWithName:@"onnxruntime" + NSException* exception = [NSException exceptionWithName:@"onnxruntime" reason:@"can't find onnxruntime session" userInfo:nil]; @throw exception; } - SessionInfo *sessionInfo = (SessionInfo *)[value pointerValue]; + SessionInfo* sessionInfo = (SessionInfo*)[value pointerValue]; [self checkBlobManager]; @@ -285,9 +285,9 @@ - (NSDictionary *)run:(NSString *)url std::vector allocations; feeds.reserve(sessionInfo->inputNames.size()); for (auto inputName : sessionInfo->inputNames) { - NSDictionary *inputTensor = [input objectForKey:[NSString stringWithUTF8String:inputName]]; + NSDictionary* inputTensor = [input objectForKey:[NSString stringWithUTF8String:inputName]]; if (inputTensor == nil) { - NSException *exception = [NSException exceptionWithName:@"onnxruntime" reason:@"can't find input" userInfo:nil]; + NSException* exception = [NSException exceptionWithName:@"onnxruntime" reason:@"can't find input" userInfo:nil]; @throw exception; } @@ -298,9 +298,9 @@ - (NSDictionary *)run:(NSString *)url feeds.emplace_back(std::move(value)); } - std::vector requestedOutputs; + std::vector requestedOutputs; requestedOutputs.reserve(output.count); - for (NSString *outputName : output) { + for (NSString* outputName : output) { requestedOutputs.emplace_back([outputName UTF8String]); } Ort::RunOptions runOptions = [self parseRunOptions:options]; @@ -309,21 +309,21 @@ - (NSDictionary *)run:(NSString *)url sessionInfo->session->Run(runOptions, sessionInfo->inputNames.data(), feeds.data(), sessionInfo->inputNames.size(), requestedOutputs.data(), requestedOutputs.size()); - NSDictionary *resultMap = [TensorHelper createOutputTensor:blobManager outputNames:requestedOutputs values:result]; + NSDictionary* resultMap = [TensorHelper createOutputTensor:blobManager outputNames:requestedOutputs values:result]; return resultMap; } -static NSDictionary *graphOptimizationLevelTable = @{ +static NSDictionary* graphOptimizationLevelTable = @{ @"disabled" : @(ORT_DISABLE_ALL), @"basic" : @(ORT_ENABLE_BASIC), @"extended" : @(ORT_ENABLE_EXTENDED), @"all" : @(ORT_ENABLE_ALL) }; -static NSDictionary *executionModeTable = @{@"sequential" : @(ORT_SEQUENTIAL), @"parallel" : @(ORT_PARALLEL)}; +static NSDictionary* executionModeTable = @{@"sequential" : @(ORT_SEQUENTIAL), @"parallel" : @(ORT_PARALLEL)}; -- (Ort::SessionOptions)parseSessionOptions:(NSDictionary *)options { +- (Ort::SessionOptions)parseSessionOptions:(NSDictionary*)options { Ort::SessionOptions sessionOptions; if ([options objectForKey:@"intraOpNumThreads"]) { @@ -341,7 +341,7 @@ - (NSDictionary *)run:(NSString *)url } if ([options objectForKey:@"graphOptimizationLevel"]) { - NSString *graphOptimizationLevel = [[options objectForKey:@"graphOptimizationLevel"] stringValue]; + NSString* graphOptimizationLevel = [[options objectForKey:@"graphOptimizationLevel"] stringValue]; if ([graphOptimizationLevelTable objectForKey:graphOptimizationLevel]) { sessionOptions.SetGraphOptimizationLevel( (GraphOptimizationLevel)[[graphOptimizationLevelTable objectForKey:graphOptimizationLevel] intValue]); @@ -367,19 +367,19 @@ - (NSDictionary *)run:(NSString *)url } if ([options objectForKey:@"executionMode"]) { - NSString *executionMode = [[options objectForKey:@"executionMode"] stringValue]; + NSString* executionMode = [[options objectForKey:@"executionMode"] stringValue]; if ([executionModeTable objectForKey:executionMode]) { sessionOptions.SetExecutionMode((ExecutionMode)[[executionModeTable objectForKey:executionMode] intValue]); } } if ([options objectForKey:@"executionProviders"]) { - NSArray *executionProviders = [options objectForKey:@"executionProviders"]; - for (auto *executionProvider in executionProviders) { - NSString *epName = nil; + NSArray* executionProviders = [options objectForKey:@"executionProviders"]; + for (auto* executionProvider in executionProviders) { + NSString* epName = nil; bool useOptions = false; if ([executionProvider isKindOfClass:[NSString class]]) { - epName = (NSString *)executionProvider; + epName = (NSString*)executionProvider; } else { epName = [executionProvider objectForKey:@"name"]; useOptions = true; @@ -403,7 +403,7 @@ - (NSDictionary *)run:(NSString *)url } else if ([epName isEqualToString:@"cpu"]) { continue; } else { - NSException *exception = [NSException exceptionWithName:@"onnxruntime" + NSException* exception = [NSException exceptionWithName:@"onnxruntime" reason:@"unsupported execution provider" userInfo:nil]; @throw exception; @@ -412,7 +412,7 @@ - (NSDictionary *)run:(NSString *)url } if ([options objectForKey:@"logId"]) { - NSString *logId = [[options objectForKey:@"logId"] stringValue]; + NSString* logId = [[options objectForKey:@"logId"] stringValue]; sessionOptions.SetLogId([logId UTF8String]); } @@ -424,7 +424,7 @@ - (NSDictionary *)run:(NSString *)url return sessionOptions; } -- (Ort::RunOptions)parseRunOptions:(NSDictionary *)options { +- (Ort::RunOptions)parseRunOptions:(NSDictionary*)options { Ort::RunOptions runOptions; if ([options objectForKey:@"logSeverityLevel"]) { @@ -433,7 +433,7 @@ - (NSDictionary *)run:(NSString *)url } if ([options objectForKey:@"tag"]) { - NSString *tag = [[options objectForKey:@"tag"] stringValue]; + NSString* tag = [[options objectForKey:@"tag"] stringValue]; runOptions.SetRunTag([tag UTF8String]); } @@ -441,8 +441,8 @@ - (NSDictionary *)run:(NSString *)url } - (void)dealloc { - NSEnumerator *iterator = [sessionMap keyEnumerator]; - while (NSString *key = [iterator nextObject]) { + NSEnumerator* iterator = [sessionMap keyEnumerator]; + while (NSString* key = [iterator nextObject]) { [self dispose:key]; } blobManager = nullptr; diff --git a/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.h b/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.h index c6069b1a1d26d..f1f6c0004ff2f 100644 --- a/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.h +++ b/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.h @@ -8,15 +8,15 @@ @interface FakeRCTBlobManager : RCTBlobManager -@property (nonatomic, strong) NSMutableDictionary *blobs; +@property(nonatomic, strong) NSMutableDictionary* blobs; -- (NSString *)store:(NSData *)data; +- (NSString*)store:(NSData*)data; -- (NSData *)resolve:(NSString *)blobId offset:(long)offset size:(long)size; +- (NSData*)resolve:(NSString*)blobId offset:(long)offset size:(long)size; -- (NSDictionary *)testCreateData:(NSData *)buffer; +- (NSDictionary*)testCreateData:(NSData*)buffer; -- (NSString *)testGetData:(NSDictionary *)data; +- (NSString*)testGetData:(NSDictionary*)data; @end diff --git a/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.m b/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.m index 5df902df03534..156df7b232503 100644 --- a/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.m +++ b/js/react_native/ios/OnnxruntimeModuleTest/FakeRCTBlobManager.m @@ -13,31 +13,31 @@ - (instancetype)init { return self; } -- (NSString *)store:(NSData *)data { - NSString *blobId = [[NSUUID UUID] UUIDString]; +- (NSString*)store:(NSData*)data { + NSString* blobId = [[NSUUID UUID] UUIDString]; _blobs[blobId] = data; return blobId; } -- (NSData *)resolve:(NSString *)blobId offset:(long)offset size:(long)size { - NSData *data = _blobs[blobId]; +- (NSData*)resolve:(NSString*)blobId offset:(long)offset size:(long)size { + NSData* data = _blobs[blobId]; if (data == nil) { return nil; } return [data subdataWithRange:NSMakeRange(offset, size)]; } -- (NSDictionary *)testCreateData:(NSData *)buffer { +- (NSDictionary*)testCreateData:(NSData*)buffer { NSString* blobId = [self store:buffer]; return @{ - @"blobId": blobId, - @"offset": @0, - @"size": @(buffer.length), + @"blobId" : blobId, + @"offset" : @0, + @"size" : @(buffer.length), }; } -- (NSString *)testGetData:(NSDictionary *)data { - NSString *blobId = [data objectForKey:@"blobId"]; +- (NSString*)testGetData:(NSDictionary*)data { + NSString* blobId = [data objectForKey:@"blobId"]; long size = [[data objectForKey:@"size"] longValue]; long offset = [[data objectForKey:@"offset"] longValue]; [self resolve:blobId offset:offset size:size]; diff --git a/js/react_native/ios/OnnxruntimeModuleTest/OnnxruntimeModuleTest.mm b/js/react_native/ios/OnnxruntimeModuleTest/OnnxruntimeModuleTest.mm index f5805717f6615..7059177400f3c 100644 --- a/js/react_native/ios/OnnxruntimeModuleTest/OnnxruntimeModuleTest.mm +++ b/js/react_native/ios/OnnxruntimeModuleTest/OnnxruntimeModuleTest.mm @@ -14,7 +14,7 @@ @interface OnnxruntimeModuleTest : XCTestCase @implementation OnnxruntimeModuleTest -FakeRCTBlobManager *fakeBlobManager = nil; +FakeRCTBlobManager* fakeBlobManager = nil; + (void)initialize { if (self == [OnnxruntimeModuleTest class]) { @@ -23,45 +23,45 @@ + (void)initialize { } - (void)testOnnxruntimeModule { - NSBundle *bundle = [NSBundle bundleForClass:[OnnxruntimeModuleTest class]]; - NSString *dataPath = [bundle pathForResource:@"test_types_float" ofType:@"ort"]; - NSString *sessionKey = @""; - NSString *sessionKey2 = @""; + NSBundle* bundle = [NSBundle bundleForClass:[OnnxruntimeModuleTest class]]; + NSString* dataPath = [bundle pathForResource:@"test_types_float" ofType:@"ort"]; + NSString* sessionKey = @""; + NSString* sessionKey2 = @""; - OnnxruntimeModule *onnxruntimeModule = [OnnxruntimeModule new]; + OnnxruntimeModule* onnxruntimeModule = [OnnxruntimeModule new]; [onnxruntimeModule setBlobManager:fakeBlobManager]; { // test loadModelFromBuffer() - NSMutableDictionary *options = [NSMutableDictionary dictionary]; - NSData *fileData = [NSData dataWithContentsOfFile:dataPath]; + NSMutableDictionary* options = [NSMutableDictionary dictionary]; + NSData* fileData = [NSData dataWithContentsOfFile:dataPath]; - NSDictionary *resultMap = [onnxruntimeModule loadModelFromBuffer:fileData options:options]; + NSDictionary* resultMap = [onnxruntimeModule loadModelFromBuffer:fileData options:options]; sessionKey = resultMap[@"key"]; - NSArray *inputNames = resultMap[@"inputNames"]; + NSArray* inputNames = resultMap[@"inputNames"]; XCTAssertEqual([inputNames count], 1); XCTAssertEqualObjects(inputNames[0], @"input"); - NSArray *outputNames = resultMap[@"outputNames"]; + NSArray* outputNames = resultMap[@"outputNames"]; XCTAssertEqual([outputNames count], 1); XCTAssertEqualObjects(outputNames[0], @"output"); // test loadModel() - NSDictionary *resultMap2 = [onnxruntimeModule loadModel:dataPath options:options]; + NSDictionary* resultMap2 = [onnxruntimeModule loadModel:dataPath options:options]; sessionKey2 = resultMap2[@"key"]; - NSArray *inputNames2 = resultMap2[@"inputNames"]; + NSArray* inputNames2 = resultMap2[@"inputNames"]; XCTAssertEqual([inputNames2 count], 1); XCTAssertEqualObjects(inputNames2[0], @"input"); - NSArray *outputNames2 = resultMap2[@"outputNames"]; + NSArray* outputNames2 = resultMap2[@"outputNames"]; XCTAssertEqual([outputNames2 count], 1); XCTAssertEqualObjects(outputNames2[0], @"output"); } // test run() { - NSMutableDictionary *inputTensorMap = [NSMutableDictionary dictionary]; + NSMutableDictionary* inputTensorMap = [NSMutableDictionary dictionary]; // dims - NSArray *dims = @[ [NSNumber numberWithLong:1], [NSNumber numberWithLong:5] ]; + NSArray* dims = @[ [NSNumber numberWithLong:1], [NSNumber numberWithLong:5] ]; inputTensorMap[@"dims"] = dims; // type @@ -72,27 +72,27 @@ - (void)testOnnxruntimeModule { std::numeric_limits::max()}; const NSInteger byteBufferSize = outValues.size() * sizeof(float); - unsigned char *byteBuffer = static_cast(malloc(byteBufferSize)); - NSData *byteBufferRef = [NSData dataWithBytesNoCopy:byteBuffer length:byteBufferSize]; - float *floatPtr = (float *)[byteBufferRef bytes]; + unsigned char* byteBuffer = static_cast(malloc(byteBufferSize)); + NSData* byteBufferRef = [NSData dataWithBytesNoCopy:byteBuffer length:byteBufferSize]; + float* floatPtr = (float*)[byteBufferRef bytes]; for (NSUInteger i = 0; i < outValues.size(); ++i) { *floatPtr++ = outValues[i]; } - floatPtr = (float *)[byteBufferRef bytes]; + floatPtr = (float*)[byteBufferRef bytes]; XCTAssertNotNil(fakeBlobManager); inputTensorMap[@"data"] = [fakeBlobManager testCreateData:byteBufferRef]; - NSMutableDictionary *inputDataMap = [NSMutableDictionary dictionary]; + NSMutableDictionary* inputDataMap = [NSMutableDictionary dictionary]; inputDataMap[@"input"] = inputTensorMap; - NSMutableDictionary *options = [NSMutableDictionary dictionary]; + NSMutableDictionary* options = [NSMutableDictionary dictionary]; - NSMutableArray *output = [NSMutableArray array]; + NSMutableArray* output = [NSMutableArray array]; [output addObject:@"output"]; - NSDictionary *resultMap = [onnxruntimeModule run:sessionKey input:inputDataMap output:output options:options]; - NSDictionary *resultMap2 = [onnxruntimeModule run:sessionKey2 input:inputDataMap output:output options:options]; + NSDictionary* resultMap = [onnxruntimeModule run:sessionKey input:inputDataMap output:output options:options]; + NSDictionary* resultMap2 = [onnxruntimeModule run:sessionKey2 input:inputDataMap output:output options:options]; // Compare output & input, but data.blobId is different // dims @@ -116,30 +116,30 @@ - (void)testOnnxruntimeModule { } - (void)testOnnxruntimeModule_AppendCoreml { - NSBundle *bundle = [NSBundle bundleForClass:[OnnxruntimeModuleTest class]]; - NSString *dataPath = [bundle pathForResource:@"test_types_float" ofType:@"ort"]; - NSString *sessionKey = @""; + NSBundle* bundle = [NSBundle bundleForClass:[OnnxruntimeModuleTest class]]; + NSString* dataPath = [bundle pathForResource:@"test_types_float" ofType:@"ort"]; + NSString* sessionKey = @""; - OnnxruntimeModule *onnxruntimeModule = [OnnxruntimeModule new]; + OnnxruntimeModule* onnxruntimeModule = [OnnxruntimeModule new]; [onnxruntimeModule setBlobManager:fakeBlobManager]; { // test loadModel() with coreml options - NSMutableDictionary *options = [NSMutableDictionary dictionary]; + NSMutableDictionary* options = [NSMutableDictionary dictionary]; // register coreml ep options - NSMutableArray *epList = [NSMutableArray array]; + NSMutableArray* epList = [NSMutableArray array]; [epList addObject:@"coreml"]; - NSArray *immutableEpList = [NSArray arrayWithArray:epList]; + NSArray* immutableEpList = [NSArray arrayWithArray:epList]; [options setObject:immutableEpList forKey:@"executionProviders"]; - NSDictionary *resultMap = [onnxruntimeModule loadModel:dataPath options:options]; + NSDictionary* resultMap = [onnxruntimeModule loadModel:dataPath options:options]; sessionKey = resultMap[@"key"]; - NSArray *inputNames = resultMap[@"inputNames"]; + NSArray* inputNames = resultMap[@"inputNames"]; XCTAssertEqual([inputNames count], 1); XCTAssertEqualObjects(inputNames[0], @"input"); - NSArray *outputNames = resultMap[@"outputNames"]; + NSArray* outputNames = resultMap[@"outputNames"]; XCTAssertEqual([outputNames count], 1); XCTAssertEqualObjects(outputNames[0], @"output"); } diff --git a/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm b/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm index edd476d03914c..7b307a5bb26fd 100644 --- a/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm +++ b/js/react_native/ios/OnnxruntimeModuleTest/TensorHelperTest.mm @@ -14,7 +14,7 @@ @interface TensorHelperTest : XCTestCase @implementation TensorHelperTest -FakeRCTBlobManager *testBlobManager = nil; +FakeRCTBlobManager* testBlobManager = nil; + (void)initialize { if (self == [TensorHelperTest class]) { @@ -23,12 +23,12 @@ + (void)initialize { } template -static void testCreateInputTensorT(const std::array &outValues, std::function &convert, - ONNXTensorElementDataType onnxType, NSString *jsTensorType) { - NSMutableDictionary *inputTensorMap = [NSMutableDictionary dictionary]; +static void testCreateInputTensorT(const std::array& outValues, std::function& convert, + ONNXTensorElementDataType onnxType, NSString* jsTensorType) { + NSMutableDictionary* inputTensorMap = [NSMutableDictionary dictionary]; // dims - NSArray *dims = @[ [NSNumber numberWithLong:outValues.size()] ]; + NSArray* dims = @[ [NSNumber numberWithLong:outValues.size()] ]; inputTensorMap[@"dims"] = dims; // type @@ -36,9 +36,9 @@ static void testCreateInputTensorT(const std::array &outValues, std::funct // encoded data size_t byteBufferSize = sizeof(T) * outValues.size(); - unsigned char *byteBuffer = static_cast(malloc(byteBufferSize)); - NSData *byteBufferRef = [NSData dataWithBytesNoCopy:byteBuffer length:byteBufferSize]; - T *typePtr = (T *)[byteBufferRef bytes]; + unsigned char* byteBuffer = static_cast(malloc(byteBufferSize)); + NSData* byteBufferRef = [NSData dataWithBytesNoCopy:byteBuffer length:byteBufferSize]; + T* typePtr = (T*)[byteBufferRef bytes]; for (size_t i = 0; i < outValues.size(); ++i) { typePtr[i] = outValues[i]; } @@ -67,25 +67,25 @@ static void testCreateInputTensorT(const std::array &outValues, std::funct - (void)testCreateInputTensorFloat { std::array outValues{std::numeric_limits::min(), 2.0f, std::numeric_limits::max()}; - std::function convert = [](float value) { return [NSNumber numberWithFloat:value]; }; + std::function convert = [](float value) { return [NSNumber numberWithFloat:value]; }; testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, JsTensorTypeFloat); } - (void)testCreateInputTensorDouble { std::array outValues{std::numeric_limits::min(), 2.0f, std::numeric_limits::max()}; - std::function convert = [](double_t value) { return [NSNumber numberWithDouble:value]; }; + std::function convert = [](double_t value) { return [NSNumber numberWithDouble:value]; }; testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, JsTensorTypeDouble); } - (void)testCreateInputTensorBool { std::array outValues{false, true, true}; - std::function convert = [](bool value) { return [NSNumber numberWithBool:value]; }; + std::function convert = [](bool value) { return [NSNumber numberWithBool:value]; }; testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, JsTensorTypeBool); } - (void)testCreateInputTensorUInt8 { std::array outValues{std::numeric_limits::min(), 2, std::numeric_limits::max()}; - std::function convert = [](uint8_t value) { + std::function convert = [](uint8_t value) { return [NSNumber numberWithUnsignedChar:value]; }; testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, JsTensorTypeUnsignedByte); @@ -93,42 +93,42 @@ - (void)testCreateInputTensorUInt8 { - (void)testCreateInputTensorInt8 { std::array outValues{std::numeric_limits::min(), 2, std::numeric_limits::max()}; - std::function convert = [](int8_t value) { return [NSNumber numberWithChar:value]; }; + std::function convert = [](int8_t value) { return [NSNumber numberWithChar:value]; }; testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, JsTensorTypeByte); } - (void)testCreateInputTensorInt16 { std::array outValues{std::numeric_limits::min(), 2, std::numeric_limits::max()}; - std::function convert = [](int16_t value) { return [NSNumber numberWithShort:value]; }; + std::function convert = [](int16_t value) { return [NSNumber numberWithShort:value]; }; testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, JsTensorTypeShort); } - (void)testCreateInputTensorInt32 { std::array outValues{std::numeric_limits::min(), 2, std::numeric_limits::max()}; - std::function convert = [](int32_t value) { return [NSNumber numberWithInt:value]; }; + std::function convert = [](int32_t value) { return [NSNumber numberWithInt:value]; }; testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, JsTensorTypeInt); } - (void)testCreateInputTensorInt64 { std::array outValues{std::numeric_limits::min(), 2, std::numeric_limits::max()}; - std::function convert = [](int64_t value) { return [NSNumber numberWithLongLong:value]; }; + std::function convert = [](int64_t value) { return [NSNumber numberWithLongLong:value]; }; testCreateInputTensorT(outValues, convert, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, JsTensorTypeLong); } - (void)testCreateInputTensorString { std::array outValues{"a", "b", "c"}; - NSMutableDictionary *inputTensorMap = [NSMutableDictionary dictionary]; + NSMutableDictionary* inputTensorMap = [NSMutableDictionary dictionary]; // dims - NSArray *dims = @[ [NSNumber numberWithLong:outValues.size()] ]; + NSArray* dims = @[ [NSNumber numberWithLong:outValues.size()] ]; inputTensorMap[@"dims"] = dims; // type inputTensorMap[@"type"] = JsTensorTypeString; // data - NSMutableArray *data = [NSMutableArray array]; + NSMutableArray* data = [NSMutableArray array]; for (auto value : outValues) { [data addObject:[NSString stringWithUTF8String:value.c_str()]]; } @@ -150,17 +150,17 @@ - (void)testCreateInputTensorString { for (int i = 0; i < inputTensor.GetTensorTypeAndShapeInfo().GetElementCount(); ++i) { size_t elementLength = inputTensor.GetStringTensorElementLength(i); std::string element(elementLength, '\0'); - inputTensor.GetStringTensorElement(elementLength, i, (void *)element.data()); + inputTensor.GetStringTensorElement(elementLength, i, (void*)element.data()); XCTAssertEqual(element, outValues[i]); } } template -static void testCreateOutputTensorT(const std::array &outValues, std::function &convert, - NSString *jsTensorType, NSString *testDataFileName, - NSString *testDataFileExtension) { - NSBundle *bundle = [NSBundle bundleForClass:[TensorHelperTest class]]; - NSString *dataPath = [bundle pathForResource:testDataFileName ofType:testDataFileExtension]; +static void testCreateOutputTensorT(const std::array& outValues, std::function& convert, + NSString* jsTensorType, NSString* testDataFileName, + NSString* testDataFileExtension) { + NSBundle* bundle = [NSBundle bundleForClass:[TensorHelperTest class]]; + NSString* dataPath = [bundle pathForResource:testDataFileName ofType:testDataFileExtension]; Ort::Env ortEnv{ORT_LOGGING_LEVEL_INFO, "Default"}; Ort::SessionOptions sessionOptions; @@ -171,7 +171,7 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func names.reserve(session.GetInputCount() + session.GetOutputCount()); - std::vector inputNames; + std::vector inputNames; inputNames.reserve(session.GetInputCount()); for (size_t i = 0; i < session.GetInputCount(); ++i) { auto inputName = session.GetInputNameAllocated(i, ortAllocator); @@ -179,7 +179,7 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func names.emplace_back(std::move(inputName)); } - std::vector outputNames; + std::vector outputNames; outputNames.reserve(session.GetOutputCount()); for (size_t i = 0; i < session.GetOutputCount(); ++i) { auto outputName = session.GetOutputNameAllocated(i, ortAllocator); @@ -187,10 +187,10 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func names.emplace_back(std::move(outputName)); } - NSMutableDictionary *inputTensorMap = [NSMutableDictionary dictionary]; + NSMutableDictionary* inputTensorMap = [NSMutableDictionary dictionary]; // dims - NSArray *dims = @[ [NSNumber numberWithLong:1], [NSNumber numberWithLong:outValues.size()] ]; + NSArray* dims = @[ [NSNumber numberWithLong:1], [NSNumber numberWithLong:outValues.size()] ]; inputTensorMap[@"dims"] = dims; // type @@ -198,9 +198,9 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func // encoded data size_t byteBufferSize = sizeof(T) * outValues.size(); - unsigned char *byteBuffer = static_cast(malloc(byteBufferSize)); - NSData *byteBufferRef = [NSData dataWithBytesNoCopy:byteBuffer length:byteBufferSize]; - T *typePtr = (T *)[byteBufferRef bytes]; + unsigned char* byteBuffer = static_cast(malloc(byteBufferSize)); + NSData* byteBufferRef = [NSData dataWithBytesNoCopy:byteBuffer length:byteBufferSize]; + T* typePtr = (T*)[byteBufferRef bytes]; for (size_t i = 0; i < outValues.size(); ++i) { typePtr[i] = outValues[i]; } @@ -220,11 +220,11 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func auto output = session.Run(runOptions, inputNames.data(), feeds.data(), inputNames.size(), outputNames.data(), outputNames.size()); - NSDictionary *resultMap = [TensorHelper createOutputTensor:testBlobManager outputNames:outputNames values:output]; + NSDictionary* resultMap = [TensorHelper createOutputTensor:testBlobManager outputNames:outputNames values:output]; // Compare output & input, but data.blobId is different - NSDictionary *outputMap = [resultMap objectForKey:@"output"]; + NSDictionary* outputMap = [resultMap objectForKey:@"output"]; // dims XCTAssertTrue([outputMap[@"dims"] isEqualToArray:inputTensorMap[@"dims"]]); @@ -233,7 +233,7 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func XCTAssertEqual(outputMap[@"type"], jsTensorType); // data ({ blobId, offset, size }) - NSDictionary *data = outputMap[@"data"]; + NSDictionary* data = outputMap[@"data"]; XCTAssertNotNil(data[@"blobId"]); XCTAssertEqual([data[@"offset"] longValue], 0); @@ -243,26 +243,26 @@ static void testCreateOutputTensorT(const std::array &outValues, std::func - (void)testCreateOutputTensorFloat { std::array outValues{std::numeric_limits::min(), 1.0f, 2.0f, 3.0f, std::numeric_limits::max()}; - std::function convert = [](float value) { return [NSNumber numberWithFloat:value]; }; + std::function convert = [](float value) { return [NSNumber numberWithFloat:value]; }; testCreateOutputTensorT(outValues, convert, JsTensorTypeFloat, @"test_types_float", @"ort"); } - (void)testCreateOutputTensorDouble { std::array outValues{std::numeric_limits::min(), 1.0f, 2.0f, 3.0f, std::numeric_limits::max()}; - std::function convert = [](double_t value) { return [NSNumber numberWithDouble:value]; }; + std::function convert = [](double_t value) { return [NSNumber numberWithDouble:value]; }; testCreateOutputTensorT(outValues, convert, JsTensorTypeDouble, @"test_types_double", @"onnx"); } - (void)testCreateOutputTensorBool { std::array outValues{false, true, true, false, true}; - std::function convert = [](bool value) { return [NSNumber numberWithBool:value]; }; + std::function convert = [](bool value) { return [NSNumber numberWithBool:value]; }; testCreateOutputTensorT(outValues, convert, JsTensorTypeBool, @"test_types_bool", @"onnx"); } - (void)testCreateOutputTensorUInt8 { std::array outValues{std::numeric_limits::min(), 1, 2, 3, std::numeric_limits::max()}; - std::function convert = [](uint8_t value) { + std::function convert = [](uint8_t value) { return [NSNumber numberWithUnsignedChar:value]; }; testCreateOutputTensorT(outValues, convert, JsTensorTypeUnsignedByte, @"test_types_uint8", @"ort"); @@ -270,19 +270,19 @@ - (void)testCreateOutputTensorUInt8 { - (void)testCreateOutputTensorInt8 { std::array outValues{std::numeric_limits::min(), 1, -2, 3, std::numeric_limits::max()}; - std::function convert = [](int8_t value) { return [NSNumber numberWithChar:value]; }; + std::function convert = [](int8_t value) { return [NSNumber numberWithChar:value]; }; testCreateOutputTensorT(outValues, convert, JsTensorTypeByte, @"test_types_int8", @"ort"); } - (void)testCreateOutputTensorInt32 { std::array outValues{std::numeric_limits::min(), 1, -2, 3, std::numeric_limits::max()}; - std::function convert = [](int32_t value) { return [NSNumber numberWithInt:value]; }; + std::function convert = [](int32_t value) { return [NSNumber numberWithInt:value]; }; testCreateOutputTensorT(outValues, convert, JsTensorTypeInt, @"test_types_int32", @"ort"); } - (void)testCreateOutputTensorInt64 { std::array outValues{std::numeric_limits::min(), 1, -2, 3, std::numeric_limits::max()}; - std::function convert = [](int64_t value) { return [NSNumber numberWithLongLong:value]; }; + std::function convert = [](int64_t value) { return [NSNumber numberWithLongLong:value]; }; testCreateOutputTensorT(outValues, convert, JsTensorTypeLong, @"test_types_int64", @"ort"); } diff --git a/js/react_native/ios/TensorHelper.h b/js/react_native/ios/TensorHelper.h index c7c7fa8fd9f45..d0fdb5eb3a04e 100644 --- a/js/react_native/ios/TensorHelper.h +++ b/js/react_native/ios/TensorHelper.h @@ -39,18 +39,18 @@ FOUNDATION_EXPORT NSString* const JsTensorTypeString; * It creates an input tensor from a map passed by react native js. * 'data' is blob object and the buffer is stored in RCTBlobManager. It first resolve it and creates a tensor. */ -+(Ort::Value)createInputTensor:(RCTBlobManager *)blobManager - input:(NSDictionary*)input - ortAllocator:(OrtAllocator*)ortAllocator - allocations:(std::vector&)allocations; ++ (Ort::Value)createInputTensor:(RCTBlobManager*)blobManager + input:(NSDictionary*)input + ortAllocator:(OrtAllocator*)ortAllocator + allocations:(std::vector&)allocations; /** * It creates an output map from an output tensor. * a data array is store in RCTBlobManager. */ -+(NSDictionary*)createOutputTensor:(RCTBlobManager *)blobManager - outputNames:(const std::vector&)outputNames - values:(const std::vector&)values; ++ (NSDictionary*)createOutputTensor:(RCTBlobManager*)blobManager + outputNames:(const std::vector&)outputNames + values:(const std::vector&)values; @end diff --git a/js/react_native/ios/TensorHelper.mm b/js/react_native/ios/TensorHelper.mm index 8555dfec275f8..22c632a271c37 100644 --- a/js/react_native/ios/TensorHelper.mm +++ b/js/react_native/ios/TensorHelper.mm @@ -9,29 +9,29 @@ @implementation TensorHelper /** * Supported tensor data type */ -NSString *const JsTensorTypeBool = @"bool"; -NSString *const JsTensorTypeUnsignedByte = @"uint8"; -NSString *const JsTensorTypeByte = @"int8"; -NSString *const JsTensorTypeShort = @"int16"; -NSString *const JsTensorTypeInt = @"int32"; -NSString *const JsTensorTypeLong = @"int64"; -NSString *const JsTensorTypeFloat = @"float32"; -NSString *const JsTensorTypeDouble = @"float64"; -NSString *const JsTensorTypeString = @"string"; +NSString* const JsTensorTypeBool = @"bool"; +NSString* const JsTensorTypeUnsignedByte = @"uint8"; +NSString* const JsTensorTypeByte = @"int8"; +NSString* const JsTensorTypeShort = @"int16"; +NSString* const JsTensorTypeInt = @"int32"; +NSString* const JsTensorTypeLong = @"int64"; +NSString* const JsTensorTypeFloat = @"float32"; +NSString* const JsTensorTypeDouble = @"float64"; +NSString* const JsTensorTypeString = @"string"; /** * It creates an input tensor from a map passed by react native js. * 'data' is blob object and the buffer is stored in RCTBlobManager. It first resolve it and creates a tensor. */ -+ (Ort::Value)createInputTensor:(RCTBlobManager *)blobManager - input:(NSDictionary *)input - ortAllocator:(OrtAllocator *)ortAllocator - allocations:(std::vector &)allocations { ++ (Ort::Value)createInputTensor:(RCTBlobManager*)blobManager + input:(NSDictionary*)input + ortAllocator:(OrtAllocator*)ortAllocator + allocations:(std::vector&)allocations { // shape - NSArray *dimsArray = [input objectForKey:@"dims"]; + NSArray* dimsArray = [input objectForKey:@"dims"]; std::vector dims; dims.reserve(dimsArray.count); - for (NSNumber *dim in dimsArray) { + for (NSNumber* dim in dimsArray) { dims.emplace_back([dim longLongValue]); } @@ -40,17 +40,17 @@ @implementation TensorHelper // data if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - NSArray *values = [input objectForKey:@"data"]; + NSArray* values = [input objectForKey:@"data"]; auto inputTensor = Ort::Value::CreateTensor(ortAllocator, dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); size_t index = 0; - for (NSString *value in values) { + for (NSString* value in values) { inputTensor.FillStringTensorElement([value UTF8String], index++); } return inputTensor; } else { - NSDictionary *data = [input objectForKey:@"data"]; - NSString *blobId = [data objectForKey:@"blobId"]; + NSDictionary* data = [input objectForKey:@"data"]; + NSString* blobId = [data objectForKey:@"blobId"]; long size = [[data objectForKey:@"size"] longValue]; long offset = [[data objectForKey:@"offset"] longValue]; auto buffer = [blobManager resolve:blobId offset:offset size:size]; @@ -68,33 +68,33 @@ @implementation TensorHelper * It creates an output map from an output tensor. * a data array is store in RCTBlobManager. */ -+ (NSDictionary *)createOutputTensor:(RCTBlobManager *)blobManager - outputNames:(const std::vector &)outputNames - values:(const std::vector &)values { ++ (NSDictionary*)createOutputTensor:(RCTBlobManager*)blobManager + outputNames:(const std::vector&)outputNames + values:(const std::vector&)values { if (outputNames.size() != values.size()) { - NSException *exception = [NSException exceptionWithName:@"create output tensor" + NSException* exception = [NSException exceptionWithName:@"create output tensor" reason:@"output name and tensor count mismatched" userInfo:nil]; @throw exception; } - NSMutableDictionary *outputTensorMap = [NSMutableDictionary dictionary]; + NSMutableDictionary* outputTensorMap = [NSMutableDictionary dictionary]; for (size_t i = 0; i < outputNames.size(); ++i) { const auto outputName = outputNames[i]; - const Ort::Value &value = values[i]; + const Ort::Value& value = values[i]; if (!value.IsTensor()) { - NSException *exception = [NSException exceptionWithName:@"create output tensor" + NSException* exception = [NSException exceptionWithName:@"create output tensor" reason:@"only tensor type is supported" userInfo:nil]; @throw exception; } - NSMutableDictionary *outputTensor = [NSMutableDictionary dictionary]; + NSMutableDictionary* outputTensor = [NSMutableDictionary dictionary]; // dims - NSMutableArray *outputDims = [NSMutableArray array]; + NSMutableArray* outputDims = [NSMutableArray array]; auto dims = value.GetTensorTypeAndShapeInfo().GetShape(); for (auto dim : dims) { [outputDims addObject:[NSNumber numberWithLongLong:dim]]; @@ -106,17 +106,17 @@ + (NSDictionary *)createOutputTensor:(RCTBlobManager *)blobManager // data if (value.GetTensorTypeAndShapeInfo().GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - NSMutableArray *buffer = [NSMutableArray array]; + NSMutableArray* buffer = [NSMutableArray array]; for (NSInteger i = 0; i < value.GetTensorTypeAndShapeInfo().GetElementCount(); ++i) { size_t elementLength = value.GetStringTensorElementLength(i); std::string element(elementLength, '\0'); - value.GetStringTensorElement(elementLength, i, (void *)element.data()); + value.GetStringTensorElement(elementLength, i, (void*)element.data()); [buffer addObject:[NSString stringWithUTF8String:element.data()]]; } outputTensor[@"data"] = buffer; } else { - NSData *data = [self createOutputTensor:value]; - NSString *blobId = [blobManager store:data]; + NSData* data = [self createOutputTensor:value]; + NSString* blobId = [blobManager store:data]; outputTensor[@"data"] = @{ @"blobId" : blobId, @"offset" : @0, @@ -131,103 +131,104 @@ + (NSDictionary *)createOutputTensor:(RCTBlobManager *)blobManager } template -static Ort::Value createInputTensorT(OrtAllocator *ortAllocator, const std::vector &dims, NSData *buffer, - std::vector &allocations) { - T *dataBuffer = static_cast(ortAllocator->Alloc(ortAllocator, [buffer length])); +static Ort::Value createInputTensorT(OrtAllocator* ortAllocator, const std::vector& dims, NSData* buffer, + std::vector& allocations) { + T* dataBuffer = static_cast(ortAllocator->Alloc(ortAllocator, [buffer length])); allocations.emplace_back(ortAllocator, dataBuffer, [buffer length]); - memcpy(static_cast(dataBuffer), [buffer bytes], [buffer length]); + memcpy(static_cast(dataBuffer), [buffer bytes], [buffer length]); return Ort::Value::CreateTensor(ortAllocator->Info(ortAllocator), dataBuffer, buffer.length / sizeof(T), dims.data(), dims.size()); } + (Ort::Value)createInputTensor:(ONNXTensorElementDataType)tensorType - dims:(const std::vector &)dims - buffer:(NSData *)buffer - ortAllocator:(OrtAllocator *)ortAllocator - allocations:(std::vector &)allocations { + dims:(const std::vector&)dims + buffer:(NSData*)buffer + ortAllocator:(OrtAllocator*)ortAllocator + allocations:(std::vector&)allocations { switch (tensorType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return createInputTensorT(ortAllocator, dims, buffer, allocations); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - return createInputTensorT(ortAllocator, dims, buffer, allocations); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - return createInputTensorT(ortAllocator, dims, buffer, allocations); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - return createInputTensorT(ortAllocator, dims, buffer, allocations); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return createInputTensorT(ortAllocator, dims, buffer, allocations); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - return createInputTensorT(ortAllocator, dims, buffer, allocations); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - return createInputTensorT(ortAllocator, dims, buffer, allocations); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - return createInputTensorT(ortAllocator, dims, buffer, allocations); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - default: { - NSException *exception = [NSException exceptionWithName:@"create input tensor" - reason:@"unsupported tensor type" - userInfo:nil]; - @throw exception; - } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return createInputTensorT(ortAllocator, dims, buffer, allocations); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return createInputTensorT(ortAllocator, dims, buffer, allocations); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return createInputTensorT(ortAllocator, dims, buffer, allocations); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return createInputTensorT(ortAllocator, dims, buffer, allocations); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return createInputTensorT(ortAllocator, dims, buffer, allocations); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return createInputTensorT(ortAllocator, dims, buffer, allocations); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return createInputTensorT(ortAllocator, dims, buffer, allocations); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return createInputTensorT(ortAllocator, dims, buffer, allocations); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + default: { + NSException* exception = [NSException exceptionWithName:@"create input tensor" + reason:@"unsupported tensor type" + userInfo:nil]; + @throw exception; + } } } -template static NSData *createOutputTensorT(const Ort::Value &tensor) { +template +static NSData* createOutputTensorT(const Ort::Value& tensor) { const auto data = tensor.GetTensorData(); - return [NSData dataWithBytesNoCopy:(void *)data + return [NSData dataWithBytesNoCopy:(void*)data length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T) freeWhenDone:false]; } -+ (NSData *)createOutputTensor:(const Ort::Value &)tensor { ++ (NSData*)createOutputTensor:(const Ort::Value&)tensor { ONNXTensorElementDataType tensorType = tensor.GetTensorTypeAndShapeInfo().GetElementType(); switch (tensorType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return createOutputTensorT(tensor); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - return createOutputTensorT(tensor); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - return createOutputTensorT(tensor); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - return createOutputTensorT(tensor); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return createOutputTensorT(tensor); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - return createOutputTensorT(tensor); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - return createOutputTensorT(tensor); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - return createOutputTensorT(tensor); - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - default: { - NSException *exception = [NSException exceptionWithName:@"create output tensor" - reason:@"unsupported tensor type" - userInfo:nil]; - @throw exception; - } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return createOutputTensorT(tensor); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return createOutputTensorT(tensor); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return createOutputTensorT(tensor); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return createOutputTensorT(tensor); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return createOutputTensorT(tensor); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return createOutputTensorT(tensor); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return createOutputTensorT(tensor); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return createOutputTensorT(tensor); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + default: { + NSException* exception = [NSException exceptionWithName:@"create output tensor" + reason:@"unsupported tensor type" + userInfo:nil]; + @throw exception; + } } } -NSDictionary *JsTensorTypeToOnnxTensorTypeMap; -NSDictionary *OnnxTensorTypeToJsTensorTypeMap; +NSDictionary* JsTensorTypeToOnnxTensorTypeMap; +NSDictionary* OnnxTensorTypeToJsTensorTypeMap; + (void)initialize { JsTensorTypeToOnnxTensorTypeMap = @{ @@ -255,7 +256,7 @@ + (void)initialize { }; } -+ (ONNXTensorElementDataType)getOnnxTensorType:(const NSString *)type { ++ (ONNXTensorElementDataType)getOnnxTensorType:(const NSString*)type { if ([JsTensorTypeToOnnxTensorTypeMap objectForKey:type]) { return (ONNXTensorElementDataType)[JsTensorTypeToOnnxTensorTypeMap[type] intValue]; } else { @@ -263,7 +264,7 @@ + (ONNXTensorElementDataType)getOnnxTensorType:(const NSString *)type { } } -+ (NSString *)getJsTensorType:(ONNXTensorElementDataType)type { ++ (NSString*)getJsTensorType:(ONNXTensorElementDataType)type { if ([OnnxTensorTypeToJsTensorTypeMap objectForKey:@(type)]) { return OnnxTensorTypeToJsTensorTypeMap[@(type)]; } else { diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index 3d3569028e636..854a7ffd9a6ab 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -1,38 +1,52 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {type Backend, InferenceSession, type InferenceSessionHandler, type SessionHandler, Tensor} from 'onnxruntime-common'; -import {Platform} from 'react-native'; +import { + type Backend, + InferenceSession, + type InferenceSessionHandler, + type SessionHandler, + Tensor, +} from 'onnxruntime-common'; +import { Platform } from 'react-native'; -import {binding, type Binding, type JSIBlob, jsiHelper} from './binding'; +import { binding, type Binding, type JSIBlob, jsiHelper } from './binding'; type SupportedTypedArray = Exclude; -const tensorTypeToTypedArray = (type: Tensor.Type):|Float32ArrayConstructor|Int8ArrayConstructor|Int16ArrayConstructor| - Int32ArrayConstructor|BigInt64ArrayConstructor|Float64ArrayConstructor|Uint8ArrayConstructor => { - switch (type) { - case 'float32': - return Float32Array; - case 'int8': - return Int8Array; - case 'uint8': - return Uint8Array; - case 'int16': - return Int16Array; - case 'int32': - return Int32Array; - case 'bool': - return Int8Array; - case 'float64': - return Float64Array; - case 'int64': - /* global BigInt64Array */ - /* eslint no-undef: ["error", { "typeof": true }] */ - return BigInt64Array; - default: - throw new Error(`unsupported type: ${type}`); - } - }; +const tensorTypeToTypedArray = ( + type: Tensor.Type, +): + | Float32ArrayConstructor + | Int8ArrayConstructor + | Int16ArrayConstructor + | Int32ArrayConstructor + | BigInt64ArrayConstructor + | Float64ArrayConstructor + | Uint8ArrayConstructor => { + switch (type) { + case 'float32': + return Float32Array; + case 'int8': + return Int8Array; + case 'uint8': + return Uint8Array; + case 'int16': + return Int16Array; + case 'int32': + return Int32Array; + case 'bool': + return Int8Array; + case 'float64': + return Float64Array; + case 'int64': + /* global BigInt64Array */ + /* eslint no-undef: ["error", { "typeof": true }] */ + return BigInt64Array; + default: + throw new Error(`unsupported type: ${type}`); + } +}; const normalizePath = (path: string): string => { // remove 'file://' prefix in iOS @@ -47,12 +61,12 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; #key: string; - #pathOrBuffer: string|Uint8Array; + #pathOrBuffer: string | Uint8Array; inputNames: string[]; outputNames: string[]; - constructor(pathOrBuffer: string|Uint8Array) { + constructor(pathOrBuffer: string | Uint8Array) { this.#inferenceSession = binding; this.#pathOrBuffer = pathOrBuffer; this.#key = ''; @@ -96,14 +110,18 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { // TODO: implement profiling } - async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): - Promise { + async run( + feeds: SessionHandler.FeedsType, + fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions, + ): Promise { const outputNames: Binding.FetchesType = []; for (const name in fetches) { if (Object.prototype.hasOwnProperty.call(fetches, name)) { if (fetches[name]) { throw new Error( - 'Preallocated output is not supported and only names as string array is allowed as parameter'); + 'Preallocated output is not supported and only names as string array is allowed as parameter', + ); } outputNames.push(name); } @@ -114,12 +132,11 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { return output; } - encodeFeedsType(feeds: SessionHandler.FeedsType): Binding.FeedsType { - const returnValue: {[name: string]: Binding.EncodedTensorType} = {}; + const returnValue: { [name: string]: Binding.EncodedTensorType } = {}; for (const key in feeds) { if (Object.hasOwnProperty.call(feeds, key)) { - let data: JSIBlob|string[]; + let data: JSIBlob | string[]; if (Array.isArray(feeds[key].data)) { data = feeds[key].data as string[]; @@ -165,8 +182,10 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler( + pathOrBuffer: string | Uint8Array, + options?: InferenceSession.SessionOptions, + ): Promise { const handler = new OnnxruntimeSessionHandler(pathOrBuffer); await handler.loadModel(options || {}); return handler; diff --git a/js/react_native/lib/binding.ts b/js/react_native/lib/binding.ts index 5ecf85dcd25ab..9537b47f58fbe 100644 --- a/js/react_native/lib/binding.ts +++ b/js/react_native/lib/binding.ts @@ -2,8 +2,8 @@ // Licensed under the MIT License. // eslint-disable-next-line @typescript-eslint/no-unused-vars -import type {InferenceSession} from 'onnxruntime-common'; -import {NativeModules} from 'react-native'; +import type { InferenceSession } from 'onnxruntime-common'; +import { NativeModules } from 'react-native'; /** * model loading information @@ -29,7 +29,9 @@ interface ModelLoadInfo { * JSIBlob is a blob object that exchange ArrayBuffer by OnnxruntimeJSIHelper. */ export type JSIBlob = { - blobId: string; offset: number; size: number; + blobId: string; + offset: number; + size: number; }; /** @@ -48,7 +50,7 @@ interface EncodedTensor { * the JSIBlob object of the buffer data of the tensor. * if data is string array, it won't be stored as JSIBlob. */ - readonly data: JSIBlob|string[]; + readonly data: JSIBlob | string[]; } /** @@ -61,13 +63,13 @@ export declare namespace Binding { type SessionOptions = InferenceSession.SessionOptions; type RunOptions = InferenceSession.RunOptions; - type FeedsType = {[name: string]: EncodedTensor}; + type FeedsType = { [name: string]: EncodedTensor }; // SessionHanlder FetchesType is different from native module's one. // It's because Java API doesn't support preallocated output values. type FetchesType = string[]; - type ReturnType = {[name: string]: EncodedTensor}; + type ReturnType = { [name: string]: EncodedTensor }; interface InferenceSession { loadModel(modelPath: string, options: SessionOptions): Promise; @@ -78,7 +80,7 @@ export declare namespace Binding { } // export native binding -const {Onnxruntime, OnnxruntimeJSIHelper} = NativeModules; +const { Onnxruntime, OnnxruntimeJSIHelper } = NativeModules; export const binding = Onnxruntime as Binding.InferenceSession; // install JSI helper global functions @@ -86,22 +88,28 @@ OnnxruntimeJSIHelper.install(); declare global { // eslint-disable-next-line no-var - var jsiOnnxruntimeStoreArrayBuffer: ((buffer: ArrayBuffer) => JSIBlob)|undefined; + var jsiOnnxruntimeStoreArrayBuffer: ((buffer: ArrayBuffer) => JSIBlob) | undefined; // eslint-disable-next-line no-var - var jsiOnnxruntimeResolveArrayBuffer: ((blob: JSIBlob) => ArrayBuffer)|undefined; + var jsiOnnxruntimeResolveArrayBuffer: ((blob: JSIBlob) => ArrayBuffer) | undefined; } export const jsiHelper = { - storeArrayBuffer: globalThis.jsiOnnxruntimeStoreArrayBuffer || (() => { - throw new Error( - 'jsiOnnxruntimeStoreArrayBuffer is not found, ' + - 'please make sure OnnxruntimeJSIHelper installation is successful.'); - }), - resolveArrayBuffer: globalThis.jsiOnnxruntimeResolveArrayBuffer || (() => { - throw new Error( - 'jsiOnnxruntimeResolveArrayBuffer is not found, ' + - 'please make sure OnnxruntimeJSIHelper installation is successful.'); - }), + storeArrayBuffer: + globalThis.jsiOnnxruntimeStoreArrayBuffer || + (() => { + throw new Error( + 'jsiOnnxruntimeStoreArrayBuffer is not found, ' + + 'please make sure OnnxruntimeJSIHelper installation is successful.', + ); + }), + resolveArrayBuffer: + globalThis.jsiOnnxruntimeResolveArrayBuffer || + (() => { + throw new Error( + 'jsiOnnxruntimeResolveArrayBuffer is not found, ' + + 'please make sure OnnxruntimeJSIHelper installation is successful.', + ); + }), }; // Remove global functions after installation diff --git a/js/react_native/lib/index.ts b/js/react_native/lib/index.ts index 3bf9da3719e97..65daf2cfe33e6 100644 --- a/js/react_native/lib/index.ts +++ b/js/react_native/lib/index.ts @@ -2,10 +2,10 @@ // Licensed under the MIT License. export * from 'onnxruntime-common'; -import {registerBackend, env} from 'onnxruntime-common'; -import {Platform} from 'react-native'; -import {onnxruntimeBackend} from './backend'; -import {version} from './version'; +import { registerBackend, env } from 'onnxruntime-common'; +import { Platform } from 'react-native'; +import { onnxruntimeBackend } from './backend'; +import { version } from './version'; registerBackend('cpu', onnxruntimeBackend, 1); registerBackend('xnnpack', onnxruntimeBackend, 1); @@ -15,4 +15,4 @@ if (Platform.OS === 'android') { registerBackend('coreml', onnxruntimeBackend, 1); } -Object.defineProperty(env.versions, 'react-native', {value: version, enumerable: true}); +Object.defineProperty(env.versions, 'react-native', { value: version, enumerable: true }); diff --git a/js/react_native/scripts/prepack.ts b/js/react_native/scripts/prepack.ts index 2e43294165a83..83ec1d9b45fd8 100644 --- a/js/react_native/scripts/prepack.ts +++ b/js/react_native/scripts/prepack.ts @@ -20,7 +20,7 @@ function updatePackageJson() { const version = packageCommon.version; packageSelf.dependencies['onnxruntime-common'] = `${version}`; } - fs.writeJSONSync(selfPackageJsonPath, packageSelf, {spaces: 2}); + fs.writeJSONSync(selfPackageJsonPath, packageSelf, { spaces: 2 }); console.log('=== finished updating package.json.'); } diff --git a/js/scripts/prepare-onnx-node-tests.ts b/js/scripts/prepare-onnx-node-tests.ts index 64d6eb6648cfd..91aa63e9e6aff 100644 --- a/js/scripts/prepare-onnx-node-tests.ts +++ b/js/scripts/prepare-onnx-node-tests.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {compareSync} from 'dir-compare'; +import { compareSync } from 'dir-compare'; import fs from 'fs-extra'; import jszip from 'jszip'; import log from 'npmlog'; import * as path from 'path'; -import {downloadZip, extractFile} from './utils'; +import { downloadZip, extractFile } from './utils'; const TEST_DATA_OPSET_VERSIONS = [ ['opset19', '1.14.0'], @@ -49,7 +49,7 @@ const main = async () => { const buffer = await downloadZip(resourceUri); const zip = await jszip.loadAsync(buffer); - const entries = zip.filter(relativePath => relativePath.startsWith(folderPrefix)); + const entries = zip.filter((relativePath) => relativePath.startsWith(folderPrefix)); const testCasesFolder = path.join(JS_TEST_DATA_ROOT, 'node', opset); log.info('PrepareTestData', `Preparing folders under ${testCasesFolder}`); @@ -69,7 +69,9 @@ const main = async () => { for (const entry of entries) { if (!entry.dir) { await extractFile( - entry, fs.createWriteStream(path.join(testCasesFolder, path.relative(folderPrefix, entry.name)))); + entry, + fs.createWriteStream(path.join(testCasesFolder, path.relative(folderPrefix, entry.name))), + ); } } } @@ -83,11 +85,11 @@ const main = async () => { // compare each subfolder to its previous version. If they are same, remove the one in current version. let count = 0; - fs.readdirSync(currentFolder, {withFileTypes: true}).forEach(dir => { + fs.readdirSync(currentFolder, { withFileTypes: true }).forEach((dir) => { const currentDir = path.join(currentFolder, dir.name); const previousDir = path.join(previousFolder, dir.name); if (dir.isDirectory() && fs.existsSync(previousDir) && fs.statSync(previousDir).isDirectory()) { - if (compareSync(currentDir, previousDir, {compareContent: true}).differences === 0) { + if (compareSync(currentDir, previousDir, { compareContent: true }).differences === 0) { fs.removeSync(currentDir); count++; } diff --git a/js/scripts/utils.ts b/js/scripts/utils.ts index 7ef253397de22..e22eeb1bd9217 100644 --- a/js/scripts/utils.ts +++ b/js/scripts/utils.ts @@ -1,47 +1,51 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {WriteStream} from 'fs'; +import { WriteStream } from 'fs'; import * as https from 'https'; -import {JSZipObject} from 'jszip'; +import { JSZipObject } from 'jszip'; -export const downloadZip = async(url: string): Promise => new Promise((resolve, reject) => { - https.get(url, res => { - const {statusCode} = res; - const contentType = res.headers['content-type']; +export const downloadZip = async (url: string): Promise => + new Promise((resolve, reject) => { + https.get(url, (res) => { + const { statusCode } = res; + const contentType = res.headers['content-type']; - if (statusCode === 301 || statusCode === 302) { - downloadZip(res.headers.location!).then(buffer => resolve(buffer), reason => reject(reason)); - return; - } else if (statusCode !== 200) { - throw new Error(`Failed to download build list. HTTP status code = ${statusCode}`); - } - if (!contentType || !/^application\/zip/.test(contentType)) { - throw new Error(`unexpected content type: ${contentType}`); - } + if (statusCode === 301 || statusCode === 302) { + downloadZip(res.headers.location!).then( + (buffer) => resolve(buffer), + (reason) => reject(reason), + ); + return; + } else if (statusCode !== 200) { + throw new Error(`Failed to download build list. HTTP status code = ${statusCode}`); + } + if (!contentType || !/^application\/zip/.test(contentType)) { + throw new Error(`unexpected content type: ${contentType}`); + } - const chunks: Buffer[] = []; - res.on('data', (chunk) => { - chunks.push(chunk); - }); - res.on('end', () => { - resolve(Buffer.concat(chunks)); - }); - res.on('error', err => { - reject(`${err}`); + const chunks: Buffer[] = []; + res.on('data', (chunk) => { + chunks.push(chunk); + }); + res.on('end', () => { + resolve(Buffer.concat(chunks)); + }); + res.on('error', (err) => { + reject(`${err}`); + }); }); }); -}); -export const extractFile = async(entry: JSZipObject, ostream: WriteStream): Promise => - new Promise((resolve, reject) => { - entry.nodeStream() - .pipe(ostream) - .on('finish', - () => { - resolve(); - }) - .on('error', (err) => { - reject(err); - }); - }); +export const extractFile = async (entry: JSZipObject, ostream: WriteStream): Promise => + new Promise((resolve, reject) => { + entry + .nodeStream() + .pipe(ostream) + .on('finish', () => { + resolve(); + }) + .on('error', (err) => { + reject(err); + }); + }); diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index c072ec8be1600..fe1018aab196e 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -4,7 +4,7 @@ 'use strict'; const args = require('minimist')(process.argv, {}); -const bundleMode = args['bundle-mode'] || 'dev'; // 'dev'|'perf' +const bundleMode = args['bundle-mode'] || 'dev'; // 'dev'|'perf' const karmaPlugins = args['karma-plugins'] || undefined; const timeoutMocha = args['timeout-mocha'] || 60000; const forceLocalHost = !!args['force-localhost']; @@ -57,7 +57,7 @@ const hostname = getMachineIpAddress(); // In Node.js v17+, 'localhost' is using IPv6, so need to listen to '::' const listenAddress = Number.parseInt(process.versions.node.split('.')[0]) >= 17 ? '::' : '0.0.0.0'; -module.exports = function(config) { +module.exports = function (config) { config.set({ // global config of your BrowserStack account browserStack: { @@ -69,14 +69,14 @@ module.exports = function(config) { }, frameworks: ['mocha'], files: [ - {pattern: ORT_FILE}, - {pattern: TEST_FILE}, - {pattern: 'test/testdata-file-cache-*.json', included: false, watched: false}, - {pattern: 'test/data/**/*', included: false, nocache: true, watched: false}, - {pattern: 'dist/*.*', included: false, watched: false}, + { pattern: ORT_FILE }, + { pattern: TEST_FILE }, + { pattern: 'test/testdata-file-cache-*.json', included: false, watched: false }, + { pattern: 'test/data/**/*', included: false, nocache: true, watched: false }, + { pattern: 'dist/*.*', included: false, watched: false }, ], plugins: karmaPlugins, - client: {captureConsole: true, mocha: {expose: ['body'], timeout: timeoutMocha}}, + client: { captureConsole: true, mocha: { expose: ['body'], timeout: timeoutMocha } }, reporters: ['mocha', 'BrowserStack'], browsers: [], captureTimeout: 120000, @@ -89,10 +89,10 @@ module.exports = function(config) { listenAddress, customLaunchers: { // Chromium-based browsers - EdgeTest: {base: 'Edge', flags: chromiumFlags, edgeDataDir: userDataDir}, - ChromeTest: {base: 'Chrome', flags: chromiumFlags, chromeDataDir: userDataDir}, - ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags, chromeDataDir: userDataDir}, - FirefoxTest: {base: 'Firefox', profile: userDataDir}, + EdgeTest: { base: 'Edge', flags: chromiumFlags, edgeDataDir: userDataDir }, + ChromeTest: { base: 'Chrome', flags: chromiumFlags, chromeDataDir: userDataDir }, + ChromeCanaryTest: { base: 'ChromeCanary', flags: chromiumFlags, chromeDataDir: userDataDir }, + FirefoxTest: { base: 'Firefox', profile: userDataDir }, // // ==== BrowserStack browsers ==== @@ -100,33 +100,73 @@ module.exports = function(config) { // Windows // - BS_WIN_10_Chrome_91: - {base: 'BrowserStack', os: 'Windows', os_version: '10', browser: 'Chrome', browser_version: '91'}, - BS_WIN_10_Edge_91: - {base: 'BrowserStack', os: 'Windows', os_version: '10', browser: 'Edge', browser_version: '91'}, - BS_WIN_10_Firefox_89: - {base: 'BrowserStack', os: 'Windows', os_version: '10', browser: 'Firefox', browser_version: '89'}, + BS_WIN_10_Chrome_91: { + base: 'BrowserStack', + os: 'Windows', + os_version: '10', + browser: 'Chrome', + browser_version: '91', + }, + BS_WIN_10_Edge_91: { + base: 'BrowserStack', + os: 'Windows', + os_version: '10', + browser: 'Edge', + browser_version: '91', + }, + BS_WIN_10_Firefox_89: { + base: 'BrowserStack', + os: 'Windows', + os_version: '10', + browser: 'Firefox', + browser_version: '89', + }, // macOS // - BS_MAC_11_Safari_14: - {base: 'BrowserStack', os: 'OS X', os_version: 'Big Sur', browser: 'Safari', browser_version: '14.0'}, - BS_MAC_11_Chrome_91: - {base: 'BrowserStack', os: 'OS X', os_version: 'Big Sur', browser: 'Chrome', browser_version: '91'}, + BS_MAC_11_Safari_14: { + base: 'BrowserStack', + os: 'OS X', + os_version: 'Big Sur', + browser: 'Safari', + browser_version: '14.0', + }, + BS_MAC_11_Chrome_91: { + base: 'BrowserStack', + os: 'OS X', + os_version: 'Big Sur', + browser: 'Chrome', + browser_version: '91', + }, // iPhone // - BS_IOS_14_iPhoneXS: {base: 'BrowserStack', device: 'iPhone XS', real_mobile: true, os: 'ios', os_version: '14'}, - BS_IOS_13_iPhoneXS: {base: 'BrowserStack', device: 'iPhone XS', real_mobile: true, os: 'ios', os_version: '13'}, + BS_IOS_14_iPhoneXS: { base: 'BrowserStack', device: 'iPhone XS', real_mobile: true, os: 'ios', os_version: '14' }, + BS_IOS_13_iPhoneXS: { base: 'BrowserStack', device: 'iPhone XS', real_mobile: true, os: 'ios', os_version: '13' }, // Android // - BS_ANDROID_11_Pixel_5: - {base: 'BrowserStack', device: 'Google Pixel 5', real_mobile: true, os: 'android', os_version: '11.0'}, - BS_ANDROID_11_Galaxy_S_21: - {base: 'BrowserStack', device: 'Samsung Galaxy S21', real_mobile: true, os: 'android', os_version: '11.0'}, - BS_ANDROID_10_Pixel_4: - {base: 'BrowserStack', device: 'Google Pixel 4', real_mobile: true, os: 'android', os_version: '10.0'} - } + BS_ANDROID_11_Pixel_5: { + base: 'BrowserStack', + device: 'Google Pixel 5', + real_mobile: true, + os: 'android', + os_version: '11.0', + }, + BS_ANDROID_11_Galaxy_S_21: { + base: 'BrowserStack', + device: 'Samsung Galaxy S21', + real_mobile: true, + os: 'android', + os_version: '11.0', + }, + BS_ANDROID_10_Pixel_4: { + base: 'BrowserStack', + device: 'Google Pixel 4', + real_mobile: true, + os: 'android', + os_version: '10.0', + }, + }, }); }; diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 7176823c9bf13..5aa799161f4bf 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -2,17 +2,19 @@ // Licensed under the MIT License. /* eslint-disable import/no-internal-modules */ -import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; +import { Backend, InferenceSession, InferenceSessionHandler } from 'onnxruntime-common'; -import {Session} from './onnxjs/session'; -import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference'; +import { Session } from './onnxjs/session'; +import { OnnxjsSessionHandler } from './onnxjs/session-handler-inference'; class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function async init(): Promise {} - async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler( + pathOrBuffer: string | Uint8Array, + options?: InferenceSession.SessionOptions, + ): Promise { // NOTE: Session.Config(from onnx.js) is not compatible with InferenceSession.SessionOptions(from // onnxruntime-common). // In future we should remove Session.Config and use InferenceSession.SessionOptions. diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts index 475a0243ebd3d..7dfe7ee05a1d3 100644 --- a/js/web/lib/backend-wasm-inference.ts +++ b/js/web/lib/backend-wasm-inference.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index 09dac3a85311c..7332b3f97eba0 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -1,19 +1,27 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; +import { InferenceSession, TrainingSessionHandler } from 'onnxruntime-common'; -import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; -import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training'; +import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; +import { OnnxruntimeWebAssemblyTrainingSessionHandler } from './wasm/session-handler-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( - checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, - evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, - options: InferenceSession.SessionOptions): Promise { + checkpointStateUriOrBuffer: string | Uint8Array, + trainModelUriOrBuffer: string | Uint8Array, + evalModelUriOrBuffer: string | Uint8Array, + optimizerModelUriOrBuffer: string | Uint8Array, + options: InferenceSession.SessionOptions, + ): Promise { const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); await handler.createTrainingSession( - checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); + checkpointStateUriOrBuffer, + trainModelUriOrBuffer, + evalModelUriOrBuffer, + optimizerModelUriOrBuffer, + options, + ); return Promise.resolve(handler); } } diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index a3a213392af22..7bef538b26063 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; +import { Backend, env, InferenceSession, InferenceSessionHandler } from 'onnxruntime-common'; -import {initializeOrtEp, initializeWebAssemblyAndOrtRuntime} from './wasm/proxy-wrapper'; -import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference'; -import {scriptSrc} from './wasm/wasm-utils-import'; +import { initializeOrtEp, initializeWebAssemblyAndOrtRuntime } from './wasm/proxy-wrapper'; +import { OnnxruntimeWebAssemblySessionHandler } from './wasm/session-handler-inference'; +import { scriptSrc } from './wasm/wasm-utils-import'; /** * This function initializes all flags for WebAssembly. @@ -21,8 +21,9 @@ export const initializeFlags = (): void => { if (env.wasm.simd === false) { // eslint-disable-next-line no-console console.warn( - 'Deprecated property "env.wasm.simd" is set to false. ' + - 'non-SIMD build is no longer provided, and this setting will be ignored.'); + 'Deprecated property "env.wasm.simd" is set to false. ' + + 'non-SIMD build is no longer provided, and this setting will be ignored.', + ); } if (typeof env.wasm.proxy !== 'boolean') { @@ -49,7 +50,7 @@ export const initializeFlags = (): void => { env.wasm.numThreads = 1; } else { const numCpuLogicalCores = - typeof navigator === 'undefined' ? require('node:os').cpus().length : navigator.hardwareConcurrency; + typeof navigator === 'undefined' ? require('node:os').cpus().length : navigator.hardwareConcurrency; env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2)); } } @@ -81,12 +82,18 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { // performe EP specific initialization await initializeOrtEp(backendName); } - createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions): - Promise; - createInferenceSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): - Promise; - async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + createInferenceSessionHandler( + path: string, + options?: InferenceSession.SessionOptions, + ): Promise; + createInferenceSessionHandler( + buffer: Uint8Array, + options?: InferenceSession.SessionOptions, + ): Promise; + async createInferenceSessionHandler( + pathOrBuffer: string | Uint8Array, + options?: InferenceSession.SessionOptions, + ): Promise { const handler = new OnnxruntimeWebAssemblySessionHandler(); await handler.loadModel(pathOrBuffer, options); return Promise.resolve(handler); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 86c05b9a2fa15..321394466b365 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -11,8 +11,8 @@ export * from 'onnxruntime-common'; import * as ort from 'onnxruntime-common'; export default ort; -import {registerBackend, env} from 'onnxruntime-common'; -import {version} from './version'; +import { registerBackend, env } from 'onnxruntime-common'; +import { version } from './version'; if (!BUILD_DEFS.DISABLE_WEBGL) { const onnxjsBackend = require('./backend-onnxjs').onnxjsBackend; @@ -20,8 +20,9 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : - require('./backend-wasm-training').wasmBackend; + const wasmBackend = BUILD_DEFS.DISABLE_TRAINING + ? require('./backend-wasm-inference').wasmBackend + : require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_JSEP) { registerBackend('webgpu', wasmBackend, 5); registerBackend('webnn', wasmBackend, 5); @@ -30,4 +31,4 @@ if (!BUILD_DEFS.DISABLE_WASM) { registerBackend('wasm', wasmBackend, 10); } -Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); +Object.defineProperty(env.versions, 'web', { value: version, enumerable: true }); diff --git a/js/web/lib/onnxjs/attribute-with-cache-key.ts b/js/web/lib/onnxjs/attribute-with-cache-key.ts index 5d47570f267a6..a5470bb107769 100644 --- a/js/web/lib/onnxjs/attribute-with-cache-key.ts +++ b/js/web/lib/onnxjs/attribute-with-cache-key.ts @@ -9,8 +9,10 @@ class AttributeWithCacheKeyImpl { private key: string; public get cacheKey(): string { if (!this.key) { - this.key = - Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); + this.key = Object.getOwnPropertyNames(this) + .sort() + .map((name) => `${(this as Record)[name]}`) + .join(';'); } return this.key; } @@ -20,5 +22,6 @@ export interface AttributeWithCacheKey { readonly cacheKey: string; } -export const createAttributeWithCacheKey = >(attribute: T): T&AttributeWithCacheKey => - new AttributeWithCacheKeyImpl(attribute) as unknown as T & AttributeWithCacheKey; +export const createAttributeWithCacheKey = >( + attribute: T, +): T & AttributeWithCacheKey => new AttributeWithCacheKeyImpl(attribute) as unknown as T & AttributeWithCacheKey; diff --git a/js/web/lib/onnxjs/attribute.ts b/js/web/lib/onnxjs/attribute.ts index 9abdb2943a552..0f1086ad51e91 100644 --- a/js/web/lib/onnxjs/attribute.ts +++ b/js/web/lib/onnxjs/attribute.ts @@ -3,10 +3,10 @@ import Long from 'long'; -import {onnxruntime} from './ort-schema/flatbuffers/ort-generated'; -import {onnx} from './ort-schema/protobuf/onnx'; -import {Tensor} from './tensor'; -import {decodeUtf8String, LongUtil} from './util'; +import { onnxruntime } from './ort-schema/flatbuffers/ort-generated'; +import { onnx } from './ort-schema/protobuf/onnx'; +import { Tensor } from './tensor'; +import { decodeUtf8String, LongUtil } from './util'; import ortFbs = onnxruntime.experimental.fbs; @@ -30,7 +30,7 @@ type ValueTypes = Attribute.DataTypeMap[Attribute.DataType]; type Value = [ValueTypes, Attribute.DataType]; export class Attribute { - constructor(attributes: onnx.IAttributeProto[]|ortFbs.Attribute[]|null|undefined) { + constructor(attributes: onnx.IAttributeProto[] | ortFbs.Attribute[] | null | undefined) { this._attributes = new Map(); if (attributes !== null && attributes !== undefined) { for (const attr of attributes) { @@ -85,7 +85,10 @@ export class Attribute { } private get( - key: string, type: Attribute.DataType, defaultValue?: V): V { + key: string, + type: Attribute.DataType, + defaultValue?: V, + ): V { const valueAndType = this._attributes.get(key); if (valueAndType === undefined) { if (defaultValue !== undefined) { @@ -99,8 +102,8 @@ export class Attribute { return valueAndType[0] as V; } - private static getType(attr: onnx.IAttributeProto|ortFbs.Attribute): Attribute.DataType { - const type = attr instanceof onnx.AttributeProto ? (attr).type : (attr as ortFbs.Attribute).type(); + private static getType(attr: onnx.IAttributeProto | ortFbs.Attribute): Attribute.DataType { + const type = attr instanceof onnx.AttributeProto ? attr.type : (attr as ortFbs.Attribute).type(); switch (type) { case onnx.AttributeProto.AttributeType.FLOAT: return 'float'; @@ -123,7 +126,7 @@ export class Attribute { } } - private static getValue(attr: onnx.IAttributeProto|ortFbs.Attribute) { + private static getValue(attr: onnx.IAttributeProto | ortFbs.Attribute) { const attrType = attr instanceof onnx.AttributeProto ? attr.type : (attr as ortFbs.Attribute).type(); if (attrType === onnx.AttributeProto.AttributeType.GRAPH || attrType === onnx.AttributeProto.AttributeType.GRAPHS) { throw new Error('graph attribute is not supported yet'); @@ -138,7 +141,7 @@ export class Attribute { // cast LONG[] to number[] if (attrType === onnx.AttributeProto.AttributeType.INTS) { - const arr = (value as Array); + const arr = value as Array; const numberValue: number[] = new Array(arr.length); for (let i = 0; i < arr.length; i++) { @@ -151,18 +154,19 @@ export class Attribute { // cast onnx.TensorProto to onnxjs.Tensor if (attrType === onnx.AttributeProto.AttributeType.TENSOR) { - return attr instanceof onnx.AttributeProto ? Tensor.fromProto(value as onnx.ITensorProto) : - Tensor.fromOrtTensor(value as ortFbs.Tensor); + return attr instanceof onnx.AttributeProto + ? Tensor.fromProto(value as onnx.ITensorProto) + : Tensor.fromOrtTensor(value as ortFbs.Tensor); } // cast onnx.TensorProto[] to onnxjs.Tensor[] if (attrType === onnx.AttributeProto.AttributeType.TENSORS) { if (attr instanceof onnx.AttributeProto) { const tensorProtos = value as onnx.ITensorProto[]; - return tensorProtos.map(value => Tensor.fromProto(value)); + return tensorProtos.map((value) => Tensor.fromProto(value)); } else if (attr instanceof ortFbs.Attribute) { const tensorProtos = value as ortFbs.Tensor[]; - return tensorProtos.map(value => Tensor.fromOrtTensor(value)); + return tensorProtos.map((value) => Tensor.fromOrtTensor(value)); } } @@ -189,9 +193,10 @@ export class Attribute { return value as ValueTypes; } - private static getValueNoCheck(attr: onnx.IAttributeProto|ortFbs.Attribute) { - return attr instanceof (onnx.AttributeProto) ? this.getValueNoCheckFromOnnxFormat(attr) : - this.getValueNoCheckFromOrtFormat(attr as ortFbs.Attribute); + private static getValueNoCheck(attr: onnx.IAttributeProto | ortFbs.Attribute) { + return attr instanceof onnx.AttributeProto + ? this.getValueNoCheckFromOnnxFormat(attr) + : this.getValueNoCheckFromOrtFormat(attr as ortFbs.Attribute); } private static getValueNoCheckFromOnnxFormat(attr: onnx.IAttributeProto) { diff --git a/js/web/lib/onnxjs/backend.ts b/js/web/lib/onnxjs/backend.ts index f402b820e76e1..5544a0cc6d2e3 100644 --- a/js/web/lib/onnxjs/backend.ts +++ b/js/web/lib/onnxjs/backend.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {WebGLBackend} from './backends/backend-webgl'; -import {Graph} from './graph'; -import {Operator} from './operators'; -import {OpSet} from './opset'; -import {Session} from './session'; +import { WebGLBackend } from './backends/backend-webgl'; +import { Graph } from './graph'; +import { Operator } from './operators'; +import { OpSet } from './opset'; +import { Session } from './session'; export interface InferenceHandler { /** @@ -61,7 +61,7 @@ export interface Backend { * initialize the backend. will be called only once, when the first time the * backend it to be used */ - initialize(): boolean|Promise; + initialize(): boolean | Promise; /** * create an instance of SessionHandler to use in a Session object's lifecycle @@ -77,15 +77,15 @@ export interface Backend { // caches all initialized backend instances const backendsCache: Map = new Map(); -export const backend: {[name: string]: Backend} = { - webgl: new WebGLBackend() +export const backend: { [name: string]: Backend } = { + webgl: new WebGLBackend(), }; /** * Resolve a reference to the backend. If a hint is specified, the corresponding * backend will be used. */ -export async function resolveBackend(hint?: string|readonly string[]): Promise { +export async function resolveBackend(hint?: string | readonly string[]): Promise { if (!hint) { return resolveBackend(['webgl']); } else { @@ -107,7 +107,7 @@ export async function resolveBackend(hint?: string|readonly string[]): Promise { +async function tryLoadBackend(backendHint: string): Promise { const backendObj = backend; if (typeof backendObj[backendHint] !== 'undefined' && isBackend(backendObj[backendHint])) { @@ -131,9 +131,12 @@ function isBackend(obj: unknown) { // check if an object is a Backend instance if ( - 'initialize' in o && typeof o.initialize === 'function' && // initialize() - 'createSessionHandler' in o && typeof o.createSessionHandler === 'function' && // createSessionHandler() - 'dispose' in o && typeof o.dispose === 'function' // dispose() + 'initialize' in o && + typeof o.initialize === 'function' && // initialize() + 'createSessionHandler' in o && + typeof o.createSessionHandler === 'function' && // createSessionHandler() + 'dispose' in o && + typeof o.dispose === 'function' // dispose() ) { return true; } diff --git a/js/web/lib/onnxjs/backends/backend-webgl.ts b/js/web/lib/onnxjs/backends/backend-webgl.ts index 21ed7e38b9f86..a122068eb67bc 100644 --- a/js/web/lib/onnxjs/backends/backend-webgl.ts +++ b/js/web/lib/onnxjs/backends/backend-webgl.ts @@ -1,15 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env} from 'onnxruntime-common'; +import { env } from 'onnxruntime-common'; -import {Backend, SessionHandler} from '../backend'; -import {Logger} from '../instrument'; -import {Session} from '../session'; +import { Backend, SessionHandler } from '../backend'; +import { Logger } from '../instrument'; +import { Session } from '../session'; -import {WebGLSessionHandler} from './webgl/session-handler'; -import {WebGLContext} from './webgl/webgl-context'; -import {createWebGLContext} from './webgl/webgl-context-factory'; +import { WebGLSessionHandler } from './webgl/session-handler'; +import { WebGLContext } from './webgl/webgl-context'; +import { createWebGLContext } from './webgl/webgl-context-factory'; /** * WebGLBackend is the entry point for all WebGL opeartions @@ -19,38 +19,38 @@ import {createWebGLContext} from './webgl/webgl-context-factory'; export class WebGLBackend implements Backend { glContext: WebGLContext; - get contextId(): 'webgl'|'webgl2'|undefined { + get contextId(): 'webgl' | 'webgl2' | undefined { return env.webgl.contextId; } - set contextId(value: 'webgl'|'webgl2'|undefined) { + set contextId(value: 'webgl' | 'webgl2' | undefined) { env.webgl.contextId = value; } - get matmulMaxBatchSize(): number|undefined { + get matmulMaxBatchSize(): number | undefined { return env.webgl.matmulMaxBatchSize; } - set matmulMaxBatchSize(value: number|undefined) { + set matmulMaxBatchSize(value: number | undefined) { env.webgl.matmulMaxBatchSize = value; } - get textureCacheMode(): 'initializerOnly'|'full'|undefined { + get textureCacheMode(): 'initializerOnly' | 'full' | undefined { return env.webgl.textureCacheMode; } - set textureCacheMode(value: 'initializerOnly'|'full'|undefined) { + set textureCacheMode(value: 'initializerOnly' | 'full' | undefined) { env.webgl.textureCacheMode = value; } - get pack(): boolean|undefined { + get pack(): boolean | undefined { return env.webgl.pack; } - set pack(value: boolean|undefined) { + set pack(value: boolean | undefined) { env.webgl.pack = value; } - get async(): boolean|undefined { + get async(): boolean | undefined { return env.webgl.async; } - set async(value: boolean|undefined) { + set async(value: boolean | undefined) { env.webgl.async = value; } @@ -73,14 +73,15 @@ export class WebGLBackend implements Backend { Logger.setWithEnv(env); if (!env.webgl.context) { - Object.defineProperty(env.webgl, 'context', {value: this.glContext.gl}); + Object.defineProperty(env.webgl, 'context', { value: this.glContext.gl }); } Logger.verbose( - 'WebGLBackend', - `Created WebGLContext: ${typeof this.glContext} with matmulMaxBatchSize: ${ - this.matmulMaxBatchSize}; textureCacheMode: ${this.textureCacheMode}; pack: ${this.pack}; async: ${ - this.async}.`); + 'WebGLBackend', + `Created WebGLContext: ${typeof this.glContext} with matmulMaxBatchSize: ${ + this.matmulMaxBatchSize + }; textureCacheMode: ${this.textureCacheMode}; pack: ${this.pack}; async: ${this.async}.`, + ); return true; } catch (e) { Logger.warning('WebGLBackend', `Unable to initialize WebGLBackend. ${e}`); diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-array-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-array-lib.ts index f5c7252f3ce8b..dac6fb7dfc104 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-array-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-array-lib.ts @@ -1,22 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions'; +import { GlslContext, GlslLib, GlslLibRoutine } from './glsl-definitions'; /** * This library produces routines needed for non-constant access to uniform arrays */ export class ArrayGlslLib extends GlslLib { - getFunctions(): {[name: string]: GlslLibRoutine} { + getFunctions(): { [name: string]: GlslLibRoutine } { return this.generate(); } - getCustomTypes(): {[name: string]: string} { + getCustomTypes(): { [name: string]: string } { return {}; } constructor(context: GlslContext) { super(context); } - protected generate(): {[name: string]: GlslLibRoutine} { - const result: {[name: string]: GlslLibRoutine} = {}; + protected generate(): { [name: string]: GlslLibRoutine } { + const result: { [name: string]: GlslLibRoutine } = {}; for (let i = 1; i <= 16; i++) { result[`setItem${i}`] = new GlslLibRoutine(this.generateSetItem(i)); result[`getItem${i}`] = new GlslLibRoutine(this.generateGetItem(i)); diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts index 717233182ed8a..70bd4fb8ab02b 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts @@ -1,13 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {ArrayUtil, BroadcastUtil, ShapeUtil} from '../../util'; - -import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions'; -import {getGlsl} from './glsl-source'; -import {squeezeShape} from './texture-layout-strategy'; -import {TextureLayout} from './types'; -import {generateShaderFuncNameFromInputSamplerName, generateShaderFuncNameFromInputSamplerNameAtOutCoords, getCoordsDataType, getGlChannels, getSqueezedParams, squeezeInputShape} from './utils'; +import { ArrayUtil, BroadcastUtil, ShapeUtil } from '../../util'; + +import { GlslContext, GlslLib, GlslLibRoutine } from './glsl-definitions'; +import { getGlsl } from './glsl-source'; +import { squeezeShape } from './texture-layout-strategy'; +import { TextureLayout } from './types'; +import { + generateShaderFuncNameFromInputSamplerName, + generateShaderFuncNameFromInputSamplerNameAtOutCoords, + getCoordsDataType, + getGlChannels, + getSqueezedParams, + squeezeInputShape, +} from './utils'; /** * GLSL Library responsible for data types and routines for manipulating @@ -19,7 +26,7 @@ export class CoordsGlslLib extends GlslLib { constructor(context: GlslContext) { super(context); } - getFunctions(): {[name: string]: GlslLibRoutine} { + getFunctions(): { [name: string]: GlslLibRoutine } { return { ...this.offsetToCoords(), ...this.coordsToOffset(), @@ -28,7 +35,7 @@ export class CoordsGlslLib extends GlslLib { // TODO return these only when packing is enabled. ...this.getCommonUtilFuncs(), ...this.getInputsSamplingSnippets(), - ...this.getOutputSamplingSnippet() + ...this.getOutputSamplingSnippet(), }; } getCustomTypes() { @@ -38,7 +45,7 @@ export class CoordsGlslLib extends GlslLib { * Produces a function that can map from * 2D normalzied coordinates (s,t) to a flat offset */ - protected offsetToCoords(): {[name: string]: GlslLibRoutine} { + protected offsetToCoords(): { [name: string]: GlslLibRoutine } { const funcName = 'offsetToCoords'; return { offsetToCoords: new GlslLibRoutine(` @@ -48,7 +55,7 @@ export class CoordsGlslLib extends GlslLib { vec2 coords = (vec2(s,t) + vec2(0.5,0.5)) / vec2(width, height); return coords; } - `) + `), }; } @@ -56,7 +63,7 @@ export class CoordsGlslLib extends GlslLib { * Produces a function that can map from * 2D normalzied coordinates (s,t) to a flat offset */ - protected coordsToOffset(): {[name: string]: GlslLibRoutine} { + protected coordsToOffset(): { [name: string]: GlslLibRoutine } { const funcName = 'coordsToOffset'; return { coordsToOffset: new GlslLibRoutine(` @@ -66,7 +73,7 @@ export class CoordsGlslLib extends GlslLib { int offset = int(t) * width + int(s); return offset; } - `) + `), }; } @@ -74,7 +81,7 @@ export class CoordsGlslLib extends GlslLib { * Generates code for output sampler. */ - protected getOutputSamplingSnippet(): {[name: string]: GlslLibRoutine} { + protected getOutputSamplingSnippet(): { [name: string]: GlslLibRoutine } { const outputLayout = this.context.outputTextureLayout; if (outputLayout.isPacked) { return this.getPackedOutputSamplingSnippet(outputLayout); @@ -86,10 +93,10 @@ export class CoordsGlslLib extends GlslLib { /** * Generates code for packed output sampler. */ - protected getPackedOutputSamplingSnippet(outputLayout: TextureLayout): {[name: string]: GlslLibRoutine} { + protected getPackedOutputSamplingSnippet(outputLayout: TextureLayout): { [name: string]: GlslLibRoutine } { const outShape = outputLayout.unpackedShape; const outTexShape = [outputLayout.width, outputLayout.height]; - const result: {[name: string]: GlslLibRoutine} = {}; + const result: { [name: string]: GlslLibRoutine } = {}; const funcName = 'getOutputCoords'; switch (outShape.length) { case 0: @@ -102,8 +109,10 @@ export class CoordsGlslLib extends GlslLib { result[funcName] = this.getOutputPacked2DCoords(outShape as [number, number], outTexShape as [number, number]); break; case 3: - result[funcName] = - this.getOutputPacked3DCoords(outShape as [number, number, number], outTexShape as [number, number]); + result[funcName] = this.getOutputPacked3DCoords( + outShape as [number, number, number], + outTexShape as [number, number], + ); break; default: result[funcName] = this.getOutputPackedNDCoords(outShape, outTexShape as [number, number]); @@ -124,10 +133,10 @@ export class CoordsGlslLib extends GlslLib { /** * Generates code for unpacked output sampler. */ - protected getUnpackedOutputSamplingSnippet(outputLayout: TextureLayout): {[name: string]: GlslLibRoutine} { + protected getUnpackedOutputSamplingSnippet(outputLayout: TextureLayout): { [name: string]: GlslLibRoutine } { const outShape = outputLayout.unpackedShape; const outTexShape = [outputLayout.width, outputLayout.height]; - const result: {[name: string]: GlslLibRoutine} = {}; + const result: { [name: string]: GlslLibRoutine } = {}; const funcName = 'getOutputCoords'; switch (outShape.length) { case 0: @@ -137,24 +146,34 @@ export class CoordsGlslLib extends GlslLib { result[funcName] = this.getOutputUnpacked1DCoords(outShape as [number], outTexShape as [number, number]); break; case 2: - result[funcName] = - this.getOutputUnpacked2DCoords(outShape as [number, number], outTexShape as [number, number]); + result[funcName] = this.getOutputUnpacked2DCoords( + outShape as [number, number], + outTexShape as [number, number], + ); break; case 3: - result[funcName] = - this.getOutputUnpacked3DCoords(outShape as [number, number, number], outTexShape as [number, number]); + result[funcName] = this.getOutputUnpacked3DCoords( + outShape as [number, number, number], + outTexShape as [number, number], + ); break; case 4: result[funcName] = this.getOutputUnpacked4DCoords( - outShape as [number, number, number, number], outTexShape as [number, number]); + outShape as [number, number, number, number], + outTexShape as [number, number], + ); break; case 5: result[funcName] = this.getOutputUnpacked5DCoords( - outShape as [number, number, number, number, number], outTexShape as [number, number]); + outShape as [number, number, number, number, number], + outTexShape as [number, number], + ); break; case 6: result[funcName] = this.getOutputUnpacked6DCoords( - outShape as [number, number, number, number, number, number], outTexShape as [number, number]); + outShape as [number, number, number, number, number, number], + outTexShape as [number, number], + ); break; default: throw new Error(`Unsupported output dimensionality: ${outShape.length}`); @@ -301,7 +320,8 @@ export class CoordsGlslLib extends GlslLib { for (let b = 2; b < shape.length - 1; b++) { texelsInBatchN *= shape[shape.length - b - 1]; - batches = ` + batches = + ` int b${b} = index / ${texelsInBatchN}; index -= b${b} * ${texelsInBatchN}; ` + batches; @@ -377,16 +397,16 @@ export class CoordsGlslLib extends GlslLib { strides[i] = strides[i + 1] * shape[i + 1]; } const coordsToCompute = ['r', 'c', 'd']; - const coordsFromIndexSnippet = - strides - .map((stride, i) => { - const line1 = `int ${coordsToCompute[i]} = index / ${stride}`; - const line2 = i === strides.length - 1 ? - `int ${coordsToCompute[i + 1]} = index - ${coordsToCompute[i]} * ${stride}` : - `index -= ${coordsToCompute[i]} * ${stride}`; - return `${line1}; ${line2};`; - }) - .join(''); + const coordsFromIndexSnippet = strides + .map((stride, i) => { + const line1 = `int ${coordsToCompute[i]} = index / ${stride}`; + const line2 = + i === strides.length - 1 + ? `int ${coordsToCompute[i + 1]} = index - ${coordsToCompute[i]} * ${stride}` + : `index -= ${coordsToCompute[i]} * ${stride}`; + return `${line1}; ${line2};`; + }) + .join(''); source = ` ivec3 getOutputCoords() { @@ -403,8 +423,10 @@ export class CoordsGlslLib extends GlslLib { /** * Unpacked 4D output coordinates. */ - protected getOutputUnpacked4DCoords(shape: [number, number, number, number], texShape: [number, number]): - GlslLibRoutine { + protected getOutputUnpacked4DCoords( + shape: [number, number, number, number], + texShape: [number, number], + ): GlslLibRoutine { let source = ''; const rank = shape.length; @@ -419,16 +441,16 @@ export class CoordsGlslLib extends GlslLib { strides[i] = strides[i + 1] * shape[i + 1]; } const coordsToCompute = ['r', 'c', 'd', 'd2']; - const coordsFromIndexSnippet = - strides - .map((stride, i) => { - const line1 = `int ${coordsToCompute[i]} = index / ${stride}`; - const line2 = i === strides.length - 1 ? - `int ${coordsToCompute[i + 1]} = index - ${coordsToCompute[i]} * ${stride}` : - `index -= ${coordsToCompute[i]} * ${stride}`; - return `${line1}; ${line2};`; - }) - .join(''); + const coordsFromIndexSnippet = strides + .map((stride, i) => { + const line1 = `int ${coordsToCompute[i]} = index / ${stride}`; + const line2 = + i === strides.length - 1 + ? `int ${coordsToCompute[i + 1]} = index - ${coordsToCompute[i]} * ${stride}` + : `index -= ${coordsToCompute[i]} * ${stride}`; + return `${line1}; ${line2};`; + }) + .join(''); source = ` ivec4 getOutputCoords() { @@ -445,8 +467,10 @@ export class CoordsGlslLib extends GlslLib { /** * Unpacked 5D output coordinates. */ - protected getOutputUnpacked5DCoords(shape: [number, number, number, number, number], texShape: [number, number]): - GlslLibRoutine { + protected getOutputUnpacked5DCoords( + shape: [number, number, number, number, number], + texShape: [number, number], + ): GlslLibRoutine { let source = ''; const rank = shape.length; @@ -461,16 +485,16 @@ export class CoordsGlslLib extends GlslLib { strides[i] = strides[i + 1] * shape[i + 1]; } const coordsToCompute = ['r', 'c', 'd', 'd2', 'd3']; - const coordsFromIndexSnippet = - strides - .map((stride, i) => { - const line1 = `int ${coordsToCompute[i]} = index / ${stride}`; - const line2 = i === strides.length - 1 ? - `int ${coordsToCompute[i + 1]} = index - ${coordsToCompute[i]} * ${stride}` : - `index -= ${coordsToCompute[i]} * ${stride}`; - return `${line1}; ${line2};`; - }) - .join(''); + const coordsFromIndexSnippet = strides + .map((stride, i) => { + const line1 = `int ${coordsToCompute[i]} = index / ${stride}`; + const line2 = + i === strides.length - 1 + ? `int ${coordsToCompute[i + 1]} = index - ${coordsToCompute[i]} * ${stride}` + : `index -= ${coordsToCompute[i]} * ${stride}`; + return `${line1}; ${line2};`; + }) + .join(''); source = ` ivec5 getOutputCoords() { @@ -487,9 +511,10 @@ export class CoordsGlslLib extends GlslLib { /** * Unpacked 6D output coordinates. */ - protected getOutputUnpacked6DCoords(shape: [number, number, number, number, number, number], texShape: [ - number, number - ]): GlslLibRoutine { + protected getOutputUnpacked6DCoords( + shape: [number, number, number, number, number, number], + texShape: [number, number], + ): GlslLibRoutine { let source = ''; const rank = shape.length; @@ -504,16 +529,16 @@ export class CoordsGlslLib extends GlslLib { strides[i] = strides[i + 1] * shape[i + 1]; } const coordsToCompute = ['r', 'c', 'd', 'd2', 'd3', 'd4']; - const coordsFromIndexSnippet = - strides - .map((stride, i) => { - const line1 = `int ${coordsToCompute[i]} = index / ${stride}`; - const line2 = i === strides.length - 1 ? - `int ${coordsToCompute[i + 1]} = index - ${coordsToCompute[i]} * ${stride}` : - `index -= ${coordsToCompute[i]} * ${stride}`; - return `${line1}; ${line2};`; - }) - .join(''); + const coordsFromIndexSnippet = strides + .map((stride, i) => { + const line1 = `int ${coordsToCompute[i]} = index / ${stride}`; + const line2 = + i === strides.length - 1 + ? `int ${coordsToCompute[i + 1]} = index - ${coordsToCompute[i]} * ${stride}` + : `index -= ${coordsToCompute[i]} * ${stride}`; + return `${line1}; ${line2};`; + }) + .join(''); source = ` ivec6 getOutputCoords() { @@ -530,8 +555,8 @@ export class CoordsGlslLib extends GlslLib { /** * Generates code for common UV coords computation utility functions. */ - protected getCommonUtilFuncs(): {[name: string]: GlslLibRoutine} { - const result: {[name: string]: GlslLibRoutine} = {}; + protected getCommonUtilFuncs(): { [name: string]: GlslLibRoutine } { + const result: { [name: string]: GlslLibRoutine } = {}; let funcName = 'uvFromFlat'; result[funcName] = new GlslLibRoutine(` vec2 uvFromFlat(int texNumR, int texNumC, int index) { @@ -583,8 +608,8 @@ export class CoordsGlslLib extends GlslLib { /** * Constructing snippets for inputs */ - protected getInputsSamplingSnippets(): {[name: string]: GlslLibRoutine} { - const result: {[name: string]: GlslLibRoutine} = {}; + protected getInputsSamplingSnippets(): { [name: string]: GlslLibRoutine } { + const result: { [name: string]: GlslLibRoutine } = {}; const outputLayout = this.context.outputTextureLayout; this.context.programInfo.inputNames.forEach((samplerName, i) => { const inputLayout = this.context.inputTextureLayouts[i]; @@ -598,11 +623,19 @@ export class CoordsGlslLib extends GlslLib { const outCoordFuncName = generateShaderFuncNameFromInputSamplerNameAtOutCoords(samplerName); if (inputLayout.unpackedShape.length <= outputLayout.unpackedShape.length) { if (inputLayout.isPacked) { - result[outCoordFuncName] = - this.getPackedSamplerAtOutputCoords(outCoordFuncName, inputLayout, outputLayout, samplerName); + result[outCoordFuncName] = this.getPackedSamplerAtOutputCoords( + outCoordFuncName, + inputLayout, + outputLayout, + samplerName, + ); } else { - result[outCoordFuncName] = - this.getUnpackedSamplerAtOutputCoords(outCoordFuncName, inputLayout, outputLayout, samplerName); + result[outCoordFuncName] = this.getUnpackedSamplerAtOutputCoords( + outCoordFuncName, + inputLayout, + outputLayout, + samplerName, + ); } } }); @@ -614,7 +647,11 @@ export class CoordsGlslLib extends GlslLib { * Constructing snippets for output coordinates of samplers */ protected getPackedSamplerAtOutputCoords( - funcName: string, inputLayout: TextureLayout, outputLayout: TextureLayout, name: string): GlslLibRoutine { + funcName: string, + inputLayout: TextureLayout, + outputLayout: TextureLayout, + name: string, + ): GlslLibRoutine { const inShape = inputLayout.unpackedShape; const outShape = outputLayout.unpackedShape; const texName = name; @@ -635,7 +672,7 @@ export class CoordsGlslLib extends GlslLib { } else if (outRank < 2 && broadcastDims.length >= 1) { coordsSnippet = 'coords = 0;'; } else { - coordsSnippet = broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`).join('\n'); + coordsSnippet = broadcastDims.map((d) => `coords.${fields[d + rankDiff]} = 0;`).join('\n'); } let unpackedCoordsSnippet = ''; if (outRank < 2 && inRank > 0) { @@ -671,8 +708,7 @@ export class CoordsGlslLib extends GlslLib { if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) { output = 'return vec4(outputValue.x);'; } else if (broadcastDims.indexOf(rows) > -1) { - output = 'return vec4(outputValue.x, outputValue.y, ' + - 'outputValue.x, outputValue.y);'; + output = 'return vec4(outputValue.x, outputValue.y, ' + 'outputValue.x, outputValue.y);'; } else if (broadcastDims.indexOf(cols) > -1) { output = 'return vec4(outputValue.xx, outputValue.zz);'; } @@ -699,7 +735,11 @@ export class CoordsGlslLib extends GlslLib { * Constructing snippets for unpacked output coordinates of samplers */ protected getUnpackedSamplerAtOutputCoords( - funcName: string, inputLayout: TextureLayout, outputLayout: TextureLayout, name: string): GlslLibRoutine { + funcName: string, + inputLayout: TextureLayout, + outputLayout: TextureLayout, + name: string, + ): GlslLibRoutine { const outTexShape = [outputLayout.width, outputLayout.height]; const inTexShape = [inputLayout.width, inputLayout.height]; const inRank = inputLayout.unpackedShape.length; @@ -728,7 +768,7 @@ export class CoordsGlslLib extends GlslLib { } else if (outRank < 2 && broadcastDims.length >= 1) { coordsSnippet = 'coords = 0;'; } else { - coordsSnippet = broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`).join('\n'); + coordsSnippet = broadcastDims.map((d) => `coords.${fields[d + rankDiff]} = 0;`).join('\n'); } let unpackedCoordsSnippet = ''; if (outRank < 2 && inRank > 0) { @@ -939,8 +979,11 @@ export class CoordsGlslLib extends GlslLib { return sampleTexture(${name}, uv); } `; - return new GlslLibRoutine( - source, ['coordinates.uvFromFlat', 'coordinates.sampleTexture', 'coordinates.coordsToOffset']); + return new GlslLibRoutine(source, [ + 'coordinates.uvFromFlat', + 'coordinates.sampleTexture', + 'coordinates.coordsToOffset', + ]); } /** @@ -1008,7 +1051,7 @@ export class CoordsGlslLib extends GlslLib { return new GlslLibRoutine(source, ['coordinates.sampleTexture']); } - const {newShape, keptDims} = squeezeShape(shape as number[]); + const { newShape, keptDims } = squeezeShape(shape as number[]); const squeezedShape = newShape; if (squeezedShape.length < shape.length) { const newInputShape = squeezeInputShape(shape, squeezedShape); @@ -1059,8 +1102,11 @@ export class CoordsGlslLib extends GlslLib { return sampleTexture(${name}, uv); } `; - return new GlslLibRoutine( - source, ['coordinates.uvFromFlat', 'coordinates.sampleTexture', 'coordinates.coordsToOffset']); + return new GlslLibRoutine(source, [ + 'coordinates.uvFromFlat', + 'coordinates.sampleTexture', + 'coordinates.coordsToOffset', + ]); } /** @@ -1072,7 +1118,7 @@ export class CoordsGlslLib extends GlslLib { const stride0 = shape[1] * shape[2]; const stride1 = shape[2]; - const {newShape, keptDims} = squeezeShape(shape as number[]); + const { newShape, keptDims } = squeezeShape(shape as number[]); const squeezedShape = newShape; if (squeezedShape.length < shape.length) { const newInputShape = squeezeInputShape(shape, squeezedShape); @@ -1102,8 +1148,11 @@ export class CoordsGlslLib extends GlslLib { return sampleTexture(${name}, uv); } `; - return new GlslLibRoutine( - source, ['coordinates.uvFromFlat', 'coordinates.sampleTexture', 'coordinates.coordsToOffset']); + return new GlslLibRoutine(source, [ + 'coordinates.uvFromFlat', + 'coordinates.sampleTexture', + 'coordinates.coordsToOffset', + ]); } /** @@ -1159,7 +1208,7 @@ export class CoordsGlslLib extends GlslLib { const stride1 = shape[2] * stride2; const stride0 = shape[1] * stride1; - const {newShape, keptDims} = squeezeShape(shape as number[]); + const { newShape, keptDims } = squeezeShape(shape as number[]); if (newShape.length < shape.length) { const newInputShape = squeezeInputShape(shape, newShape); const params = ['row', 'col', 'depth', 'depth2', 'depth3']; @@ -1200,7 +1249,7 @@ export class CoordsGlslLib extends GlslLib { const stride1 = shape[2] * stride2; const stride0 = shape[1] * stride1; - const {newShape, keptDims} = squeezeShape(shape as number[]); + const { newShape, keptDims } = squeezeShape(shape as number[]); if (newShape.length < shape.length) { const newInputShape = squeezeInputShape(shape, newShape); const params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4']; @@ -1229,8 +1278,11 @@ export class CoordsGlslLib extends GlslLib { return sampleTexture(${name}, uv); } `; - return new GlslLibRoutine( - source, ['coordinates.uvFromFlat', 'coordinates.sampleTexture', 'coordinates.coordsToOffset']); + return new GlslLibRoutine(source, [ + 'coordinates.uvFromFlat', + 'coordinates.sampleTexture', + 'coordinates.coordsToOffset', + ]); } /** @@ -1239,7 +1291,7 @@ export class CoordsGlslLib extends GlslLib { * There will only be one single variation of this * Also see coordsToOffset and offsetToIndices for input-specific versions */ - protected toVec(): {[name: string]: GlslLibRoutine} { + protected toVec(): { [name: string]: GlslLibRoutine } { const output = this.context.outputTextureLayout; const rank = output.shape.length; const strides = output.strides; @@ -1264,7 +1316,7 @@ export class CoordsGlslLib extends GlslLib { ${stridesBlock.join('')} } `; - return {toVec: new GlslLibRoutine(body, ['coordinates.coordsToOffset'])}; + return { toVec: new GlslLibRoutine(body, ['coordinates.coordsToOffset']) }; } /** * These are value getter functions generated for each input @@ -1272,20 +1324,24 @@ export class CoordsGlslLib extends GlslLib { * An '_T' variation is also produced which accesses values as if the * input was transposed */ - protected valueFrom(): {[name: string]: GlslLibRoutine} { - const result: {[name: string]: GlslLibRoutine} = {}; + protected valueFrom(): { [name: string]: GlslLibRoutine } { + const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const layout = this.context.inputTextureLayouts[i]; const shape = layout.unpackedShape.length > 0 ? layout.unpackedShape : layout.shape; const rank = shape.length; let funcName = `_${name}`; - result[funcName] = new GlslLibRoutine( - this.getValueFromSingle(name, rank, layout.width, layout.height, false), - [`shapeUtils.indicesToOffset${funcName}`, 'coordinates.offsetToCoords', 'fragcolor.getColorAsFloat']); + result[funcName] = new GlslLibRoutine(this.getValueFromSingle(name, rank, layout.width, layout.height, false), [ + `shapeUtils.indicesToOffset${funcName}`, + 'coordinates.offsetToCoords', + 'fragcolor.getColorAsFloat', + ]); funcName = funcName + '_T'; - result[funcName] = new GlslLibRoutine( - this.getValueFromSingle(name, rank, layout.width, layout.height, true), - [`shapeUtils.indicesToOffset${funcName}`, 'coordinates.offsetToCoords', 'fragcolor.getColorAsFloat']); + result[funcName] = new GlslLibRoutine(this.getValueFromSingle(name, rank, layout.width, layout.height, true), [ + `shapeUtils.indicesToOffset${funcName}`, + 'coordinates.offsetToCoords', + 'fragcolor.getColorAsFloat', + ]); }); return result; } @@ -1296,8 +1352,13 @@ export class CoordsGlslLib extends GlslLib { * @param rank rank of the input * @param transpose whether or not should generate a transpose variation */ - protected getValueFromSingle(varName: string, rank: number, width: number, height: number, transpose: boolean): - string { + protected getValueFromSingle( + varName: string, + rank: number, + width: number, + height: number, + transpose: boolean, + ): string { let name = `_${varName}`; if (transpose) { name = name + '_T'; @@ -1320,8 +1381,13 @@ export class CoordsGlslLib extends GlslLib { * @param rank rank of the input * @param transpose whether or not should generate a transpose variation */ - protected getPackedValueFrom(varName: string, rank: number, width: number, height: number, transpose: boolean): - string { + protected getPackedValueFrom( + varName: string, + rank: number, + width: number, + height: number, + transpose: boolean, + ): string { let name = `_${varName}_Pack`; if (transpose) { name = name + '_T'; diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-definitions.ts b/js/web/lib/onnxjs/backends/webgl/glsl-definitions.ts index 304508328408b..7632260909955 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-definitions.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-definitions.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {ProgramInfo, TextureLayout} from './types'; -import {WebGLContext} from './webgl-context'; +import { ProgramInfo, TextureLayout } from './types'; +import { WebGLContext } from './webgl-context'; /* eslint-disable @typescript-eslint/naming-convention */ export enum FunctionType { ValueBased, - Positional + Positional, } export interface GlslFunction { body: string; @@ -22,18 +22,24 @@ export interface GlslPositionalFunction extends GlslFunction, alreadyTraversed: Set, - result: GlslLibRoutineNode[]) { + graphNodes: GlslLibRoutineNode[], + cycleCheck: Set, + alreadyTraversed: Set, + result: GlslLibRoutineNode[], + ) { for (let i = 0; i < graphNodes.length; ++i) { this.dfsTraverse(graphNodes[i], cycleCheck, alreadyTraversed, result); } } private static dfsTraverse( - root: GlslLibRoutineNode, cycleCheck: Set, alreadyTraversed: Set, result: GlslLibRoutineNode[]) { + root: GlslLibRoutineNode, + cycleCheck: Set, + alreadyTraversed: Set, + result: GlslLibRoutineNode[], + ) { // if this root has already been traversed return if (!root || alreadyTraversed.has(root.name)) { return; @@ -95,7 +112,7 @@ export class TopologicalSortGlslRoutines { // cyclic dependency has been detected if (cycleCheck.has(root.name)) { - throw new Error('Cyclic dependency detected. Can\'t topologically sort routines needed for shader.'); + throw new Error("Cyclic dependency detected. Can't topologically sort routines needed for shader."); } // hold this node to detect cycles if any diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-encoding-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-encoding-lib.ts index 9d0656051c011..fe6673604e8c5 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-encoding-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-encoding-lib.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions'; +import { GlslContext, GlslLib, GlslLibRoutine } from './glsl-definitions'; /** * This GLSL library handles routines converting @@ -11,33 +11,33 @@ export class EncodingGlslLib extends GlslLib { constructor(context: GlslContext) { super(context); } - getFunctions(): {[name: string]: GlslLibRoutine} { - return {...this.encodeFloat32(), ...this.decodeFloat32()}; + getFunctions(): { [name: string]: GlslLibRoutine } { + return { ...this.encodeFloat32(), ...this.decodeFloat32() }; } - getCustomTypes(): {[name: string]: string} { + getCustomTypes(): { [name: string]: string } { return {}; } - protected encodeFloat32(): {[name: string]: GlslLibRoutine} { + protected encodeFloat32(): { [name: string]: GlslLibRoutine } { return { encode: new GlslLibRoutine(`highp vec4 encode(highp float f) { return vec4(f, 0.0, 0.0, 0.0); } - `) + `), }; } - protected decodeFloat32(): {[name: string]: GlslLibRoutine} { + protected decodeFloat32(): { [name: string]: GlslLibRoutine } { return { decode: new GlslLibRoutine(`highp float decode(highp vec4 rgba) { return rgba.r; } - `) + `), }; } /** * returns the routine to encode encode a 32bit float to a vec4 (of unsigned bytes) * @credit: https://stackoverflow.com/questions/7059962/how-do-i-convert-a-vec4-rgba-value-to-a-float */ - protected encodeUint8(): {[name: string]: GlslLibRoutine} { + protected encodeUint8(): { [name: string]: GlslLibRoutine } { const endianness = EncodingGlslLib.isLittleEndian() ? 'rgba.rgba=rgba.abgr;' : ''; return { encode: new GlslLibRoutine(` @@ -56,14 +56,14 @@ export class EncodingGlslLib extends GlslLib { rgba = rgba / 255.0; // values need to be normalized to [0,1] return rgba; } - `) + `), }; } /** * returns the routine to encode a vec4 of unsigned bytes to float32 * @credit: https://stackoverflow.com/questions/7059962/how-do-i-convert-a-vec4-rgba-value-to-a-float */ - protected decodeUint8(): {[name: string]: GlslLibRoutine} { + protected decodeUint8(): { [name: string]: GlslLibRoutine } { const endianness = EncodingGlslLib.isLittleEndian() ? 'rgba.rgba=rgba.abgr;' : ''; return { decode: new GlslLibRoutine(` @@ -76,7 +76,7 @@ export class EncodingGlslLib extends GlslLib { highp float Result = Sign * exp2(Exponent) * (Mantissa * exp2(-23.0 )); return Result; } - `) + `), }; } /** diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-fragcolor-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-fragcolor-lib.ts index 03954714f8adb..2bfe92421f277 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-fragcolor-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-fragcolor-lib.ts @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions'; -import {getGlsl} from './glsl-source'; +import { GlslContext, GlslLib, GlslLibRoutine } from './glsl-definitions'; +import { getGlsl } from './glsl-source'; /** * This GLSL library handles routines around reading a texlet and writing to it @@ -13,33 +13,35 @@ export class FragColorGlslLib extends GlslLib { constructor(context: GlslContext) { super(context); } - getFunctions(): {[name: string]: GlslLibRoutine} { - return {...this.setFragColor(), ...this.getColorAsFloat()}; + getFunctions(): { [name: string]: GlslLibRoutine } { + return { ...this.setFragColor(), ...this.getColorAsFloat() }; } - getCustomTypes(): {[name: string]: string} { + getCustomTypes(): { [name: string]: string } { return {}; } - protected setFragColor(): {[name: string]: GlslLibRoutine} { + protected setFragColor(): { [name: string]: GlslLibRoutine } { const glsl = getGlsl(this.context.glContext.version); return { setFragColor: new GlslLibRoutine( - ` + ` void setFragColor(float value) { ${glsl.output} = encode(value); } `, - ['encoding.encode']) + ['encoding.encode'], + ), }; } - protected getColorAsFloat(): {[name: string]: GlslLibRoutine} { + protected getColorAsFloat(): { [name: string]: GlslLibRoutine } { return { getColorAsFloat: new GlslLibRoutine( - ` + ` float getColorAsFloat(vec4 color) { return decode(color); } `, - ['encoding.decode']) + ['encoding.decode'], + ), }; } } diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-function-inliner.ts b/js/web/lib/onnxjs/backends/webgl/glsl-function-inliner.ts index 7e371700e4303..20ace4fbe515c 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-function-inliner.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-function-inliner.ts @@ -7,20 +7,20 @@ const FUNC_CALL_REGEX = '(\\w+)?\\s+([_0-9a-zA-Z]+)\\s+=\\s+__FUNC__\\((.*)\\)\\ * GLSL preprocessor responsible for resolving @inline directives */ export function replaceInlines(script: string): string { - const inlineDefs: {[name: string]: {params: Array<{type: string; name: string}|null>; body: string}} = {}; + const inlineDefs: { [name: string]: { params: Array<{ type: string; name: string } | null>; body: string } } = {}; let match; while ((match = INLINE_FUNC_DEF_REGEX.exec(script)) !== null) { const params = match[3] - .split(',') - .map(s => { - const tokens = s.trim().split(' '); - if (tokens && tokens.length === 2) { - return {type: tokens[0], name: tokens[1]}; - } - return null; - }) - .filter(v => v !== null); - inlineDefs[match[2]] = {params, body: match[4]}; + .split(',') + .map((s) => { + const tokens = s.trim().split(' '); + if (tokens && tokens.length === 2) { + return { type: tokens[0], name: tokens[1] }; + } + return null; + }) + .filter((v) => v !== null); + inlineDefs[match[2]] = { params, body: match[4] }; } for (const name in inlineDefs) { const regexString = FUNC_CALL_REGEX.replace('__FUNC__', name); @@ -29,7 +29,7 @@ export function replaceInlines(script: string): string { const type = match[1]; const variable = match[2]; const params = match[3].split(','); - const declLine = (type) ? `${type} ${variable};` : ''; + const declLine = type ? `${type} ${variable};` : ''; let newBody: string = inlineDefs[name].body; let paramRedecLine = ''; inlineDefs[name].params.forEach((v, i) => { diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-preprocessor.ts b/js/web/lib/onnxjs/backends/webgl/glsl-preprocessor.ts index c65118bb57df7..1fa390350d2a2 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-preprocessor.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-preprocessor.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {GlslContext, GlslLib, GlslLibRoutineNode, TopologicalSortGlslRoutines} from './glsl-definitions'; -import {replaceInlines} from './glsl-function-inliner'; -import {glslRegistry} from './glsl-registered-libs'; -import {getDefaultFragShaderMain, getFragShaderPreamble} from './glsl-source'; -import {ProgramInfo, TextureLayout, VariableInfo} from './types'; -import {WebGLContext} from './webgl-context'; +import { GlslContext, GlslLib, GlslLibRoutineNode, TopologicalSortGlslRoutines } from './glsl-definitions'; +import { replaceInlines } from './glsl-function-inliner'; +import { glslRegistry } from './glsl-registered-libs'; +import { getDefaultFragShaderMain, getFragShaderPreamble } from './glsl-source'; +import { ProgramInfo, TextureLayout, VariableInfo } from './types'; +import { WebGLContext } from './webgl-context'; /** * Preprocessor for the additions to the GLSL language @@ -18,12 +18,15 @@ import {WebGLContext} from './webgl-context'; */ export class GlslPreprocessor { readonly context: GlslContext; - readonly libs: {[name: string]: GlslLib} = {}; - readonly glslLibRoutineDependencyGraph: {[routineName: string]: GlslLibRoutineNode} = {}; + readonly libs: { [name: string]: GlslLib } = {}; + readonly glslLibRoutineDependencyGraph: { [routineName: string]: GlslLibRoutineNode } = {}; constructor( - glContext: WebGLContext, programInfo: ProgramInfo, inputTextureLayouts: TextureLayout[], - outputTextureLayout: TextureLayout) { + glContext: WebGLContext, + programInfo: ProgramInfo, + inputTextureLayouts: TextureLayout[], + outputTextureLayout: TextureLayout, + ) { this.context = new GlslContext(glContext, programInfo, inputTextureLayouts, outputTextureLayout); // construct GlslLibs @@ -103,7 +106,7 @@ export class GlslPreprocessor { private selectGlslLibRoutinesToBeIncluded(script: string): GlslLibRoutineNode[] { const nodes: GlslLibRoutineNode[] = []; - Object.keys(this.glslLibRoutineDependencyGraph).forEach(classAndRoutine => { + Object.keys(this.glslLibRoutineDependencyGraph).forEach((classAndRoutine) => { const routine = classAndRoutine.split('.')[1]; if (script.indexOf(routine) !== -1) { nodes.push(this.glslLibRoutineDependencyGraph[classAndRoutine]); @@ -123,7 +126,8 @@ export class GlslPreprocessor { if (variables) { for (const variable of variables) { uniformLines.push( - `uniform ${variable.type} ${variable.name}${variable.arrayLength ? `[${variable.arrayLength}]` : ''};`); + `uniform ${variable.type} ${variable.name}${variable.arrayLength ? `[${variable.arrayLength}]` : ''};`, + ); } } return uniformLines.join('\n'); diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-registered-libs.ts b/js/web/lib/onnxjs/backends/webgl/glsl-registered-libs.ts index 5556a9a58d6ab..e58aaaf112624 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-registered-libs.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-registered-libs.ts @@ -1,18 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {CoordsGlslLib} from './glsl-coordinate-lib'; -import {GlslContext, GlslLib} from './glsl-definitions'; -import {EncodingGlslLib} from './glsl-encoding-lib'; -import {FragColorGlslLib} from './glsl-fragcolor-lib'; -import {ShapeUtilsGlslLib} from './glsl-shape-utils-lib'; -import {VecGlslLib} from './glsl-vec-lib'; +import { CoordsGlslLib } from './glsl-coordinate-lib'; +import { GlslContext, GlslLib } from './glsl-definitions'; +import { EncodingGlslLib } from './glsl-encoding-lib'; +import { FragColorGlslLib } from './glsl-fragcolor-lib'; +import { ShapeUtilsGlslLib } from './glsl-shape-utils-lib'; +import { VecGlslLib } from './glsl-vec-lib'; -export const glslRegistry: {[name: string]: new (context: GlslContext) => GlslLib} = { - 'encoding': EncodingGlslLib, - 'fragcolor': FragColorGlslLib, - 'vec': VecGlslLib, - 'shapeUtils': ShapeUtilsGlslLib, - 'coordinates': CoordsGlslLib, +export const glslRegistry: { [name: string]: new (context: GlslContext) => GlslLib } = { + encoding: EncodingGlslLib, + fragcolor: FragColorGlslLib, + vec: VecGlslLib, + shapeUtils: ShapeUtilsGlslLib, + coordinates: CoordsGlslLib, // 'arrays': ArrayGlslSLib }; diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-shape-utils-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-shape-utils-lib.ts index 779ab64de6ee9..05fe49e13009e 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-shape-utils-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-shape-utils-lib.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions'; +import { GlslContext, GlslLib, GlslLibRoutine } from './glsl-definitions'; /** * GLSL Library responsible for data types and routines for manipulating @@ -11,21 +11,21 @@ export class ShapeUtilsGlslLib extends GlslLib { constructor(context: GlslContext) { super(context); } - getFunctions(): {[name: string]: GlslLibRoutine} { + getFunctions(): { [name: string]: GlslLibRoutine } { return { ...this.bcastIndex(), ...this.bcastMatmulIndex(), ...this.offsetToIndices(), ...this.indicesToOffset(), - ...this.incrementIndices() + ...this.incrementIndices(), }; } getCustomTypes() { return {}; } - protected bcastIndex(): {[name: string]: GlslLibRoutine} { + protected bcastIndex(): { [name: string]: GlslLibRoutine } { const outputRank = this.context.outputTextureLayout.shape.length; - const result: {[name: string]: GlslLibRoutine} = {}; + const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].unpackedShape; if (shape.length <= outputRank) { @@ -48,9 +48,9 @@ export class ShapeUtilsGlslLib extends GlslLib { }); return result; } - protected bcastMatmulIndex(): {[name: string]: GlslLibRoutine} { + protected bcastMatmulIndex(): { [name: string]: GlslLibRoutine } { const outputRank = this.context.outputTextureLayout.shape.length; - const result: {[name: string]: GlslLibRoutine} = {}; + const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].shape; if (!(shape.length < 2 || shape.length > outputRank)) { @@ -75,8 +75,8 @@ export class ShapeUtilsGlslLib extends GlslLib { }); return result; } - protected indicesToOffset(): {[name: string]: GlslLibRoutine} { - const result: {[name: string]: GlslLibRoutine} = {}; + protected indicesToOffset(): { [name: string]: GlslLibRoutine } { + const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].shape; const strides = this.context.inputTextureLayouts[i].strides; @@ -84,8 +84,9 @@ export class ShapeUtilsGlslLib extends GlslLib { let funcName = `indicesToOffset_${name}`; result[funcName] = new GlslLibRoutine(ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides)); funcName = `indicesToOffset_${name}_T`; - result[funcName] = - new GlslLibRoutine(ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides.slice().reverse())); + result[funcName] = new GlslLibRoutine( + ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides.slice().reverse()), + ); }); return result; } @@ -104,8 +105,8 @@ export class ShapeUtilsGlslLib extends GlslLib { } `; } - protected offsetToIndices(): {[name: string]: GlslLibRoutine} { - const result: {[name: string]: GlslLibRoutine} = {}; + protected offsetToIndices(): { [name: string]: GlslLibRoutine } { + const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].shape; const strides = this.context.inputTextureLayouts[i].strides; @@ -113,8 +114,9 @@ export class ShapeUtilsGlslLib extends GlslLib { let funcName = `offsetToIndices_${name}`; result[funcName] = new GlslLibRoutine(ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides)); funcName = `offsetToIndices_${name}_T`; - result[funcName] = - new GlslLibRoutine(ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides.slice().reverse())); + result[funcName] = new GlslLibRoutine( + ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides.slice().reverse()), + ); }); return result; } @@ -134,8 +136,8 @@ export class ShapeUtilsGlslLib extends GlslLib { } `; } - protected incrementIndices(): {[name: string]: GlslLibRoutine} { - const result: {[name: string]: GlslLibRoutine} = {}; + protected incrementIndices(): { [name: string]: GlslLibRoutine } { + const result: { [name: string]: GlslLibRoutine } = {}; this.context.programInfo.inputNames.forEach((name, i) => { const shape = this.context.inputTextureLayouts[i].shape; const rank = shape.length; diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-source.ts b/js/web/lib/onnxjs/backends/webgl/glsl-source.ts index a6cb2e503dc05..6759f39fa7f07 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-source.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-source.ts @@ -33,11 +33,11 @@ const GLSL_ES_3_0: Glsl = { outputDeclaration: 'out vec4 outputColor;', }; -export function getGlsl(version: 1|2) { +export function getGlsl(version: 1 | 2) { return version === 1 ? GLSL_ES_2_0 : GLSL_ES_3_0; } -export function getVertexShaderSource(version: 1|2): string { +export function getVertexShaderSource(version: 1 | 2): string { const glsl = getGlsl(version); return `${glsl.version} precision highp float; @@ -53,7 +53,7 @@ export function getVertexShaderSource(version: 1|2): string { }`; } -export function getFragShaderPreamble(version: 1|2): string { +export function getFragShaderPreamble(version: 1 | 2): string { const glsl = getGlsl(version); return `${glsl.version} precision highp float; @@ -90,7 +90,7 @@ export function getFragShaderPreamble(version: 1|2): string { `; } -export function getDefaultFragShaderMain(version: 1|2, outputShapeLength: number): string { +export function getDefaultFragShaderMain(version: 1 | 2, outputShapeLength: number): string { const glsl = getGlsl(version); return ` void main() { diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts index eb7c1c080ee9b..7b1ba915e7c10 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions'; +import { GlslContext, GlslLib, GlslLibRoutine } from './glsl-definitions'; /** * GLSL Library responsible for vec routines @@ -12,17 +12,17 @@ export class VecGlslLib extends GlslLib { constructor(context: GlslContext) { super(context); } - getCustomTypes(): {[name: string]: string} { + getCustomTypes(): { [name: string]: string } { return {}; } - getFunctions(): {[name: string]: GlslLibRoutine} { - return {...this.binaryVecFunctions(), ...this.copyVec(), ...this.setVecItem(), ...this.getVecItem()}; + getFunctions(): { [name: string]: GlslLibRoutine } { + return { ...this.binaryVecFunctions(), ...this.copyVec(), ...this.setVecItem(), ...this.getVecItem() }; } - protected binaryVecFunctions(): {[name: string]: GlslLibRoutine} { + protected binaryVecFunctions(): { [name: string]: GlslLibRoutine } { const outputLayout = this.context.outputTextureLayout; const rank = outputLayout.shape.length; - const nameOp: {[name: string]: string} = {add: '+=', sub: '-=', mul: '*=', div: '/='}; - const result: {[name: string]: GlslLibRoutine} = {}; + const nameOp: { [name: string]: string } = { add: '+=', sub: '-=', mul: '*=', div: '/=' }; + const result: { [name: string]: GlslLibRoutine } = {}; for (const name in nameOp) { const fname = `${name}Vec`; let assignmentBlock = ''; @@ -41,7 +41,7 @@ export class VecGlslLib extends GlslLib { return result; } - protected copyVec(): {[name: string]: GlslLibRoutine} { + protected copyVec(): { [name: string]: GlslLibRoutine } { const outputLayout = this.context.outputTextureLayout; const rank = outputLayout.shape.length; let assignmentBlock = ''; @@ -55,10 +55,10 @@ export class VecGlslLib extends GlslLib { ${assignmentBlock} } `; - return {copyVec: new GlslLibRoutine(body)}; + return { copyVec: new GlslLibRoutine(body) }; } - protected setVecItem(): {[name: string]: GlslLibRoutine} { + protected setVecItem(): { [name: string]: GlslLibRoutine } { const outputLayout = this.context.outputTextureLayout; const rank = outputLayout.shape.length; let block = ` @@ -82,9 +82,9 @@ export class VecGlslLib extends GlslLib { ${block} } `; - return {setVecItem: new GlslLibRoutine(body)}; + return { setVecItem: new GlslLibRoutine(body) }; } - protected getVecItem(): {[name: string]: GlslLibRoutine} { + protected getVecItem(): { [name: string]: GlslLibRoutine } { const outputLayout = this.context.outputTextureLayout; const rank = outputLayout.shape.length; let block = ` @@ -108,6 +108,6 @@ export class VecGlslLib extends GlslLib { ${block} } `; - return {getVecItem: new GlslLibRoutine(body)}; + return { getVecItem: new GlslLibRoutine(body) }; } } diff --git a/js/web/lib/onnxjs/backends/webgl/inference-handler.ts b/js/web/lib/onnxjs/backends/webgl/inference-handler.ts index 0a51ff7c4029e..678ffa19275e9 100644 --- a/js/web/lib/onnxjs/backends/webgl/inference-handler.ts +++ b/js/web/lib/onnxjs/backends/webgl/inference-handler.ts @@ -1,32 +1,38 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceHandler} from '../../backend'; -import {Logger} from '../../instrument'; -import {Tensor} from '../../tensor'; -import {ShapeUtil} from '../../util'; - -import {createPackProgramInfoLoader} from './ops/pack'; -import {createPackedReshape3DProgramInfoLoader, isReshapeCheap, processDims3D} from './ops/reshape-packed'; -import {encodeAsUint8} from './ops/uint8-encode'; -import {createUnpackProgramInfoLoader} from './ops/unpack'; -import {WebGLSessionHandler} from './session-handler'; -import {EncoderUsage} from './texture-data-encoder'; -import {calculateTextureWidthAndHeight, createTextureLayoutFromShape, createTextureLayoutFromTextureType} from './texture-layout'; -import {Artifact, ProgramInfo, ProgramInfoLoader, TextureData, TextureLayout, TextureType} from './types'; - -const getProgramInfoUniqueKey = - (programInfo: ProgramInfo|ProgramInfoLoader, inputTextureDatas: TextureData[]): string => { - const inputs = - inputTextureDatas.map(texture => `${texture.unpackedShape.join(',')};${texture.width}x${texture.height}`) - .join('_'); - let key = programInfo.name; - if (programInfo.cacheHint) { - key += '[' + programInfo.cacheHint + ']'; - } - key += ':' + inputs; - return key; - }; +import { InferenceHandler } from '../../backend'; +import { Logger } from '../../instrument'; +import { Tensor } from '../../tensor'; +import { ShapeUtil } from '../../util'; + +import { createPackProgramInfoLoader } from './ops/pack'; +import { createPackedReshape3DProgramInfoLoader, isReshapeCheap, processDims3D } from './ops/reshape-packed'; +import { encodeAsUint8 } from './ops/uint8-encode'; +import { createUnpackProgramInfoLoader } from './ops/unpack'; +import { WebGLSessionHandler } from './session-handler'; +import { EncoderUsage } from './texture-data-encoder'; +import { + calculateTextureWidthAndHeight, + createTextureLayoutFromShape, + createTextureLayoutFromTextureType, +} from './texture-layout'; +import { Artifact, ProgramInfo, ProgramInfoLoader, TextureData, TextureLayout, TextureType } from './types'; + +const getProgramInfoUniqueKey = ( + programInfo: ProgramInfo | ProgramInfoLoader, + inputTextureDatas: TextureData[], +): string => { + const inputs = inputTextureDatas + .map((texture) => `${texture.unpackedShape.join(',')};${texture.width}x${texture.height}`) + .join('_'); + let key = programInfo.name; + if (programInfo.cacheHint) { + key += '[' + programInfo.cacheHint + ']'; + } + key += ':' + inputs; + return key; +}; export class WebGLInferenceHandler implements InferenceHandler { private packedTextureDataCache: Map; @@ -43,7 +49,7 @@ export class WebGLInferenceHandler implements InferenceHandler { return calculateTextureWidthAndHeight(this.session.layoutStrategy, shape, textureType); } - executeProgram(program: ProgramInfo|ProgramInfoLoader, inputs: readonly Tensor[]): TextureData { + executeProgram(program: ProgramInfo | ProgramInfoLoader, inputs: readonly Tensor[]): TextureData { if (inputs.length < program.inputNames.length) { throw new Error(`Input size mustn't be less than ${program.inputNames.length}.`); } @@ -59,14 +65,18 @@ export class WebGLInferenceHandler implements InferenceHandler { const key = getProgramInfoUniqueKey(program, inputTextureDatas); let artifact = this.session.programManager.getArtifact(key); - const programInfo = artifact ? - artifact.programInfo : - (typeof (program as ProgramInfoLoader).get === 'function' ? (program as ProgramInfoLoader).get() : - (program as ProgramInfo)); + const programInfo = artifact + ? artifact.programInfo + : typeof (program as ProgramInfoLoader).get === 'function' + ? (program as ProgramInfoLoader).get() + : (program as ProgramInfo); // create texture info for output const outputTextureLayout = createTextureLayoutFromTextureType( - this.session.layoutStrategy, programInfo.output.dims, programInfo.output.textureType); + this.session.layoutStrategy, + programInfo.output.dims, + programInfo.output.textureType, + ); const outputTextureData = this.createTextureData(outputTextureLayout, programInfo.output.type); if (!artifact) { @@ -141,18 +151,21 @@ export class WebGLInferenceHandler implements InferenceHandler { // 3. run the program before dotProduct. // const adjustedKernelShape = [shape[0], Math.ceil((shape[1] * shape[2] * shape[3]) / channels)]; - const adjustedLayout = - createTextureLayoutFromTextureType(this.session.layoutStrategy, adjustedKernelShape, textureType); + const adjustedLayout = createTextureLayoutFromTextureType( + this.session.layoutStrategy, + adjustedKernelShape, + textureType, + ); let buffer = tensor.numberData; - if (shape[1] * shape[2] * shape[3] % channels !== 0) { + if ((shape[1] * shape[2] * shape[3]) % channels !== 0) { const numFeatureMaps = shape[0]; const oldRowSize = shape[1] * shape[2] * shape[3]; - const newRowSize = Math.ceil(oldRowSize * group / channels) * channels; + const newRowSize = Math.ceil((oldRowSize * group) / channels) * channels; const newSize = numFeatureMaps * newRowSize; buffer = new Float32Array(newSize); for (let f = 0; f < numFeatureMaps; ++f) { const oldOffset = f * oldRowSize; - const newOffset = f * newRowSize + f % group * oldRowSize; + const newOffset = f * newRowSize + (f % group) * oldRowSize; buffer.set(tensor.numberData.subarray(oldOffset, oldOffset + oldRowSize), newOffset); } } @@ -161,10 +174,16 @@ export class WebGLInferenceHandler implements InferenceHandler { } if (textureType === TextureType.packed) { - const unpackedTextureLayout = - createTextureLayoutFromShape(this.session.layoutStrategy, tensor.dims, 1, [], {reverseWH: true}); + const unpackedTextureLayout = createTextureLayoutFromShape(this.session.layoutStrategy, tensor.dims, 1, [], { + reverseWH: true, + }); const unpackedTextureData = this.createTextureData( - unpackedTextureLayout, tensor.type, tensor.numberData, tensor, EncoderUsage.UploadOnly); + unpackedTextureLayout, + tensor.type, + tensor.numberData, + tensor, + EncoderUsage.UploadOnly, + ); td = this.pack(unpackedTextureData); } else { td = this.createTextureData(layout, tensor.type, tensor.numberData, tensor, EncoderUsage.UploadOnly); @@ -183,13 +202,21 @@ export class WebGLInferenceHandler implements InferenceHandler { * @param tensor the tensor to bind. tensor's data is ignored. */ createTextureDataFromLayoutBindTensor( - layout: TextureLayout, dataType: Tensor.DataType, data: Tensor.NumberType, tensor: Tensor): TextureData { + layout: TextureLayout, + dataType: Tensor.DataType, + data: Tensor.NumberType, + tensor: Tensor, + ): TextureData { return this.createTextureData(layout, dataType, data, tensor, EncoderUsage.UploadOnly); } private createTextureData( - layout: TextureLayout, dataType: Tensor.DataType, data?: Tensor.NumberType, tensor?: Tensor, - usage?: EncoderUsage): TextureData { + layout: TextureLayout, + dataType: Tensor.DataType, + data?: Tensor.NumberType, + tensor?: Tensor, + usage?: EncoderUsage, + ): TextureData { Logger.verbose('InferenceHandler', `Creating TextureData: layout:[${JSON.stringify(layout)}]`); const texture = this.session.textureManager.createTextureFromLayout(dataType, layout, data, usage); return this.createTextureDataFromTexture(layout, dataType, texture, tensor); @@ -223,7 +250,7 @@ export class WebGLInferenceHandler implements InferenceHandler { shape: reshapedDims.length !== 0 ? reshapedDims : [1], strides: ShapeUtil.computeStrides(reshapedDims), unpackedShape: reshapedDims, - isPacked: true + isPacked: true, }; const newTextureData = this.createTextureDataFromTexture(newTextureLayout, input.type, inputTD.texture); return newTextureData.tensor; @@ -234,7 +261,9 @@ export class WebGLInferenceHandler implements InferenceHandler { const squeezedInputTensor = this.reshapePacked(input, squeezedInputShape); const squeezedOutputTensor = this.run( - createPackedReshape3DProgramInfoLoader(this, squeezedInputTensor, squeezedOutputShape), [squeezedInputTensor]); + createPackedReshape3DProgramInfoLoader(this, squeezedInputTensor, squeezedOutputShape), + [squeezedInputTensor], + ); const outputTensor = this.reshapePacked(squeezedOutputTensor, reshapedDims); return outputTensor; } @@ -246,23 +275,36 @@ export class WebGLInferenceHandler implements InferenceHandler { } private createTextureDataFromTexture( - layout: TextureLayout, dataType: Tensor.DataType, texture: WebGLTexture, tensor?: Tensor, tensorId?: Tensor.Id) { + layout: TextureLayout, + dataType: Tensor.DataType, + texture: WebGLTexture, + tensor?: Tensor, + tensorId?: Tensor.Id, + ) { const textureData: TextureData = { ...layout, - tensor: tensor || - new Tensor( - layout.unpackedShape, dataType, (_id: Tensor.Id) => this.readTexture(textureData), - async (_id: Tensor.Id) => this.readTextureAsync(textureData), undefined, tensorId), - texture + tensor: + tensor || + new Tensor( + layout.unpackedShape, + dataType, + (_id: Tensor.Id) => this.readTexture(textureData), + async (_id: Tensor.Id) => this.readTextureAsync(textureData), + undefined, + tensorId, + ), + texture, }; this.setTextureData(textureData.tensor.dataId, textureData, layout.isPacked); return textureData; } - private getTextureData(tensorId: Tensor.Id, isPacked = false): TextureData|undefined { - return this.session.isInitializer(tensorId) ? this.session.getTextureData(tensorId, isPacked) : - isPacked ? this.packedTextureDataCache.get(tensorId) : - this.unpackedTextureDataCache.get(tensorId); + private getTextureData(tensorId: Tensor.Id, isPacked = false): TextureData | undefined { + return this.session.isInitializer(tensorId) + ? this.session.getTextureData(tensorId, isPacked) + : isPacked + ? this.packedTextureDataCache.get(tensorId) + : this.unpackedTextureDataCache.get(tensorId); } setTextureData(tensorId: Tensor.Id, td: TextureData, isPacked = false): void { if (this.session.isInitializer(tensorId)) { @@ -277,9 +319,9 @@ export class WebGLInferenceHandler implements InferenceHandler { dispose(): void { this.session.textureManager.clearActiveTextures(); - this.packedTextureDataCache.forEach(td => this.session.textureManager.releaseTexture(td)); + this.packedTextureDataCache.forEach((td) => this.session.textureManager.releaseTexture(td)); this.packedTextureDataCache = new Map(); - this.unpackedTextureDataCache.forEach(td => this.session.textureManager.releaseTexture(td)); + this.unpackedTextureDataCache.forEach((td) => this.session.textureManager.releaseTexture(td)); this.unpackedTextureDataCache = new Map(); } diff --git a/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts index ec2a0ccc43b07..6872e2800508e 100644 --- a/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts @@ -1,38 +1,55 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {OpSet} from '../../opset'; +import { OpSet } from '../../opset'; -import {batchNormalization, parseBatchNormalizationAttributes} from './ops/batch-normalization'; +import { batchNormalization, parseBatchNormalizationAttributes } from './ops/batch-normalization'; import * as binaryOps from './ops/binary-op'; -import {cast, parseCastAttributes} from './ops/cast'; -import {concat, parseConcatAttributes} from './ops/concat'; -import {conv, parseConvAttributes} from './ops/conv'; -import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; -import {depthToSpace, parseDepthToSpaceAttributes} from './ops/depth-to-space'; -import {flatten, parseFlattenAttributes} from './ops/flatten'; -import {gather, parseGatherAttributes} from './ops/gather'; -import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm'; -import {imageScaler, parseImageScalerAttributes} from './ops/image-scaler'; -import {instanceNormalization, parseInstanceNormalizationAttributes} from './ops/instance-normalization'; -import {lrn, parseLrnAttributes} from './ops/lrn'; -import {matMul, parseMatMulAttributes} from './ops/matmul'; -import {padV11, padV2, parsePadAttributesV11, parsePadAttributesV2} from './ops/pad'; -import {averagePool, globalAveragePool, globalMaxPool, maxPool, parseAveragePoolAttributes, parseGlobalAveragePoolAttributes, parseMaxPoolAttributes} from './ops/pool'; -import {parseReduceAttributes, reduceLogSum, reduceLogSumSquare, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum} from './ops/reduce'; -import {reshape} from './ops/reshape'; -import {parseResizeAttributesV10, parseResizeAttributesV11, resize} from './ops/resize-packed'; -import {shape} from './ops/shape'; -import {parseSliceAttributes, slice, sliceV10} from './ops/slice'; -import {parseSoftmaxAttributes, parseSoftmaxAttributesV13, softmax, softmaxV13} from './ops/softmax'; -import {parseSplitAttributes, split} from './ops/split'; -import {parseSqueezeAttributes, squeeze, squeezeV13} from './ops/squeeze'; -import {sum} from './ops/sum'; -import {tile} from './ops/tile'; -import {parseTransposeAttributes, transpose} from './ops/transpose'; +import { cast, parseCastAttributes } from './ops/cast'; +import { concat, parseConcatAttributes } from './ops/concat'; +import { conv, parseConvAttributes } from './ops/conv'; +import { convTranspose, parseConvTransposeAttributes } from './ops/conv-transpose'; +import { depthToSpace, parseDepthToSpaceAttributes } from './ops/depth-to-space'; +import { flatten, parseFlattenAttributes } from './ops/flatten'; +import { gather, parseGatherAttributes } from './ops/gather'; +import { gemm, parseGemmAttributesV11, parseGemmAttributesV7 } from './ops/gemm'; +import { imageScaler, parseImageScalerAttributes } from './ops/image-scaler'; +import { instanceNormalization, parseInstanceNormalizationAttributes } from './ops/instance-normalization'; +import { lrn, parseLrnAttributes } from './ops/lrn'; +import { matMul, parseMatMulAttributes } from './ops/matmul'; +import { padV11, padV2, parsePadAttributesV11, parsePadAttributesV2 } from './ops/pad'; +import { + averagePool, + globalAveragePool, + globalMaxPool, + maxPool, + parseAveragePoolAttributes, + parseGlobalAveragePoolAttributes, + parseMaxPoolAttributes, +} from './ops/pool'; +import { + parseReduceAttributes, + reduceLogSum, + reduceLogSumSquare, + reduceMax, + reduceMean, + reduceMin, + reduceProd, + reduceSum, +} from './ops/reduce'; +import { reshape } from './ops/reshape'; +import { parseResizeAttributesV10, parseResizeAttributesV11, resize } from './ops/resize-packed'; +import { shape } from './ops/shape'; +import { parseSliceAttributes, slice, sliceV10 } from './ops/slice'; +import { parseSoftmaxAttributes, parseSoftmaxAttributesV13, softmax, softmaxV13 } from './ops/softmax'; +import { parseSplitAttributes, split } from './ops/split'; +import { parseSqueezeAttributes, squeeze, squeezeV13 } from './ops/squeeze'; +import { sum } from './ops/sum'; +import { tile } from './ops/tile'; +import { parseTransposeAttributes, transpose } from './ops/transpose'; import * as unaryOps from './ops/unary-op'; -import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze'; -import {parseUpsampleAttributesV7, parseUpsampleAttributesV9, upsample} from './ops/upsample'; +import { parseUnsqueezeAttributes, unsqueeze, unsqueezeV13 } from './ops/unsqueeze'; +import { parseUpsampleAttributesV7, parseUpsampleAttributesV9, upsample } from './ops/upsample'; export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ ['Abs', '', '6+', unaryOps.abs], @@ -99,7 +116,7 @@ export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ ['Shape', '', '1+', shape], ['Sigmoid', '', '6+', unaryOps.sigmoid], ['Sin', '', '7+', unaryOps.sin], - ['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10 + ['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10 ['Slice', '', '1-9', slice, parseSliceAttributes], // The "semantic" meaning of axis has changed in opset-13. ['Softmax', '', '1-12', softmax, parseSoftmaxAttributes], diff --git a/js/web/lib/onnxjs/backends/webgl/ops/batch-normalization.ts b/js/web/lib/onnxjs/backends/webgl/ops/batch-normalization.ts index a2013dba27e27..ee7b04920d4e0 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/batch-normalization.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/batch-normalization.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, TextureType } from '../types'; export interface BatchNormalizationAttributes extends AttributeWithCacheKey { epsilon: number; @@ -18,39 +18,53 @@ export interface BatchNormalizationAttributes extends AttributeWithCacheKey { const batchNormalizationProgramMetadata = { name: 'BatchNormalization', inputNames: ['A', 'Scale', 'B', 'Mean', 'Variance'], - inputTypes: - [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] + inputTypes: [ + TextureType.unpacked, + TextureType.unpacked, + TextureType.unpacked, + TextureType.unpacked, + TextureType.unpacked, + ], }; -export const batchNormalization: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: BatchNormalizationAttributes): Tensor[] => { - validateInputs(inputs); - const output = inferenceHandler.run( - { - ...batchNormalizationProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => createBatchNormalizationProgramInfo(inferenceHandler, inputs, attributes) - }, - inputs); - return [output]; - }; +export const batchNormalization: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: BatchNormalizationAttributes, +): Tensor[] => { + validateInputs(inputs); + const output = inferenceHandler.run( + { + ...batchNormalizationProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createBatchNormalizationProgramInfo(inferenceHandler, inputs, attributes), + }, + inputs, + ); + return [output]; +}; -export const parseBatchNormalizationAttributes: OperatorInitialization = - (node: Graph.Node): BatchNormalizationAttributes => { - const epsilon = node.attributes.getFloat('epsilon', 1e-5); - const momentum = node.attributes.getFloat('momentum', 0.9); - const spatial = node.attributes.getInt('spatial', 1); - return createAttributeWithCacheKey({epsilon, momentum, spatial}); - }; +export const parseBatchNormalizationAttributes: OperatorInitialization = ( + node: Graph.Node, +): BatchNormalizationAttributes => { + const epsilon = node.attributes.getFloat('epsilon', 1e-5); + const momentum = node.attributes.getFloat('momentum', 0.9); + const spatial = node.attributes.getInt('spatial', 1); + return createAttributeWithCacheKey({ epsilon, momentum, spatial }); +}; -const createBatchNormalizationProgramInfo = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: BatchNormalizationAttributes): - ProgramInfo => { - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const rank = inputs[0].dims.length; - const [scaleWidth, scaleHeight] = - inferenceHandler.calculateTextureWidthAndHeight(inputs[1].dims, TextureType.unpacked); - const shaderSource = ` +const createBatchNormalizationProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: BatchNormalizationAttributes, +): ProgramInfo => { + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const rank = inputs[0].dims.length; + const [scaleWidth, scaleHeight] = inferenceHandler.calculateTextureWidthAndHeight( + inputs[1].dims, + TextureType.unpacked, + ); + const shaderSource = ` float process(int[${rank}] indices) { vec2 position = offsetToCoords(indices[1], ${scaleWidth}, ${scaleHeight}); float scale = getColorAsFloat(${glsl.texture2D}(Scale, position)); @@ -60,12 +74,12 @@ const createBatchNormalizationProgramInfo = return scale * ( (_A(indices) - mean) / sqrt(variance + float(${attributes.epsilon})) ) + b; }`; - return { - ...batchNormalizationProgramMetadata, - output: {dims: inputs[0].dims, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...batchNormalizationProgramMetadata, + output: { dims: inputs[0].dims, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 5) { @@ -80,17 +94,30 @@ const validateInputs = (inputs: Tensor[]): void => { // input should atleast have three dimensions - N,C,dim1,...,dimn // other inputs can have only one dimensions - if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1 || mean.dims.length !== 1 || - var_.dims.length !== 1) { + if ( + X.dims.length < 3 || + scale.dims.length !== 1 || + B.dims.length !== 1 || + mean.dims.length !== 1 || + var_.dims.length !== 1 + ) { throw new Error('invalid input shape.'); } - if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1] || mean.dims[0] !== X.dims[1] || - var_.dims[0] !== X.dims[1]) { + if ( + scale.dims[0] !== X.dims[1] || + B.dims[0] !== X.dims[1] || + mean.dims[0] !== X.dims[1] || + var_.dims[0] !== X.dims[1] + ) { throw new Error('invalid input shape.'); } - if ((X.type !== 'float32' && X.type !== 'float64') || (scale.type !== 'float32' && scale.type !== 'float64') || - (B.type !== 'float32' && B.type !== 'float64') || (mean.type !== 'float32' && mean.type !== 'float64') || - (var_.type !== 'float32' && var_.type !== 'float64')) { + if ( + (X.type !== 'float32' && X.type !== 'float64') || + (scale.type !== 'float32' && scale.type !== 'float64') || + (B.type !== 'float32' && B.type !== 'float64') || + (mean.type !== 'float32' && mean.type !== 'float64') || + (var_.type !== 'float32' && var_.type !== 'float64') + ) { throw new Error('invalid input tensor types.'); } }; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/binary-op.ts b/js/web/lib/onnxjs/backends/webgl/ops/binary-op.ts index 4aa9bf3c9e164..84fe5ad046dc6 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/binary-op.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/binary-op.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {BroadcastUtil, ShapeUtil} from '../../../util'; -import {FunctionType, GlslValueFunction} from '../glsl-definitions'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, TextureType} from '../types'; +import { Tensor } from '../../../tensor'; +import { BroadcastUtil, ShapeUtil } from '../../../util'; +import { FunctionType, GlslValueFunction } from '../glsl-definitions'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types'; export function glslAdd(): GlslValueFunction { const name = 'add_'; @@ -18,7 +18,7 @@ export function glslAdd(): GlslValueFunction { return v1 + v2; } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslDiv(): GlslValueFunction { const name = 'div_'; @@ -30,7 +30,7 @@ export function glslDiv(): GlslValueFunction { return v1 / v2; } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslMul(): GlslValueFunction { const name = 'mul_'; @@ -42,7 +42,7 @@ export function glslMul(): GlslValueFunction { return v1 * v2; } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslSub(): GlslValueFunction { const name = 'sub_'; @@ -54,7 +54,7 @@ export function glslSub(): GlslValueFunction { return v1 - v2; } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslEqual(): GlslValueFunction { const name = 'equal_'; @@ -66,7 +66,7 @@ export function glslEqual(): GlslValueFunction { return vec4(equal(v1, v2)); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslGreater(): GlslValueFunction { const name = 'greater_'; @@ -81,7 +81,7 @@ export function glslGreater(): GlslValueFunction { v1.a > v2.a ); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslLess(): GlslValueFunction { const name = 'less_'; @@ -96,7 +96,7 @@ export function glslLess(): GlslValueFunction { v1.a < v2.a ); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslAnd(): GlslValueFunction { const name = 'and_'; @@ -113,7 +113,7 @@ export function glslAnd(): GlslValueFunction { b1.a && b2.a ); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslOr(): GlslValueFunction { const name = 'or_'; @@ -130,7 +130,7 @@ export function glslOr(): GlslValueFunction { b1.a || b2.a ); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslXor(): GlslValueFunction { const name = 'xor_'; @@ -147,7 +147,7 @@ export function glslXor(): GlslValueFunction { b1.a ^^ b2.a ); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslPow(): GlslValueFunction { return glslBuiltinBinary('pow'); @@ -167,7 +167,7 @@ export function glslPRelu(): GlslValueFunction { ); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } function glslBuiltinBinary(fname: string): GlslValueFunction { @@ -180,53 +180,61 @@ function glslBuiltinBinary(fname: string): GlslValueFunction { return ${fname}(v1, v2); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } -const createBinaryProgramInfoLoader = - (handler: WebGLInferenceHandler, inputs: Tensor[], glslFunc: GlslValueFunction, - outputTensorType: Tensor.DataType = inputs[0].type, cacheKey?: string): ProgramInfoLoader => { - const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked; - return { - name: glslFunc.name, - inputNames: ['A', 'B'], - inputTypes: [textureType, textureType], - cacheHint: cacheKey, - get: () => createBinaryProgramInfo(handler, inputs, glslFunc, outputTensorType) - }; - }; +const createBinaryProgramInfoLoader = ( + handler: WebGLInferenceHandler, + inputs: Tensor[], + glslFunc: GlslValueFunction, + outputTensorType: Tensor.DataType = inputs[0].type, + cacheKey?: string, +): ProgramInfoLoader => { + const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked; + return { + name: glslFunc.name, + inputNames: ['A', 'B'], + inputTypes: [textureType, textureType], + cacheHint: cacheKey, + get: () => createBinaryProgramInfo(handler, inputs, glslFunc, outputTensorType), + }; +}; -const createBinaryProgramInfo = - (handler: WebGLInferenceHandler, inputs: Tensor[], glslFunc: GlslValueFunction, - outputTensorType: Tensor.DataType = inputs[0].type): ProgramInfo => { - const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked; - const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims); - let outputShape = inputs[0].dims; +const createBinaryProgramInfo = ( + handler: WebGLInferenceHandler, + inputs: Tensor[], + glslFunc: GlslValueFunction, + outputTensorType: Tensor.DataType = inputs[0].type, +): ProgramInfo => { + const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked; + const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims); + let outputShape = inputs[0].dims; - const usePackedTexture = handler.session.pack; + const usePackedTexture = handler.session.pack; - if (isBroadcast) { - const calculatedShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false); - if (!calculatedShape) { - throw new Error('Can\'t perform binary op on the given tensors'); - } - outputShape = calculatedShape; - const outputRank = outputShape.length; - const aRank = inputs[0].dims.length !== 0 ? inputs[0].dims.length : 1; - const bRank = inputs[1].dims.length !== 0 ? inputs[1].dims.length : 1; - const aBcast = inputs[0].dims.length !== 0 ? 'bcastIndices_A(indices, aindices);' : 'aindices[0] = 0;'; - const bBcast = inputs[1].dims.length !== 0 ? 'bcastIndices_B(indices, bindices);' : 'bindices[0] = 0;'; + if (isBroadcast) { + const calculatedShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false); + if (!calculatedShape) { + throw new Error("Can't perform binary op on the given tensors"); + } + outputShape = calculatedShape; + const outputRank = outputShape.length; + const aRank = inputs[0].dims.length !== 0 ? inputs[0].dims.length : 1; + const bRank = inputs[1].dims.length !== 0 ? inputs[1].dims.length : 1; + const aBcast = inputs[0].dims.length !== 0 ? 'bcastIndices_A(indices, aindices);' : 'aindices[0] = 0;'; + const bBcast = inputs[1].dims.length !== 0 ? 'bcastIndices_B(indices, bindices);' : 'bindices[0] = 0;'; - const glsl = getGlsl(handler.session.backend.glContext.version); - const shaderSource = usePackedTexture ? ` + const glsl = getGlsl(handler.session.backend.glContext.version); + const shaderSource = usePackedTexture + ? ` ${glslFunc.body} void main() { vec4 a = getAAtOutCoords(); vec4 b = getBAtOutCoords(); vec4 result = ${glslFunc.name}(a, b); ${glsl.output} = result; - }` : - ` + }` + : ` ${glslFunc.body} float process(int indices[${outputRank}]) { int aindices[${aRank}]; @@ -236,17 +244,17 @@ const createBinaryProgramInfo = return ${glslFunc.name}(_A(aindices), _B(bindices)); }`; - return { - name: glslFunc.name, - inputNames: ['A', 'B'], - inputTypes: [textureType, textureType], - output: {dims: outputShape, type: outputTensorType, textureType}, - shaderSource, - hasMain: usePackedTexture - }; - } - const glsl = getGlsl(handler.session.backend.glContext.version); - const shaderSource = ` + return { + name: glslFunc.name, + inputNames: ['A', 'B'], + inputTypes: [textureType, textureType], + output: { dims: outputShape, type: outputTensorType, textureType }, + shaderSource, + hasMain: usePackedTexture, + }; + } + const glsl = getGlsl(handler.session.backend.glContext.version); + const shaderSource = ` ${glslFunc.body} void main() { vec4 v1 = ${glsl.texture2D}(A, TexCoords); @@ -256,48 +264,60 @@ const createBinaryProgramInfo = } `; - return { - name: glslFunc.name, - inputNames: ['A', 'B'], - inputTypes: [textureType, textureType], - output: {dims: inputs[0].dims, type: outputTensorType, textureType}, - shaderSource, - hasMain: true - }; - }; + return { + name: glslFunc.name, + inputNames: ['A', 'B'], + inputTypes: [textureType, textureType], + output: { dims: inputs[0].dims, type: outputTensorType, textureType }, + shaderSource, + hasMain: true, + }; +}; -export const add = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAdd()), inputs)]; +export const add = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAdd()), inputs), +]; -export const and = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAnd(), 'bool'), inputs)]; +export const and = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslAnd(), 'bool'), inputs), +]; -export const div = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslDiv()), inputs)]; +export const div = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslDiv()), inputs), +]; -export const equal = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslEqual(), 'bool'), inputs)]; +export const equal = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslEqual(), 'bool'), inputs), +]; -export const greater = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslGreater(), 'bool'), inputs)]; +export const greater = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslGreater(), 'bool'), inputs), +]; -export const less = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslLess(), 'bool'), inputs)]; +export const less = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslLess(), 'bool'), inputs), +]; -export const mul = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslMul()), inputs)]; +export const mul = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslMul()), inputs), +]; -export const or = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslOr(), 'bool'), inputs)]; +export const or = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslOr(), 'bool'), inputs), +]; -export const pow = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPow()), inputs)]; +export const pow = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPow()), inputs), +]; -export const pRelu = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPRelu()), inputs)]; +export const pRelu = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslPRelu()), inputs), +]; -export const sub = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslSub()), inputs)]; +export const sub = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslSub()), inputs), +]; -export const xor = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createBinaryProgramInfoLoader(handler, inputs, glslXor(), 'bool'), inputs)]; +export const xor = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createBinaryProgramInfoLoader(handler, inputs, glslXor(), 'bool'), inputs), +]; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/cast.ts b/js/web/lib/onnxjs/backends/webgl/ops/cast.ts index 18d65136ab179..0f5455aa743b9 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/cast.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/cast.ts @@ -1,20 +1,23 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ProtoUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ProtoUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; -export const cast: OperatorImplementation = - (handler: WebGLInferenceHandler, inputs: Tensor[], to: Tensor.DataType): Tensor[] => { - validateInputs(inputs); - return [handler.cast(inputs[0], to)]; - }; +export const cast: OperatorImplementation = ( + handler: WebGLInferenceHandler, + inputs: Tensor[], + to: Tensor.DataType, +): Tensor[] => { + validateInputs(inputs); + return [handler.cast(inputs[0], to)]; +}; export const parseCastAttributes: OperatorInitialization = (node: Graph.Node): Tensor.DataType => - ProtoUtil.tensorDataTypeFromProto(node.attributes.getInt('to')); + ProtoUtil.tensorDataTypeFromProto(node.attributes.getInt('to')); const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { @@ -24,4 +27,4 @@ const validateInputs = (inputs: Tensor[]): void => { if (inputs[0].type === 'string') { throw new Error('Invalid input type.'); } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts b/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts index d0e589a428825..3f5a1a20aa5f8 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts @@ -1,91 +1,95 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; -import {getCoordsDataType, getGlChannels} from '../utils'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; +import { getCoordsDataType, getGlChannels } from '../utils'; -import {ConcatAttributes} from './concat'; -import {getChannels, unpackFromChannel} from './packing-utils'; +import { ConcatAttributes } from './concat'; +import { getChannels, unpackFromChannel } from './packing-utils'; const createPackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({ name: 'Concat (packed)', - inputNames: Array.from({length: inputCount}, (_v, i) => `X${i}`), + inputNames: Array.from({ length: inputCount }, (_v, i) => `X${i}`), inputTypes: Array(inputCount).fill(TextureType.packed), - cacheHint + cacheHint, }); -const createPackedConcatProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); - if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { - throw new Error('axis specified for concat doesn\'t match input dimensionality'); +const createPackedConcatProgramInfo = ( + handler: WebGLInferenceHandler, + metadata: ProgramMetadata, + inputs: Tensor[], + axis: number, +): ProgramInfo => { + const inputShape = inputs[0].dims.slice(); + if (axis >= inputShape.length || axis < -1 * inputShape.length) { + throw new Error("axis specified for concat doesn't match input dimensionality"); + } + if (axis < 0) { + axis = inputShape.length + axis; + } + // ensure all of the non-concatenated axes match each other + // calculate the shape of the output tensor while we do that + const outputShape = inputShape.slice(0); + for (let i = 1; i < inputs.length; i++) { + const dataNShape = inputs[i].dims.slice(); + for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { + // add to the placeholder for computing output shape + if (axisIndex === axis) { + outputShape[axis] += dataNShape[axisIndex]; } - if (axis < 0) { - axis = inputShape.length + axis; - } - // ensure all of the non-concatenated axes match each other - // calculate the shape of the output tensor while we do that - const outputShape = inputShape.slice(0); - for (let i = 1; i < inputs.length; i++) { - const dataNShape = inputs[i].dims.slice(); - for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { - // add to the placeholder for computing output shape - if (axisIndex === axis) { - outputShape[axis] += dataNShape[axisIndex]; - } - // ensure all non-cancatenated axes match each other - else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { - throw new Error('non concat dimensions must match'); - } - } + // ensure all non-cancatenated axes match each other + else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { + throw new Error('non concat dimensions must match'); } + } + } - const rank = outputShape.length; - const coords = getChannels('coords', rank); - const dtype = getCoordsDataType(rank); - const unpackChannel = unpackFromChannel(); + const rank = outputShape.length; + const coords = getChannels('coords', rank); + const dtype = getCoordsDataType(rank); + const unpackChannel = unpackFromChannel(); - const shapes = inputs.map(i => i.dims); - const channels = getGlChannels(rank); - const offsets: number[] = new Array(shapes.length - 1); + const shapes = inputs.map((i) => i.dims); + const channels = getGlChannels(rank); + const offsets: number[] = new Array(shapes.length - 1); - offsets[0] = shapes[0][axis]; - for (let i = 1; i < offsets.length; i++) { - offsets[i] = offsets[i - 1] + shapes[i][axis]; - } + offsets[0] = shapes[0][axis]; + for (let i = 1; i < offsets.length; i++) { + offsets[i] = offsets[i - 1] + shapes[i][axis]; + } - const channel = channels[axis]; - const lastChannels = channels.slice(-2); - const allChannels = channels.join(); + const channel = channels[axis]; + const lastChannels = channels.slice(-2); + const allChannels = channels.join(); - let getValueSnippet = `if (${channel} < ${offsets[0]}) { + let getValueSnippet = `if (${channel} < ${offsets[0]}) { return getChannel( getX0(${allChannels}), vec2(${lastChannels.join()})); }`; - for (let i = 1; i < offsets.length; i++) { - const shift = offsets[i - 1]; - getValueSnippet += ` + for (let i = 1; i < offsets.length; i++) { + const shift = offsets[i - 1]; + getValueSnippet += ` if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) { return getChannel( getX${i}(${getShiftedChannelsSnippet(channels, channel, shift)}), vec2(${getShiftedChannelsSnippet(lastChannels, channel, shift)})); }`; - } - const lastIndex = offsets.length; - const shift = offsets[offsets.length - 1]; - getValueSnippet += ` + } + const lastIndex = offsets.length; + const shift = offsets[offsets.length - 1]; + getValueSnippet += ` return getChannel( getX${lastIndex}(${getShiftedChannelsSnippet(channels, channel, shift)}), vec2(${getShiftedChannelsSnippet(lastChannels, channel, shift)}));`; - const glsl = getGlsl(handler.session.backend.glContext.version); + const glsl = getGlsl(handler.session.backend.glContext.version); - const shaderSource = ` + const shaderSource = ` ${unpackChannel} - float getValue(${channels.map(x => 'int ' + x)}) { + float getValue(${channels.map((x) => 'int ' + x)}) { ${getValueSnippet} } @@ -116,19 +120,22 @@ const createPackedConcatProgramInfo = } `; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.packed}, - shaderSource, - hasMain: true, - }; - }; - -export const createPackedConcatProgramInfoLoader = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ConcatAttributes): ProgramInfoLoader => { - const metadata = createPackedConcatProgramMetadata(inputs.length, attributes.cacheKey); - return {...metadata, get: () => createPackedConcatProgramInfo(handler, metadata, inputs, attributes.axis)}; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed }, + shaderSource, + hasMain: true, + }; +}; + +export const createPackedConcatProgramInfoLoader = ( + handler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ConcatAttributes, +): ProgramInfoLoader => { + const metadata = createPackedConcatProgramMetadata(inputs.length, attributes.cacheKey); + return { ...metadata, get: () => createPackedConcatProgramInfo(handler, metadata, inputs, attributes.axis) }; +}; const getShiftedChannelsSnippet = (channels: string[], channel: string, shift: number): string => { const channelIdx = channels.indexOf(channel); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/concat.ts b/js/web/lib/onnxjs/backends/webgl/ops/concat.ts index f85f4032feae1..8270892920cff 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/concat.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/concat.ts @@ -1,86 +1,97 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; -import {createPackedConcatProgramInfoLoader} from './concat-packed'; +import { createPackedConcatProgramInfoLoader } from './concat-packed'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; } -export const concat: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ConcatAttributes): Tensor[] => { - validateInputs(inputs); - if (inferenceHandler.session.pack && inputs[0].dims.length > 1) { - const output = - inferenceHandler.run(createPackedConcatProgramInfoLoader(inferenceHandler, inputs, attributes), inputs); - return [output]; - } else { - const output = - inferenceHandler.run(createUnpackedConcatProgramInfoLoader(inferenceHandler, inputs, attributes), inputs); - return [output]; - } - }; +export const concat: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ConcatAttributes, +): Tensor[] => { + validateInputs(inputs); + if (inferenceHandler.session.pack && inputs[0].dims.length > 1) { + const output = inferenceHandler.run( + createPackedConcatProgramInfoLoader(inferenceHandler, inputs, attributes), + inputs, + ); + return [output]; + } else { + const output = inferenceHandler.run( + createUnpackedConcatProgramInfoLoader(inferenceHandler, inputs, attributes), + inputs, + ); + return [output]; + } +}; const createUnpackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({ name: 'Concat', - inputNames: Array.from({length: inputCount}, (_v, i) => `X${i}`), + inputNames: Array.from({ length: inputCount }, (_v, i) => `X${i}`), inputTypes: Array(inputCount).fill(TextureType.unpacked), - cacheHint + cacheHint, }); -const createUnpackedConcatProgramInfo = - (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); - if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { - throw new Error('axis specified for concat doesn\'t match input dimensionality'); - } - if (axis < 0) { - axis = inputShape.length + axis; +const createUnpackedConcatProgramInfo = ( + _handler: WebGLInferenceHandler, + metadata: ProgramMetadata, + inputs: Tensor[], + axis: number, +): ProgramInfo => { + const inputShape = inputs[0].dims.slice(); + if (axis >= inputShape.length || axis < -1 * inputShape.length) { + throw new Error("axis specified for concat doesn't match input dimensionality"); + } + if (axis < 0) { + axis = inputShape.length + axis; + } + // ensure all of the non-concatenated axes match each other + // calculate the shape of the output tensor while we do that + const outputShape = inputShape.slice(0); + for (let i = 1; i < inputs.length; i++) { + const dataNShape = inputs[i].dims.slice(); + for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { + // add to the placeholder for computing output shape + if (axisIndex === axis) { + outputShape[axis] += dataNShape[axisIndex]; } - // ensure all of the non-concatenated axes match each other - // calculate the shape of the output tensor while we do that - const outputShape = inputShape.slice(0); - for (let i = 1; i < inputs.length; i++) { - const dataNShape = inputs[i].dims.slice(); - for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { - // add to the placeholder for computing output shape - if (axisIndex === axis) { - outputShape[axis] += dataNShape[axisIndex]; - } - // ensure all non-cancatenated axes match each other - else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { - throw new Error('non concat dimensions must match'); - } - } + // ensure all non-cancatenated axes match each other + else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { + throw new Error('non concat dimensions must match'); } + } + } - const rank = outputShape.length; + const rank = outputShape.length; - const sizeInConcatAxis = new Array(inputs.length); - let previousSum = 0; - for (let i = 0; i < sizeInConcatAxis.length; ++i) { - previousSum += inputs[i].dims[axis]; - sizeInConcatAxis[i] = previousSum; - } + const sizeInConcatAxis = new Array(inputs.length); + let previousSum = 0; + for (let i = 0; i < sizeInConcatAxis.length; ++i) { + previousSum += inputs[i].dims[axis]; + sizeInConcatAxis[i] = previousSum; + } - let getTextureIndexWhereDataResidesMethod = ''; - // in most cases linear search is sufficient, as in most scenarios, only 2 tensors are concatenated - if (inputs.length < 5) { - getTextureIndexWhereDataResidesMethod = getTextureIndexWhereDataResidesLinearSearch(sizeInConcatAxis); - } else { - getTextureIndexWhereDataResidesMethod = getTextureIndexWhereDataResidesBinarySearch(sizeInConcatAxis); - } + let getTextureIndexWhereDataResidesMethod = ''; + // in most cases linear search is sufficient, as in most scenarios, only 2 tensors are concatenated + if (inputs.length < 5) { + getTextureIndexWhereDataResidesMethod = getTextureIndexWhereDataResidesLinearSearch(sizeInConcatAxis); + } else { + getTextureIndexWhereDataResidesMethod = getTextureIndexWhereDataResidesBinarySearch(sizeInConcatAxis); + } - const fetchDataFromCorrectTextureMethod = getFetchDataFromCorrectTextureMethod(inputs.length, rank); - const getSizeInConcatAxisValueFromIndexMethod = getGetSizeInConcatAxisValueFromIndexMethod(sizeInConcatAxis); - const shaderSource = ` + const fetchDataFromCorrectTextureMethod = getFetchDataFromCorrectTextureMethod(inputs.length, rank); + const getSizeInConcatAxisValueFromIndexMethod = getGetSizeInConcatAxisValueFromIndexMethod(sizeInConcatAxis); + const shaderSource = ` ${fetchDataFromCorrectTextureMethod} ${getSizeInConcatAxisValueFromIndexMethod} ${getTextureIndexWhereDataResidesMethod} @@ -93,22 +104,27 @@ const createUnpackedConcatProgramInfo = return fetchDataFromCorrectTexture(textureIndex, indices); }`; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource, - }; - }; - -const createUnpackedConcatProgramInfoLoader = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ConcatAttributes): ProgramInfoLoader => { - const metadata = createUnpackedConcatProgramMetadata(inputs.length, attributes.cacheKey); - return {...metadata, get: () => createUnpackedConcatProgramInfo(handler, metadata, inputs, attributes.axis)}; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; + +const createUnpackedConcatProgramInfoLoader = ( + handler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ConcatAttributes, +): ProgramInfoLoader => { + const metadata = createUnpackedConcatProgramMetadata(inputs.length, attributes.cacheKey); + return { ...metadata, get: () => createUnpackedConcatProgramInfo(handler, metadata, inputs, attributes.axis) }; +}; const getTextureIndexWhereDataResidesLinearSearch = (sizeInConcatAxis: number[]): string => { - const searchAxis = sizeInConcatAxis.map((size, i) => `if(index<${size}) {return ${i};} -`); + const searchAxis = sizeInConcatAxis.map( + (size, i) => `if(index<${size}) {return ${i};} +`, + ); return `int getTextureWhereDataResides(int index) { ${searchAxis.join('')} }`; @@ -116,28 +132,20 @@ const getTextureIndexWhereDataResidesLinearSearch = (sizeInConcatAxis: number[]) // TODO: Implement BinarySearch in GLSL const getTextureIndexWhereDataResidesBinarySearch = (sizeInConcatAxis: number[]): string => - getTextureIndexWhereDataResidesLinearSearch(sizeInConcatAxis); + getTextureIndexWhereDataResidesLinearSearch(sizeInConcatAxis); const getFetchDataFromCorrectTextureMethod = (numberOfTensors: number, tensorRank: number) => { const codeLines: string[] = [`float fetchDataFromCorrectTexture(int textureIndex, int indices[${tensorRank}]) {`]; for (let i = 0; i < numberOfTensors; ++i) { if (i === 0) { - codeLines.push( - '\t' + - `if (textureIndex == ${i}) { return _X${i}(indices); }`); + codeLines.push('\t' + `if (textureIndex == ${i}) { return _X${i}(indices); }`); } else if (i === numberOfTensors - 1) { - codeLines.push( - '\t' + - `else { return _X${i}(indices); }`); + codeLines.push('\t' + `else { return _X${i}(indices); }`); } else { - codeLines.push( - '\t' + - `else if (textureIndex == ${i}) { return _X${i}(indices); }`); + codeLines.push('\t' + `else if (textureIndex == ${i}) { return _X${i}(indices); }`); } } - codeLines.push( - '\t' + - '}'); + codeLines.push('\t' + '}'); return codeLines.join('\n'); }; @@ -145,28 +153,20 @@ const getGetSizeInConcatAxisValueFromIndexMethod = (sizeInConcatAxis: number[]): const codeLines: string[] = ['int getSizeInConcatAxisValueFromIndex(int index) {']; for (let i = 0; i < sizeInConcatAxis.length; ++i) { if (i === 0) { - codeLines.push( - '\t' + - `if (index == ${i}) { return ${sizeInConcatAxis[i]}; }`); + codeLines.push('\t' + `if (index == ${i}) { return ${sizeInConcatAxis[i]}; }`); } else if (i === sizeInConcatAxis.length - 1) { - codeLines.push( - '\t' + - `else { return ${sizeInConcatAxis[i]}; }`); + codeLines.push('\t' + `else { return ${sizeInConcatAxis[i]}; }`); } else { - codeLines.push( - '\t' + - `else if (index == ${i}) { return ${sizeInConcatAxis[i]}; }`); + codeLines.push('\t' + `else if (index == ${i}) { return ${sizeInConcatAxis[i]}; }`); } } - codeLines.push( - '\t' + - '}'); + codeLines.push('\t' + '}'); return codeLines.join('\n'); }; export const parseConcatAttributes: OperatorInitialization = (node: Graph.Node): ConcatAttributes => - createAttributeWithCacheKey({axis: node.attributes.getInt('axis')}); + createAttributeWithCacheKey({ axis: node.attributes.getInt('axis') }); const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length < 1) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv-grouped.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv-grouped.ts index 1d3a7173f590e..3d39ad2892ddc 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv-grouped.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv-grouped.ts @@ -1,41 +1,46 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Logger} from '../../../instrument'; -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; +import { Logger } from '../../../instrument'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; -import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActivationSnippet} from './fuse-utils'; +import { calculateOutputShape, ConvAttributes } from './conv'; +import { getActivationSnippet } from './fuse-utils'; const createUnpackedGroupedConvProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ name: 'GroupedConv', inputNames: hasBias ? ['X', 'W', 'Bias'] : ['X', 'W'], - inputTypes: hasBias ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] : - [TextureType.unpacked, TextureType.unpacked], - cacheHint + inputTypes: hasBias + ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] + : [TextureType.unpacked, TextureType.unpacked], + cacheHint, }); -const createUnpackedGroupedConvProgramInfo = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], metadata: ProgramMetadata, - attributes: ConvAttributes): ProgramInfo => { - const hasBias = inputs.length > 2; - const processBias = hasBias ? 'value += getBias(output_channel);' : ''; - const xShape = inputs[0].dims.slice(); - const wShape = inputs[1].dims.slice(); - const outputChannelsPerGroup = wShape[0] / attributes.group; - Logger.verbose( - 'GroupedConv', - `autpPad:${attributes.autoPad}, dilations:${attributes.dilations}, group:${attributes.group}, kernelShape:${ - attributes.kernelShape}, pads:${attributes.pads}, strides:${attributes.strides}`); - const outputShape = - calculateOutputShape(xShape, wShape, attributes.dilations, attributes.pads, attributes.strides); - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const {activationFunction, applyActivation} = getActivationSnippet(attributes); +const createUnpackedGroupedConvProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + metadata: ProgramMetadata, + attributes: ConvAttributes, +): ProgramInfo => { + const hasBias = inputs.length > 2; + const processBias = hasBias ? 'value += getBias(output_channel);' : ''; + const xShape = inputs[0].dims.slice(); + const wShape = inputs[1].dims.slice(); + const outputChannelsPerGroup = wShape[0] / attributes.group; + Logger.verbose( + 'GroupedConv', + `autpPad:${attributes.autoPad}, dilations:${attributes.dilations}, group:${attributes.group}, kernelShape:${ + attributes.kernelShape + }, pads:${attributes.pads}, strides:${attributes.strides}`, + ); + const outputShape = calculateOutputShape(xShape, wShape, attributes.dilations, attributes.pads, attributes.strides); + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const { activationFunction, applyActivation } = getActivationSnippet(attributes); - const shaderSource = ` + const shaderSource = ` const ivec2 strides = ivec2(${attributes.strides[0]}, ${attributes.strides[1]}); const ivec2 pads = ivec2(${attributes.pads[0]}, ${attributes.pads[1]}); ${activationFunction} @@ -73,20 +78,22 @@ const createUnpackedGroupedConvProgramInfo = ${glsl.output} = vec4(value, .0, .0, .0); } `; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource, - hasMain: true, - }; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + hasMain: true, + }; +}; -export const createUnpackedGroupedConvProgramInfoLoader = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes): - ProgramInfoLoader => { - const metadata = createUnpackedGroupedConvProgramMetadata(inputs.length > 2, attributes.cacheKey); - return { - ...metadata, - get: () => createUnpackedGroupedConvProgramInfo(inferenceHandler, inputs, metadata, attributes) - }; - }; +export const createUnpackedGroupedConvProgramInfoLoader = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + attributes: ConvAttributes, +): ProgramInfoLoader => { + const metadata = createUnpackedGroupedConvProgramMetadata(inputs.length > 2, attributes.cacheKey); + return { + ...metadata, + get: () => createUnpackedGroupedConvProgramInfo(inferenceHandler, inputs, metadata, attributes), + }; +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts index 3fade9890e06a..e5d71affd2e29 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts @@ -1,50 +1,58 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {WebGLInferenceHandler} from '../inference-handler'; - -import {calculateOutputShape, ConvAttributes} from './conv'; -import {createPackedIm2ColProgramInfoLoader} from './im2col-pack'; -import {createPackedMatmulProgramInfoLoader} from './matmul-pack'; - -export const conv2DPackedPointwise = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes): Tensor => { - const xshape = inputs[0].dims; - const kshape = inputs[1].dims; - const outputShape = - calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); - const reshapedX = inferenceHandler.reshapePacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]); - const reshapedK = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1]]); - - const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX]; - const matmulOutput = inferenceHandler.run( - createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes), matmulInputs); - return inferenceHandler.reshapePacked(matmulOutput, outputShape); - }; - -export const conv2DPacked = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes): Tensor => { - const xshape = inputs[0].dims; - const kshape = inputs[1].dims; - const outputShape = - calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); - - // run im2col - const im2colOutput = inferenceHandler.run( - createPackedIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes), - [inputs[0]]); - - // reshape kernel - const kernelReshaped = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1] * kshape[2] * kshape[3]]); - - // run matmul - const matmulInputs = - (inputs.length === 3) ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]; - const matmulOutput = inferenceHandler.run( - createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes), matmulInputs); - - // reshape output - const outputReshaped = inferenceHandler.reshapePacked(matmulOutput, outputShape); - return outputReshaped; - }; +import { Tensor } from '../../../tensor'; +import { WebGLInferenceHandler } from '../inference-handler'; + +import { calculateOutputShape, ConvAttributes } from './conv'; +import { createPackedIm2ColProgramInfoLoader } from './im2col-pack'; +import { createPackedMatmulProgramInfoLoader } from './matmul-pack'; + +export const conv2DPackedPointwise = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + attributes: ConvAttributes, +): Tensor => { + const xshape = inputs[0].dims; + const kshape = inputs[1].dims; + const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); + const reshapedX = inferenceHandler.reshapePacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]); + const reshapedK = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1]]); + + const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX]; + const matmulOutput = inferenceHandler.run( + createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes), + matmulInputs, + ); + return inferenceHandler.reshapePacked(matmulOutput, outputShape); +}; + +export const conv2DPacked = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + attributes: ConvAttributes, +): Tensor => { + const xshape = inputs[0].dims; + const kshape = inputs[1].dims; + const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); + + // run im2col + const im2colOutput = inferenceHandler.run( + createPackedIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes), + [inputs[0]], + ); + + // reshape kernel + const kernelReshaped = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1] * kshape[2] * kshape[3]]); + + // run matmul + const matmulInputs = inputs.length === 3 ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]; + const matmulOutput = inferenceHandler.run( + createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes), + matmulInputs, + ); + + // reshape output + const outputReshaped = inferenceHandler.reshapePacked(matmulOutput, outputShape); + return outputReshaped; +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv-transpose.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv-transpose.ts index 0da1d64871314..345842ce8c928 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv-transpose.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv-transpose.ts @@ -1,21 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {InferenceHandler} from '../../../backend'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; - -import {ConvAttributes} from './conv'; -import {getActivationSnippet, parseInternalActivationAttributes} from './fuse-utils'; - -const computeTotalPad = - (inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => - (inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize; +import { createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { InferenceHandler } from '../../../backend'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; + +import { ConvAttributes } from './conv'; +import { getActivationSnippet, parseInternalActivationAttributes } from './fuse-utils'; + +const computeTotalPad = ( + inDim: number, + stride: number, + adj: number, + kernel: number, + dilation: number, + outSize: number, +) => (inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize; const distributePadding = (totalPad: number, autoPad: string, pads: number[], head: number, tail: number) => { const smallPad = Math.floor(totalPad / 2); @@ -28,62 +33,84 @@ const distributePadding = (totalPad: number, autoPad: string, pads: number[], he } }; -const calculateOutputShapeAndPads = - (inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], autoPad: string, - pads: number[], strides: readonly number[], outputPadding: readonly number[], outputShape: number[]) => { - const spatialRank = inputShape.length - 2; - const updateShape = outputShape.length === 0; - for (let i = 0; i < spatialRank; ++i) { - const outSize = updateShape ? inputShape[i + 2] * strides[i] : outputShape[i]; - const totalPad = computeTotalPad(inputShape[i + 2], strides[i], pads[i], kernelShape[i], dilations[i], outSize); - distributePadding(totalPad, autoPad, pads, i, i + spatialRank); - if (updateShape) { - outputShape.push( - strides[i] * (inputShape[i + 2] - 1) + outputPadding[i] + (kernelShape[i] - 1) * dilations[i] + 1 - - pads[i] - pads[i + spatialRank]); - } - } - }; +const calculateOutputShapeAndPads = ( + inputShape: readonly number[], + kernelShape: readonly number[], + dilations: readonly number[], + autoPad: string, + pads: number[], + strides: readonly number[], + outputPadding: readonly number[], + outputShape: number[], +) => { + const spatialRank = inputShape.length - 2; + const updateShape = outputShape.length === 0; + for (let i = 0; i < spatialRank; ++i) { + const outSize = updateShape ? inputShape[i + 2] * strides[i] : outputShape[i]; + const totalPad = computeTotalPad(inputShape[i + 2], strides[i], pads[i], kernelShape[i], dilations[i], outSize); + distributePadding(totalPad, autoPad, pads, i, i + spatialRank); + if (updateShape) { + outputShape.push( + strides[i] * (inputShape[i + 2] - 1) + + outputPadding[i] + + (kernelShape[i] - 1) * dilations[i] + + 1 - + pads[i] - + pads[i + spatialRank], + ); + } + } +}; export interface ConvTransposeAttributes extends ConvAttributes { readonly outputPadding: readonly number[]; readonly outputShape: readonly number[]; } -export const convTranspose: OperatorImplementation = - (inferenceHandler: InferenceHandler, inputs: Tensor[], attributes: ConvTransposeAttributes): Tensor[] => { - validateInputs(inputs, attributes); // currently will fail if not convTranspose2D - return convTranspose2d(inferenceHandler, inputs, attributes); - }; +export const convTranspose: OperatorImplementation = ( + inferenceHandler: InferenceHandler, + inputs: Tensor[], + attributes: ConvTransposeAttributes, +): Tensor[] => { + validateInputs(inputs, attributes); // currently will fail if not convTranspose2D + return convTranspose2d(inferenceHandler, inputs, attributes); +}; -const convTranspose2d: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ConvTransposeAttributes): Tensor[] => { - const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); - return [convTranspose2DUnpacked(inferenceHandler, inputs, adjustedAttributes)]; - }; +const convTranspose2d: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ConvTransposeAttributes, +): Tensor[] => { + const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); + return [convTranspose2DUnpacked(inferenceHandler, inputs, adjustedAttributes)]; +}; const createConvTransposeProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ name: 'ConvTranspose', inputNames: hasBias ? ['X', 'W', 'B'] : ['X', 'W'], - inputTypes: hasBias ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] : - [TextureType.unpacked, TextureType.unpacked], - cacheHint + inputTypes: hasBias + ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] + : [TextureType.unpacked, TextureType.unpacked], + cacheHint, }); -const createUnpackedConvTransposeProgramInfo = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], metadata: ProgramMetadata, - attributes: ConvTransposeAttributes): ProgramInfo => { - const hasBias = inputs.length > 2; - const valueInit = hasBias ? 'getB(output_channel)' : '0.0'; - const xShape = inputs[0].dims; - const wShape = inputs[1].dims; - const outputChannelsPerGroup = wShape[1]; - const inputChannelsPerGroup = wShape[0] / attributes.group; - const outputShape = [inputs[0].dims[0], inputs[1].dims[1] * attributes.group, ...attributes.outputShape]; - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const {activationFunction, applyActivation} = getActivationSnippet(attributes); - - const shaderSource = ` +const createUnpackedConvTransposeProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + metadata: ProgramMetadata, + attributes: ConvTransposeAttributes, +): ProgramInfo => { + const hasBias = inputs.length > 2; + const valueInit = hasBias ? 'getB(output_channel)' : '0.0'; + const xShape = inputs[0].dims; + const wShape = inputs[1].dims; + const outputChannelsPerGroup = wShape[1]; + const inputChannelsPerGroup = wShape[0] / attributes.group; + const outputShape = [inputs[0].dims[0], inputs[1].dims[1] * attributes.group, ...attributes.outputShape]; + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const { activationFunction, applyActivation } = getActivationSnippet(attributes); + + const shaderSource = ` const ivec2 strides = ivec2(${attributes.strides[0]}, ${attributes.strides[1]}); const ivec2 pads = ivec2(${attributes.pads[0]}, ${attributes.pads[1]}); ${activationFunction} @@ -121,32 +148,37 @@ const createUnpackedConvTransposeProgramInfo = ${glsl.output} = vec4(value, .0, .0, .0); } `; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource, - hasMain: true, - }; - }; - -const createUnpackedConvTransposeProgramInfoLoader = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvTransposeAttributes): - ProgramInfoLoader => { - const metadata = createConvTransposeProgramMetadata(inputs.length > 2, attributes.cacheKey); - return { - ...metadata, - get: () => createUnpackedConvTransposeProgramInfo(inferenceHandler, inputs, metadata, attributes) - }; - }; - - -const convTranspose2DUnpacked = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvTransposeAttributes): - Tensor => { - const result = inferenceHandler.run( - createUnpackedConvTransposeProgramInfoLoader(inferenceHandler, inputs, attributes), inputs); - return result; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + hasMain: true, + }; +}; + +const createUnpackedConvTransposeProgramInfoLoader = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + attributes: ConvTransposeAttributes, +): ProgramInfoLoader => { + const metadata = createConvTransposeProgramMetadata(inputs.length > 2, attributes.cacheKey); + return { + ...metadata, + get: () => createUnpackedConvTransposeProgramInfo(inferenceHandler, inputs, metadata, attributes), + }; +}; + +const convTranspose2DUnpacked = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + attributes: ConvTransposeAttributes, +): Tensor => { + const result = inferenceHandler.run( + createUnpackedConvTransposeProgramInfoLoader(inferenceHandler, inputs, attributes), + inputs, + ); + return result; +}; const getAdjustedConvTransposeAttributes = (attributes: T, inputs: Tensor[]): T => { const kernelShape = attributes.kernelShape.slice(); @@ -163,32 +195,49 @@ const getAdjustedConvTransposeAttributes = (a // If outputShape is not specified in the attributes of this op, infer it from the parameters // Similarly, automatically infer pads if not specified calculateOutputShapeAndPads( - inputShape, kernelShape, attributes.dilations, attributes.autoPad, pads, attributes.strides, - attributes.outputPadding, outputShape); + inputShape, + kernelShape, + attributes.dilations, + attributes.autoPad, + pads, + attributes.strides, + attributes.outputPadding, + outputShape, + ); // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads, outputShape, cacheKey: attributes.cacheKey}); + Object.assign(newAttributes, { kernelShape, pads, outputShape, cacheKey: attributes.cacheKey }); return newAttributes; }; -export const parseConvTransposeAttributes: OperatorInitialization = - (node: Graph.Node): ConvTransposeAttributes => { - const attributes = node.attributes; - const activationAttributes = parseInternalActivationAttributes(attributes); - // TODO : Make this generic enough to compute default attributes for multi-dimensional conv - const autoPad = attributes.getString('auto_pad', 'NOTSET'); - const dilations = attributes.getInts('dilations', [1, 1]); - const group = attributes.getInt('group', 1); - const kernelShape = attributes.getInts('kernel_shape', []); - const outputPadding = attributes.getInts('output_padding', [0, 0]); - const outputShape = attributes.getInts('output_shape', []); - const pads = attributes.getInts('pads', [0, 0, 0, 0]); - const strides = attributes.getInts('strides', [1, 1]); - - return createAttributeWithCacheKey( - {autoPad, dilations, group, kernelShape, outputPadding, outputShape, pads, strides, ...activationAttributes}); - }; +export const parseConvTransposeAttributes: OperatorInitialization = ( + node: Graph.Node, +): ConvTransposeAttributes => { + const attributes = node.attributes; + const activationAttributes = parseInternalActivationAttributes(attributes); + // TODO : Make this generic enough to compute default attributes for multi-dimensional conv + const autoPad = attributes.getString('auto_pad', 'NOTSET'); + const dilations = attributes.getInts('dilations', [1, 1]); + const group = attributes.getInt('group', 1); + const kernelShape = attributes.getInts('kernel_shape', []); + const outputPadding = attributes.getInts('output_padding', [0, 0]); + const outputShape = attributes.getInts('output_shape', []); + const pads = attributes.getInts('pads', [0, 0, 0, 0]); + const strides = attributes.getInts('strides', [1, 1]); + + return createAttributeWithCacheKey({ + autoPad, + dilations, + group, + kernelShape, + outputPadding, + outputShape, + pads, + strides, + ...activationAttributes, + }); +}; const validateInputs = (inputs: Tensor[], attributes: ConvTransposeAttributes): void => { // Refer to the below link for all input checks diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts index ea623f5c4dbbc..3cba1439049a4 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts @@ -1,37 +1,41 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {InferenceHandler} from '../../../backend'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {PoolConvUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; - -import {createUnpackedGroupedConvProgramInfoLoader} from './conv-grouped'; -import {conv2DPacked} from './conv-pack'; -import {createDotProductProgramInfoLoader} from './dot-product'; -import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; -import {createIm2ColProgramInfoLoader} from './im2col'; -import {createMatmulProgramInfoLoader} from './matmul'; - - -export const calculateOutputShape = - (inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], - adjustPads: readonly number[], strides: readonly number[]): number[] => { - const batchSize = inputShape[0]; - const inputSpatialShape = inputShape.slice(2); - const spatialRank = inputSpatialShape.length; - const outChannels = kernelShape[0]; - const kernelSpatialShape = kernelShape.slice(2); - const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1)); - const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]); - const outputSpatialShape = - inputSpatialShapeWithPad.map((v, i) => Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i])); - const outputShape = [batchSize, outChannels].concat(...outputSpatialShape); - return outputShape; - }; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { InferenceHandler } from '../../../backend'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { PoolConvUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; + +import { createUnpackedGroupedConvProgramInfoLoader } from './conv-grouped'; +import { conv2DPacked } from './conv-pack'; +import { createDotProductProgramInfoLoader } from './dot-product'; +import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils'; +import { createIm2ColProgramInfoLoader } from './im2col'; +import { createMatmulProgramInfoLoader } from './matmul'; + +export const calculateOutputShape = ( + inputShape: readonly number[], + kernelShape: readonly number[], + dilations: readonly number[], + adjustPads: readonly number[], + strides: readonly number[], +): number[] => { + const batchSize = inputShape[0]; + const inputSpatialShape = inputShape.slice(2); + const spatialRank = inputSpatialShape.length; + const outChannels = kernelShape[0]; + const kernelSpatialShape = kernelShape.slice(2); + const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1)); + const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]); + const outputSpatialShape = inputSpatialShapeWithPad.map((v, i) => + Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]), + ); + const outputShape = [batchSize, outChannels].concat(...outputSpatialShape); + return outputShape; +}; export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey { readonly autoPad: string; @@ -42,58 +46,74 @@ export interface ConvAttributes extends InternalActivationAttributes, AttributeW readonly strides: readonly number[]; } -export const conv: OperatorImplementation = - (inferenceHandler: InferenceHandler, inputs: Tensor[], attributes: ConvAttributes): Tensor[] => { - validateInputs(inputs, attributes); // currently will fail if not conv2D - return conv2d(inferenceHandler, inputs, attributes); - }; - -const conv2d: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ConvAttributes): Tensor[] => { - const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs); - const packMode = inferenceHandler.session.pack; - const isPointwise = adjustedAttributes.kernelShape[0] === 1 && adjustedAttributes.kernelShape[1] === 1; - if (adjustedAttributes.group > 1) { - const result = inferenceHandler.run( - createUnpackedGroupedConvProgramInfoLoader(inferenceHandler, inputs, adjustedAttributes), inputs); - return [result]; - } else if (isPointwise && packMode) { - return [conv2DUnpackedPointwise(inferenceHandler, inputs, adjustedAttributes)]; - } else if (packMode && inputs[0].dims.length === 4 && inputs[0].dims[0] === 1 && !isPointwise) { - return [conv2DPacked(inferenceHandler, inputs, adjustedAttributes)]; - } else { - return [conv2DUnpacked(inferenceHandler, inputs, adjustedAttributes)]; - } - }; - -const conv2DUnpackedPointwise = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes): Tensor => { - const xshape = inputs[0].dims; - const kshape = inputs[1].dims; - const outputShape = - calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); - const reshapedX = inferenceHandler.reshapeUnpacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]); - const reshapedK = inferenceHandler.reshapeUnpacked(inputs[1], [kshape[0], kshape[1]]); - - const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX]; - const matmulOutput = inferenceHandler.run(createMatmulProgramInfoLoader(matmulInputs, attributes), matmulInputs); - return inferenceHandler.reshapeUnpacked(matmulOutput, outputShape); - }; - -const conv2DUnpacked = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes): Tensor => { - const xshape = inputs[0].dims; - const kshape = inputs[1].dims; - const outputShape = - calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); - const xIm2Col = inferenceHandler.run( - createIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes), [inputs[0]]); - - const dotProductInputs = inputs.length === 3 ? [xIm2Col, inputs[1], inputs[2]] : [xIm2Col, inputs[1]]; - const output = inferenceHandler.run( - createDotProductProgramInfoLoader(inferenceHandler, inputs, outputShape, attributes), dotProductInputs); - return output; - }; +export const conv: OperatorImplementation = ( + inferenceHandler: InferenceHandler, + inputs: Tensor[], + attributes: ConvAttributes, +): Tensor[] => { + validateInputs(inputs, attributes); // currently will fail if not conv2D + return conv2d(inferenceHandler, inputs, attributes); +}; + +const conv2d: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ConvAttributes, +): Tensor[] => { + const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs); + const packMode = inferenceHandler.session.pack; + const isPointwise = adjustedAttributes.kernelShape[0] === 1 && adjustedAttributes.kernelShape[1] === 1; + if (adjustedAttributes.group > 1) { + const result = inferenceHandler.run( + createUnpackedGroupedConvProgramInfoLoader(inferenceHandler, inputs, adjustedAttributes), + inputs, + ); + return [result]; + } else if (isPointwise && packMode) { + return [conv2DUnpackedPointwise(inferenceHandler, inputs, adjustedAttributes)]; + } else if (packMode && inputs[0].dims.length === 4 && inputs[0].dims[0] === 1 && !isPointwise) { + return [conv2DPacked(inferenceHandler, inputs, adjustedAttributes)]; + } else { + return [conv2DUnpacked(inferenceHandler, inputs, adjustedAttributes)]; + } +}; + +const conv2DUnpackedPointwise = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + attributes: ConvAttributes, +): Tensor => { + const xshape = inputs[0].dims; + const kshape = inputs[1].dims; + const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); + const reshapedX = inferenceHandler.reshapeUnpacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]); + const reshapedK = inferenceHandler.reshapeUnpacked(inputs[1], [kshape[0], kshape[1]]); + + const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX]; + const matmulOutput = inferenceHandler.run(createMatmulProgramInfoLoader(matmulInputs, attributes), matmulInputs); + return inferenceHandler.reshapeUnpacked(matmulOutput, outputShape); +}; + +const conv2DUnpacked = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + attributes: ConvAttributes, +): Tensor => { + const xshape = inputs[0].dims; + const kshape = inputs[1].dims; + const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides); + const xIm2Col = inferenceHandler.run( + createIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes), + [inputs[0]], + ); + + const dotProductInputs = inputs.length === 3 ? [xIm2Col, inputs[1], inputs[2]] : [xIm2Col, inputs[1]]; + const output = inferenceHandler.run( + createDotProductProgramInfoLoader(inferenceHandler, inputs, outputShape, attributes), + dotProductInputs, + ); + return output; +}; const getAdjustedConvAttributes = (attributes: T, inputs: Tensor[]): T => { const kernelShape = attributes.kernelShape.slice(); @@ -105,11 +125,17 @@ const getAdjustedConvAttributes = (attributes: T, inpu } const pads = attributes.pads.slice(); PoolConvUtil.adjustPadsBasedOnAutoPad( - inputs[0].dims, attributes.strides, attributes.dilations, kernelShape, pads, attributes.autoPad); + inputs[0].dims, + attributes.strides, + attributes.dilations, + kernelShape, + pads, + attributes.autoPad, + ); // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); + Object.assign(newAttributes, { kernelShape, pads, cacheKey: attributes.cacheKey }); return newAttributes; }; @@ -124,7 +150,15 @@ export const parseConvAttributes: OperatorInitialization = (node const pads = attributes.getInts('pads', [0, 0, 0, 0]); const strides = attributes.getInts('strides', [1, 1]); - return createAttributeWithCacheKey({autoPad, dilations, group, kernelShape, pads, strides, ...activationAttributes}); + return createAttributeWithCacheKey({ + autoPad, + dilations, + group, + kernelShape, + pads, + strides, + ...activationAttributes, + }); }; const validateInputs = (inputs: Tensor[], attributes: ConvAttributes): void => { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/depth-to-space.ts b/js/web/lib/onnxjs/backends/webgl/ops/depth-to-space.ts index 3073fef3f2c60..4d0a3532514bc 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/depth-to-space.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/depth-to-space.ts @@ -1,68 +1,83 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {WebGLInferenceHandler} from '../inference-handler'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { WebGLInferenceHandler } from '../inference-handler'; -import {transpose, TransposeAttributes} from './transpose'; +import { transpose, TransposeAttributes } from './transpose'; export interface DepthToSpaceAttributes { - mode: 'DCR'|'CRD'; + mode: 'DCR' | 'CRD'; blocksize: number; } -export const depthToSpace: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: DepthToSpaceAttributes): Tensor[] => { - validateInputs(inputs); - const blocksize = attributes.blocksize; - const blocksizeSqr = blocksize * blocksize; - const transposePerm = attributes.mode === 'DCR' ? [0, 3, 4, 1, 5, 2] : [0, 1, 4, 2, 5, 3]; - const firstReshapeShape = attributes.mode === 'DCR' ? - [ - inputs[0].dims[0], blocksize, blocksize, inputs[0].dims[1] / blocksizeSqr, inputs[0].dims[2], - inputs[0].dims[3] - ] : - [ - inputs[0].dims[0], inputs[0].dims[1] / blocksizeSqr, blocksize, blocksize, inputs[0].dims[2], - inputs[0].dims[3] - ]; +export const depthToSpace: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: DepthToSpaceAttributes, +): Tensor[] => { + validateInputs(inputs); + const blocksize = attributes.blocksize; + const blocksizeSqr = blocksize * blocksize; + const transposePerm = attributes.mode === 'DCR' ? [0, 3, 4, 1, 5, 2] : [0, 1, 4, 2, 5, 3]; + const firstReshapeShape = + attributes.mode === 'DCR' + ? [ + inputs[0].dims[0], + blocksize, + blocksize, + inputs[0].dims[1] / blocksizeSqr, + inputs[0].dims[2], + inputs[0].dims[3], + ] + : [ + inputs[0].dims[0], + inputs[0].dims[1] / blocksizeSqr, + blocksize, + blocksize, + inputs[0].dims[2], + inputs[0].dims[3], + ]; - // const transpose = new WebGLTranspose(); - // const attributes = new Attribute(undefined); - // attributes.set('perm', 'ints', transposePerm); - // transpose.initialize(attributes); + // const transpose = new WebGLTranspose(); + // const attributes = new Attribute(undefined); + // attributes.set('perm', 'ints', transposePerm); + // transpose.initialize(attributes); - // First reshape - const firstReshapedTensor = inferenceHandler.reshapeUnpacked(inputs[0], firstReshapeShape); + // First reshape + const firstReshapedTensor = inferenceHandler.reshapeUnpacked(inputs[0], firstReshapeShape); - // transpose - const transposeAttributes: TransposeAttributes = {perm: transposePerm, cacheKey: `${transposePerm}`}; - const [transposeOutput] = transpose(inferenceHandler, [firstReshapedTensor], transposeAttributes); + // transpose + const transposeAttributes: TransposeAttributes = { perm: transposePerm, cacheKey: `${transposePerm}` }; + const [transposeOutput] = transpose(inferenceHandler, [firstReshapedTensor], transposeAttributes); - // Second reshape - const secondReshapeShape = [ - inputs[0].dims[0], inputs[0].dims[1] / blocksizeSqr, inputs[0].dims[2] * blocksize, - inputs[0].dims[3] * blocksize - ]; - const result = inferenceHandler.reshapeUnpacked(transposeOutput, secondReshapeShape); - return [result]; - }; + // Second reshape + const secondReshapeShape = [ + inputs[0].dims[0], + inputs[0].dims[1] / blocksizeSqr, + inputs[0].dims[2] * blocksize, + inputs[0].dims[3] * blocksize, + ]; + const result = inferenceHandler.reshapeUnpacked(transposeOutput, secondReshapeShape); + return [result]; +}; -export const parseDepthToSpaceAttributes: OperatorInitialization = - (node: Graph.Node): DepthToSpaceAttributes => { - // processing node attributes - const blocksize = node.attributes.getInt('blocksize'); - if (blocksize < 1) { - throw new Error(`blocksize must be >= 1, but got : ${blocksize} for DepthToSpace`); - } - const mode = node.attributes.getString('mode', 'DCR'); - if (mode !== 'DCR' && mode !== 'CRD') { - throw new Error(`unrecognized mode: ${mode} for DepthToSpace`); - } - return {mode, blocksize}; - }; +export const parseDepthToSpaceAttributes: OperatorInitialization = ( + node: Graph.Node, +): DepthToSpaceAttributes => { + // processing node attributes + const blocksize = node.attributes.getInt('blocksize'); + if (blocksize < 1) { + throw new Error(`blocksize must be >= 1, but got : ${blocksize} for DepthToSpace`); + } + const mode = node.attributes.getString('mode', 'DCR'); + if (mode !== 'DCR' && mode !== 'CRD') { + throw new Error(`unrecognized mode: ${mode} for DepthToSpace`); + } + return { mode, blocksize }; +}; const validateInputs = (inputs: Tensor[]): void => { if (inputs.length !== 1) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/dot-product.ts b/js/web/lib/onnxjs/backends/webgl/ops/dot-product.ts index 612c77c34a605..ddbb52fef7b38 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/dot-product.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/dot-product.ts @@ -1,43 +1,52 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; -import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; -import {calculateIm2ColDims} from './im2col'; +import { getActivationSnippet, InternalActivationAttributes } from './fuse-utils'; +import { calculateIm2ColDims } from './im2col'; const createDotProductProgramMetadata = (hasBias: boolean, attributes: InternalActivationAttributes) => ({ name: 'ConvDotProduct', inputNames: hasBias ? ['Im2Col', 'K', 'B'] : ['Im2Col', 'K'], - inputTypes: hasBias ? [TextureType.unpacked, TextureType.packedLastDimension, TextureType.unpacked] : - [TextureType.unpacked, TextureType.packedLastDimension], - cacheKey: attributes.activationCacheKey + inputTypes: hasBias + ? [TextureType.unpacked, TextureType.packedLastDimension, TextureType.unpacked] + : [TextureType.unpacked, TextureType.packedLastDimension], + cacheKey: attributes.activationCacheKey, }); -const createDotProductProgramInfo = - (inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: readonly Tensor[], - outputShape: number[], attributes: InternalActivationAttributes): ProgramInfo => { - const xshape = inputs[0].dims; - const kshape = inputs[1].dims; - const adjustedKernelShape = [kshape[0], Math.ceil((xshape[1] * kshape[2] * kshape[3]) / 4)]; - const im2colShape = calculateIm2ColDims(xshape, kshape, outputShape); - const [kWidth, kHeight] = - inferenceHandler.calculateTextureWidthAndHeight(adjustedKernelShape, TextureType.packedLastDimension); +const createDotProductProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + metadata: ProgramMetadata, + inputs: readonly Tensor[], + outputShape: number[], + attributes: InternalActivationAttributes, +): ProgramInfo => { + const xshape = inputs[0].dims; + const kshape = inputs[1].dims; + const adjustedKernelShape = [kshape[0], Math.ceil((xshape[1] * kshape[2] * kshape[3]) / 4)]; + const im2colShape = calculateIm2ColDims(xshape, kshape, outputShape); + const [kWidth, kHeight] = inferenceHandler.calculateTextureWidthAndHeight( + adjustedKernelShape, + TextureType.packedLastDimension, + ); - const im2colStrides = ShapeUtil.computeStrides(im2colShape); - const [im2colWidth, im2colHeight] = - inferenceHandler.calculateTextureWidthAndHeight(im2colShape, TextureType.packedLastDimension); - const rank = outputShape.length; + const im2colStrides = ShapeUtil.computeStrides(im2colShape); + const [im2colWidth, im2colHeight] = inferenceHandler.calculateTextureWidthAndHeight( + im2colShape, + TextureType.packedLastDimension, + ); + const rank = outputShape.length; - const initValue = (inputs.length < 3) ? '0.0' : '_B(b)'; - const sharedDim = Math.ceil(xshape[1] * kshape[2] * kshape[3] / 4); - const {activationFunction, applyActivation} = getActivationSnippet(attributes); - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const shaderSource = ` + const initValue = inputs.length < 3 ? '0.0' : '_B(b)'; + const sharedDim = Math.ceil((xshape[1] * kshape[2] * kshape[3]) / 4); + const { activationFunction, applyActivation } = getActivationSnippet(attributes); + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const shaderSource = ` ${activationFunction} float process(int indices[${rank}]) { int b[1]; @@ -47,7 +56,8 @@ float process(int indices[${rank}]) { im2col[1] = indices[2]; im2col[2] = indices[3]; int im2colOffset = im2col[0] * ${im2colStrides[0]} + im2col[1] * ${im2colStrides[1]} + im2col[2] * ${ - im2colStrides[2]}; + im2colStrides[2] + }; int kernelOffset = indices[1] * ${adjustedKernelShape[1]}; float value = ${initValue}; for (int i = 0; i < ${sharedDim}; ++i) { @@ -60,19 +70,22 @@ float process(int indices[${rank}]) { ${applyActivation} return value; }`; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; -export const createDotProductProgramInfoLoader = - (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], outputShape: number[], - attributes: InternalActivationAttributes): ProgramInfoLoader => { - const metadata = createDotProductProgramMetadata(inputs.length > 2, attributes); - return { - ...metadata, - get: () => createDotProductProgramInfo(inferenceHandler, metadata, inputs, outputShape, attributes) - }; - }; +export const createDotProductProgramInfoLoader = ( + inferenceHandler: WebGLInferenceHandler, + inputs: readonly Tensor[], + outputShape: number[], + attributes: InternalActivationAttributes, +): ProgramInfoLoader => { + const metadata = createDotProductProgramMetadata(inputs.length > 2, attributes); + return { + ...metadata, + get: () => createDotProductProgramInfo(inferenceHandler, metadata, inputs, outputShape, attributes), + }; +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/flatten.ts b/js/web/lib/onnxjs/backends/webgl/ops/flatten.ts index ffce3bdaea5e5..b88bb43a337fa 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/flatten.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/flatten.ts @@ -1,22 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; - -export const flatten: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number): Tensor[] => { - validateInputs(inputs, axis); - - const outputDims = ShapeUtil.flattenShape(inputs[0].dims, axis); - return [inferenceHandler.reshapeUnpacked(inputs[0], outputDims)]; - }; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; + +export const flatten: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + axis: number, +): Tensor[] => { + validateInputs(inputs, axis); + + const outputDims = ShapeUtil.flattenShape(inputs[0].dims, axis); + return [inferenceHandler.reshapeUnpacked(inputs[0], outputDims)]; +}; export const parseFlattenAttributes: OperatorInitialization = (node: Graph.Node): number => - node.attributes.getInt('axis', 1); // default axis is 1 + node.attributes.getInt('axis', 1); // default axis is 1 const validateInputs = (inputs: Tensor[], axis: number): void => { if (!inputs || inputs.length !== 1) { @@ -36,4 +39,4 @@ const validateInputs = (inputs: Tensor[], axis: number): void => { if (inputs[0].type === 'string') { throw new Error('string tensor is not supported.'); } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts b/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts index 9497bb9f6967f..605362fda7122 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Attribute} from '../../../attribute'; -import {MAX_CLIP, MIN_CLIP} from '../../../util'; -import {GlslValueFunction} from '../glsl-definitions'; +import { Attribute } from '../../../attribute'; +import { MAX_CLIP, MIN_CLIP } from '../../../util'; +import { GlslValueFunction } from '../glsl-definitions'; -import {glslClip, glslRelu, glslSigmoid} from './unary-op'; +import { glslClip, glslRelu, glslSigmoid } from './unary-op'; export interface InternalActivationAttributes { readonly activation: string; @@ -28,13 +28,13 @@ export function getActivationSnippet(attributes: InternalActivationAttributes) { break; // TODO: adding other activations that can be fused. default: - return {activationFunction: '', applyActivation: ''}; + return { activationFunction: '', applyActivation: '' }; } const activationName = func.name; const activationFunction = func.body; const applyActivation = `value = ${activationName}_(value);`; - return {activationFunction, applyActivation}; + return { activationFunction, applyActivation }; } export const parseInternalActivationAttributes = (attributes: Attribute): InternalActivationAttributes => { @@ -42,7 +42,7 @@ export const parseInternalActivationAttributes = (attributes: Attribute): Intern if (activation === 'Clip') { const [clipMin, clipMax] = attributes.getFloats('activation_params', [MIN_CLIP, MAX_CLIP]); - return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; + return { activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}` }; } - return {activation, activationCacheKey: activation}; + return { activation, activationCacheKey: activation }; }; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/gather.ts b/js/web/lib/onnxjs/backends/webgl/ops/gather.ts index bb44a20d75f34..09d91992cc13e 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/gather.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/gather.ts @@ -1,27 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {NUMBER_TYPES, OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; interface GatherAttributes extends AttributeWithCacheKey { readonly axis: number; } -export const gather: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: GatherAttributes): Tensor[] => { - validateInputs(inputs, attributes.axis); - const output = inferenceHandler.run(createGatherProgramInfoLoader(inferenceHandler, inputs, attributes), inputs); - return [output]; - }; +export const gather: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: GatherAttributes, +): Tensor[] => { + validateInputs(inputs, attributes.axis); + const output = inferenceHandler.run(createGatherProgramInfoLoader(inferenceHandler, inputs, attributes), inputs); + return [output]; +}; export const parseGatherAttributes: OperatorInitialization = (node: Graph.Node): GatherAttributes => - createAttributeWithCacheKey({axis: node.attributes.getInt('axis', 0)}); + createAttributeWithCacheKey({ axis: node.attributes.getInt('axis', 0) }); const gatherProgramMetadata = { name: 'Gather', @@ -29,38 +32,45 @@ const gatherProgramMetadata = { inputTypes: [TextureType.unpacked, TextureType.unpacked], }; -const createGatherProgramInfo = - (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); - const indexDataShape = inputs[1].dims.slice(); - const outputShape = new Array(inputShape.length + indexDataShape.length - 1); +const createGatherProgramInfo = ( + _handler: WebGLInferenceHandler, + metadata: ProgramMetadata, + inputs: Tensor[], + axis: number, +): ProgramInfo => { + const inputShape = inputs[0].dims.slice(); + const indexDataShape = inputs[1].dims.slice(); + const outputShape = new Array(inputShape.length + indexDataShape.length - 1); - axis = ShapeUtil.normalizeAxis(axis, inputShape.length); - const indexCopyOps: string[] = []; - for (let i = 0; i < outputShape.length; i++) { - // outputShape is divided into three parts: A, B, C - // |0 axis| axis + indexDataShape.length | end| - // | A | B | C | - // - // inputIdx: [A, inputs[1][B], C] - if (i < axis) { // A - outputShape[i] = inputShape[i]; - indexCopyOps.push(`inputIdx[${i}] = outputIdx[${i}];`); - } else { - if (i < axis + indexDataShape.length) { // B - outputShape[i] = indexDataShape[i - axis]; - indexCopyOps.push(`indexDataIdx[${i - axis}] = outputIdx[${i}];`); - } else { // C - outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for axis - indexCopyOps.push(`inputIdx[${i - indexDataShape.length + 1}] = outputIdx[${i}];`); - } - } + axis = ShapeUtil.normalizeAxis(axis, inputShape.length); + const indexCopyOps: string[] = []; + for (let i = 0; i < outputShape.length; i++) { + // outputShape is divided into three parts: A, B, C + // |0 axis| axis + indexDataShape.length | end| + // | A | B | C | + // + // inputIdx: [A, inputs[1][B], C] + if (i < axis) { + // A + outputShape[i] = inputShape[i]; + indexCopyOps.push(`inputIdx[${i}] = outputIdx[${i}];`); + } else { + if (i < axis + indexDataShape.length) { + // B + outputShape[i] = indexDataShape[i - axis]; + indexCopyOps.push(`indexDataIdx[${i - axis}] = outputIdx[${i}];`); + } else { + // C + outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for axis + indexCopyOps.push(`inputIdx[${i - indexDataShape.length + 1}] = outputIdx[${i}];`); } + } + } - const orank = outputShape.length || 1; - const irank = inputShape.length; - const iDrank = indexDataShape.length || 1; - const shaderSource = ` + const orank = outputShape.length || 1; + const irank = inputShape.length; + const iDrank = indexDataShape.length || 1; + const shaderSource = ` float process(int outputIdx[${orank}]) { int inputIdx[${irank}]; int indexDataIdx[${iDrank}]; @@ -70,18 +80,21 @@ const createGatherProgramInfo = inputIdx[${axis}] = idx < 0 ? idx + ${inputShape[axis]} : idx; return _A(inputIdx); }`; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; -const createGatherProgramInfoLoader = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: GatherAttributes): ProgramInfoLoader => { - const metadata = {...gatherProgramMetadata, cacheHint: attributes.cacheKey}; - return {...metadata, get: () => createGatherProgramInfo(handler, metadata, inputs, attributes.axis)}; - }; +const createGatherProgramInfoLoader = ( + handler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: GatherAttributes, +): ProgramInfoLoader => { + const metadata = { ...gatherProgramMetadata, cacheHint: attributes.cacheKey }; + return { ...metadata, get: () => createGatherProgramInfo(handler, metadata, inputs, attributes.axis) }; +}; const validateInputs = (inputs: Tensor[], axis: number): void => { if (!inputs || inputs.length !== 2) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/gemm.ts b/js/web/lib/onnxjs/backends/webgl/ops/gemm.ts index 3f5c56b51bdc0..01f23863ecec5 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/gemm.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/gemm.ts @@ -1,84 +1,97 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {GemmUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { GemmUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; export interface GemmAttributes extends AttributeWithCacheKey { transA: boolean; transB: boolean; alpha: number; beta: number; - isOptionalC: boolean; // in opset 11, C becomes optional + isOptionalC: boolean; // in opset 11, C becomes optional } -export const gemm: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: GemmAttributes): Tensor[] => { - validateInputs(inputs, attributes); - const output = inferenceHandler.run(createGemmProgramInfoLoader(inputs, attributes), inputs); - return [output]; - }; +export const gemm: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: GemmAttributes, +): Tensor[] => { + validateInputs(inputs, attributes); + const output = inferenceHandler.run(createGemmProgramInfoLoader(inputs, attributes), inputs); + return [output]; +}; const parseGemmAttributes = (node: Graph.Node, isOptionalC: boolean): GemmAttributes => { const transA = node.attributes.getInt('transA', 0) !== 0; const transB = node.attributes.getInt('transB', 0) !== 0; const alpha = node.attributes.getFloat('alpha', 1.0); const beta = node.attributes.getFloat('beta', 1.0); - return createAttributeWithCacheKey({transA, transB, alpha, beta, isOptionalC}); + return createAttributeWithCacheKey({ transA, transB, alpha, beta, isOptionalC }); }; export const parseGemmAttributesV7: OperatorInitialization = (node: Graph.Node): GemmAttributes => - parseGemmAttributes(node, false); + parseGemmAttributes(node, false); export const parseGemmAttributesV11: OperatorInitialization = (node: Graph.Node): GemmAttributes => - parseGemmAttributes(node, true); + parseGemmAttributes(node, true); const createGemmProgramInfoLoader = (inputs: Tensor[], attributes: GemmAttributes): ProgramInfoLoader => { const metadata = { name: 'Gemm', inputNames: inputs.length === 3 ? ['A', 'B', 'C'] : ['A', 'B'], - inputTypes: inputs.length === 3 ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] : - [TextureType.unpacked, TextureType.unpacked], - key: attributes.cacheKey + inputTypes: + inputs.length === 3 + ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] + : [TextureType.unpacked, TextureType.unpacked], + key: attributes.cacheKey, }; - return {...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes)}; + return { ...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes) }; }; -const createGemmProgramInfo = - (metadata: ProgramMetadata, inputs: Tensor[], attributes: GemmAttributes): ProgramInfo => { - const aShape = inputs[0].dims.slice(); - const bShape = inputs[1].dims.slice(); - const [M, N] = GemmUtil.getShapeOfGemmResult( - aShape, attributes.transA, bShape, attributes.transB, inputs.length === 3 ? inputs[2].dims : undefined); - const outputShape = [M, N]; - if (!outputShape) { - throw new Error('Can\'t use gemm on the given tensors'); - } - let sharedDim = aShape[aShape.length - 1]; - let line = ''; - if (attributes.transA) { - sharedDim = aShape[0]; - } - if (attributes.transA && attributes.transB) { - line = 'value += _A_T(a) * _B_T(b);'; - } else if (attributes.transA && !attributes.transB) { - line = 'value += _A_T(a) * _B(b);'; - } else if (!attributes.transA && attributes.transB) { - line = 'value += _A(a) * _B_T(b);'; - } else if (!attributes.transA && !attributes.transB) { - line = 'value += _A(a) * _B(b);'; - } - const rank = outputShape.length; - const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : ''; - const broadcastC = inputs.length === 3 ? 'bcastIndices_C(indices, c);' : ''; - const calculateC = inputs.length === 3 ? 'value += beta * _C(c);' : ''; - const shaderSource = ` +const createGemmProgramInfo = ( + metadata: ProgramMetadata, + inputs: Tensor[], + attributes: GemmAttributes, +): ProgramInfo => { + const aShape = inputs[0].dims.slice(); + const bShape = inputs[1].dims.slice(); + const [M, N] = GemmUtil.getShapeOfGemmResult( + aShape, + attributes.transA, + bShape, + attributes.transB, + inputs.length === 3 ? inputs[2].dims : undefined, + ); + const outputShape = [M, N]; + if (!outputShape) { + throw new Error("Can't use gemm on the given tensors"); + } + let sharedDim = aShape[aShape.length - 1]; + let line = ''; + if (attributes.transA) { + sharedDim = aShape[0]; + } + if (attributes.transA && attributes.transB) { + line = 'value += _A_T(a) * _B_T(b);'; + } else if (attributes.transA && !attributes.transB) { + line = 'value += _A_T(a) * _B(b);'; + } else if (!attributes.transA && attributes.transB) { + line = 'value += _A(a) * _B_T(b);'; + } else if (!attributes.transA && !attributes.transB) { + line = 'value += _A(a) * _B(b);'; + } + const rank = outputShape.length; + const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : ''; + const broadcastC = inputs.length === 3 ? 'bcastIndices_C(indices, c);' : ''; + const calculateC = inputs.length === 3 ? 'value += beta * _C(c);' : ''; + const shaderSource = ` float process(int indices[${rank}]) { int a[${rank}]; int b[${rank}]; @@ -99,15 +112,16 @@ const createGemmProgramInfo = ${calculateC} return value; }`; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - variables: [ - {name: 'alpha', type: 'float', data: attributes.alpha}, {name: 'beta', type: 'float', data: attributes.beta} - ], - shaderSource - }; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + variables: [ + { name: 'alpha', type: 'float', data: attributes.alpha }, + { name: 'beta', type: 'float', data: attributes.beta }, + ], + shaderSource, + }; +}; const validateInputs = (inputs: Tensor[], attributes: GemmAttributes): void => { if (!inputs) { @@ -125,13 +139,15 @@ const validateInputs = (inputs: Tensor[], attributes: GemmAttributes): void => { throw new Error('Invalid input shape of C'); } - if ((inputs[0].type !== 'float32' && inputs[0].type !== 'float64') || - (inputs[1].type !== 'float32' && inputs[1].type !== 'float64') || - (inputs.length === 3 && inputs[2].type !== 'float32' && inputs[2].type !== 'float64')) { + if ( + (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') || + (inputs[1].type !== 'float32' && inputs[1].type !== 'float64') || + (inputs.length === 3 && inputs[2].type !== 'float32' && inputs[2].type !== 'float64') + ) { throw new Error('Invalid input type.'); } - if ((inputs[0].type !== inputs[1].type) || (inputs.length === 3 && inputs[0].type !== inputs[2].type)) { + if (inputs[0].type !== inputs[1].type || (inputs.length === 3 && inputs[0].type !== inputs[2].type)) { throw new Error('Input types are mismatched'); } }; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/im2col-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/im2col-pack.ts index f1dd968b40891..90495dfa3ee46 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/im2col-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/im2col-pack.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; -import {ConvAttributes} from './conv'; -import {unpackFromChannel} from './packing-utils'; +import { ConvAttributes } from './conv'; +import { unpackFromChannel } from './packing-utils'; const createPackedIm2ColProgramMetadata = (cacheHint: string) => ({ name: 'Im2Col (packed)', @@ -16,23 +16,28 @@ const createPackedIm2ColProgramMetadata = (cacheHint: string) => ({ cacheHint, }); -const createPackedIm2ColProgramInfo = - (inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, x: Tensor, w: Tensor, - outputShape: readonly number[], attributes: ConvAttributes): ProgramInfo => { - const xshape = x.dims; - const wshape = w.dims; - const rowDim = 2; - const colDim = 3; - const rank = outputShape.length; - const im2colShape = [wshape[1] * wshape[2] * wshape[3], outputShape[2] * outputShape[3]]; - const kernelSize = wshape[2] * wshape[3]; - const unpackChannel = unpackFromChannel(); - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - let unrolled = ''; +const createPackedIm2ColProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + metadata: ProgramMetadata, + x: Tensor, + w: Tensor, + outputShape: readonly number[], + attributes: ConvAttributes, +): ProgramInfo => { + const xshape = x.dims; + const wshape = w.dims; + const rowDim = 2; + const colDim = 3; + const rank = outputShape.length; + const im2colShape = [wshape[1] * wshape[2] * wshape[3], outputShape[2] * outputShape[3]]; + const kernelSize = wshape[2] * wshape[3]; + const unpackChannel = unpackFromChannel(); + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + let unrolled = ''; - for (let row = 0; row <= 1; row++) { - for (let col = 0; col <= 1; col++) { - unrolled += ` + for (let row = 0; row <= 1; row++) { + for (let col = 0; col <= 1; col++) { + unrolled += ` blockIndex = rc.x + ${col}; pos = rc.y + ${row}; @@ -58,10 +63,10 @@ const createPackedIm2ColProgramInfo = } `; - } - } + } + } - const shaderSource = ` + const shaderSource = ` ${unpackChannel} void main() { @@ -73,20 +78,24 @@ const createPackedIm2ColProgramInfo = ${glsl.output} = result; } `; - return { - ...metadata, - output: {dims: im2colShape, type: x.type, textureType: TextureType.packed}, - shaderSource, - hasMain: true - }; - }; + return { + ...metadata, + output: { dims: im2colShape, type: x.type, textureType: TextureType.packed }, + shaderSource, + hasMain: true, + }; +}; -export const createPackedIm2ColProgramInfoLoader = - (inferenceHandler: WebGLInferenceHandler, x: Tensor, w: Tensor, outputShape: readonly number[], - attributes: ConvAttributes): ProgramInfoLoader => { - const metadata = createPackedIm2ColProgramMetadata(attributes.cacheKey); - return { - ...metadata, - get: () => createPackedIm2ColProgramInfo(inferenceHandler, metadata, x, w, outputShape, attributes) - }; - }; +export const createPackedIm2ColProgramInfoLoader = ( + inferenceHandler: WebGLInferenceHandler, + x: Tensor, + w: Tensor, + outputShape: readonly number[], + attributes: ConvAttributes, +): ProgramInfoLoader => { + const metadata = createPackedIm2ColProgramMetadata(attributes.cacheKey); + return { + ...metadata, + get: () => createPackedIm2ColProgramInfo(inferenceHandler, metadata, x, w, outputShape, attributes), + }; +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts b/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts index a1da13ec48d70..81854a44c8fbb 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; +import { Tensor } from '../../../tensor'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; -import {ConvAttributes} from './conv'; +import { ConvAttributes } from './conv'; const createIm2ColProgramMetadata = (cacheHint: string) => ({ name: 'Im2Col', @@ -14,16 +14,21 @@ const createIm2ColProgramMetadata = (cacheHint: string) => ({ cacheHint, }); -const createIm2ColProgramInfo = - (_inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, x: Tensor, w: Tensor, - outputShape: readonly number[], attributes: ConvAttributes): ProgramInfo => { - const xshape = x.dims; - const wshape = w.dims; +const createIm2ColProgramInfo = ( + _inferenceHandler: WebGLInferenceHandler, + metadata: ProgramMetadata, + x: Tensor, + w: Tensor, + outputShape: readonly number[], + attributes: ConvAttributes, +): ProgramInfo => { + const xshape = x.dims; + const wshape = w.dims; - const rank = outputShape.length; - const im2colDims = calculateIm2ColDims(xshape, wshape, outputShape, 4); + const rank = outputShape.length; + const im2colDims = calculateIm2ColDims(xshape, wshape, outputShape, 4); - const shaderSource = ` + const shaderSource = ` const int XC = ${xshape[1]}; const int XH = ${xshape[2]}; const int XW = ${xshape[3]}; @@ -68,26 +73,35 @@ const createIm2ColProgramInfo = return value; } `; - return { - ...metadata, - output: {dims: im2colDims, type: x.type, textureType: TextureType.packedLastDimension}, - shaderSource - }; - }; + return { + ...metadata, + output: { dims: im2colDims, type: x.type, textureType: TextureType.packedLastDimension }, + shaderSource, + }; +}; -export const createIm2ColProgramInfoLoader = - (inferenceHandler: WebGLInferenceHandler, x: Tensor, w: Tensor, outputShape: readonly number[], - attributes: ConvAttributes): ProgramInfoLoader => { - const metadata = createIm2ColProgramMetadata(attributes.cacheKey); - return { - ...metadata, - get: () => createIm2ColProgramInfo(inferenceHandler, metadata, x, w, outputShape, attributes) - }; - }; +export const createIm2ColProgramInfoLoader = ( + inferenceHandler: WebGLInferenceHandler, + x: Tensor, + w: Tensor, + outputShape: readonly number[], + attributes: ConvAttributes, +): ProgramInfoLoader => { + const metadata = createIm2ColProgramMetadata(attributes.cacheKey); + return { + ...metadata, + get: () => createIm2ColProgramInfo(inferenceHandler, metadata, x, w, outputShape, attributes), + }; +}; - -export const calculateIm2ColDims = - (inputShape: readonly number[], kernelShape: readonly number[], outputShape: readonly number[], channels = 4): - number[] => - [outputShape[0], outputShape[2], outputShape[3], - Math.ceil(inputShape[1] * kernelShape[2] * kernelShape[3] / channels)]; +export const calculateIm2ColDims = ( + inputShape: readonly number[], + kernelShape: readonly number[], + outputShape: readonly number[], + channels = 4, +): number[] => [ + outputShape[0], + outputShape[2], + outputShape[3], + Math.ceil((inputShape[1] * kernelShape[2] * kernelShape[3]) / channels), +]; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts b/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts index efc79f686c960..c70a86c8cca03 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts @@ -1,32 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; export interface ImageScalerAttributes extends AttributeWithCacheKey { scale: number; bias: number[]; } -export const imageScaler: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ImageScalerAttributes): Tensor[] => { - validateInputs(inputs); - const output = - inferenceHandler.run(createImageScalerProgramInfoLoader(inferenceHandler, inputs, attributes), inputs); - return [output]; - }; +export const imageScaler: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ImageScalerAttributes, +): Tensor[] => { + validateInputs(inputs); + const output = inferenceHandler.run(createImageScalerProgramInfoLoader(inferenceHandler, inputs, attributes), inputs); + return [output]; +}; -export const parseImageScalerAttributes: OperatorInitialization = - (node: Graph.Node): ImageScalerAttributes => { - const scale = node.attributes.getFloat('scale'); - const bias = node.attributes.getFloats('bias'); - return createAttributeWithCacheKey({scale, bias}); - }; +export const parseImageScalerAttributes: OperatorInitialization = ( + node: Graph.Node, +): ImageScalerAttributes => { + const scale = node.attributes.getFloat('scale'); + const bias = node.attributes.getFloats('bias'); + return createAttributeWithCacheKey({ scale, bias }); +}; const imageScalerProgramMetadata = { name: 'ImageScaler', @@ -34,54 +37,52 @@ const imageScalerProgramMetadata = { inputTypes: [TextureType.unpacked], }; -const createImageScalerProgramInfo = - (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], attributes: ImageScalerAttributes): - ProgramInfo => { - const outputShape = inputs[0].dims.slice(); - const rank = outputShape.length; - const getBiasMethod = createGetBiasMethod(attributes.bias.length); - const shaderSource = ` +const createImageScalerProgramInfo = ( + _handler: WebGLInferenceHandler, + metadata: ProgramMetadata, + inputs: Tensor[], + attributes: ImageScalerAttributes, +): ProgramInfo => { + const outputShape = inputs[0].dims.slice(); + const rank = outputShape.length; + const getBiasMethod = createGetBiasMethod(attributes.bias.length); + const shaderSource = ` ${getBiasMethod} float process(int indices[${rank}]) { return _X(indices) * scale + getBias(bias, indices[1]); }`; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - variables: [ - {name: 'bias', type: 'float', arrayLength: attributes.bias.length, data: attributes.bias}, - {name: 'scale', type: 'float', data: attributes.scale} - ], - shaderSource - }; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + variables: [ + { name: 'bias', type: 'float', arrayLength: attributes.bias.length, data: attributes.bias }, + { name: 'scale', type: 'float', data: attributes.scale }, + ], + shaderSource, + }; +}; -const createImageScalerProgramInfoLoader = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ImageScalerAttributes): ProgramInfoLoader => { - const metadata = {...imageScalerProgramMetadata, cacheHint: attributes.cacheKey}; - return {...metadata, get: () => createImageScalerProgramInfo(handler, metadata, inputs, attributes)}; - }; +const createImageScalerProgramInfoLoader = ( + handler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ImageScalerAttributes, +): ProgramInfoLoader => { + const metadata = { ...imageScalerProgramMetadata, cacheHint: attributes.cacheKey }; + return { ...metadata, get: () => createImageScalerProgramInfo(handler, metadata, inputs, attributes) }; +}; const createGetBiasMethod = (numChannels: number): string => { const codeLines: string[] = [`float getBias(float bias[${numChannels}], int channel) {`]; for (let i = 0; i < numChannels; ++i) { if (i === 0) { - codeLines.push( - '\t' + - `if (channel == ${i}) { return bias[${i}]; }`); + codeLines.push('\t' + `if (channel == ${i}) { return bias[${i}]; }`); } else if (i === numChannels - 1) { - codeLines.push( - '\t' + - `else { return bias[${i}]; }`); + codeLines.push('\t' + `else { return bias[${i}]; }`); } else { - codeLines.push( - '\t' + - `else if (channel == ${i}) { return bias[${i}]; }`); + codeLines.push('\t' + `else if (channel == ${i}) { return bias[${i}]; }`); } } - codeLines.push( - '\t' + - '}'); + codeLines.push('\t' + '}'); return codeLines.join('\n'); }; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/instance-normalization.ts b/js/web/lib/onnxjs/backends/webgl/ops/instance-normalization.ts index 51a3ba835ca25..693b72211add9 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/instance-normalization.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/instance-normalization.ts @@ -1,26 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; - -export const instanceNormalization: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], epsilon: number): Tensor[] => { - validateInputs(inputs); - - const meanAndVariance = inferenceHandler.run(createMeanAndVarianceProgramInfoLoader(inputs[0]), inputs); - const output = inferenceHandler.run( - createComputeOutputProgramInfoLoader(inferenceHandler, inputs[0], epsilon, meanAndVariance.dims), - [inputs[0], meanAndVariance, inputs[1], inputs[2]]); - return [output]; - }; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; + +export const instanceNormalization: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + epsilon: number, +): Tensor[] => { + validateInputs(inputs); + + const meanAndVariance = inferenceHandler.run(createMeanAndVarianceProgramInfoLoader(inputs[0]), inputs); + const output = inferenceHandler.run( + createComputeOutputProgramInfoLoader(inferenceHandler, inputs[0], epsilon, meanAndVariance.dims), + [inputs[0], meanAndVariance, inputs[1], inputs[2]], + ); + return [output]; +}; export const parseInstanceNormalizationAttributes: OperatorInitialization = (node: Graph.Node): number => - node.attributes.getFloat('epsilon', 1e-5); + node.attributes.getFloat('epsilon', 1e-5); const meanAndVarianceProgramMetadata = { name: 'InstanceNormalization_MeanAndVariance', @@ -66,14 +70,14 @@ const createMeanAndVarianceProgramInfo = (metadata: ProgramMetadata, input: Tens }`; return { ...metadata, - output: {dims: outputShape, type: input.type, textureType: TextureType.packedLastDimension}, - shaderSource + output: { dims: outputShape, type: input.type, textureType: TextureType.packedLastDimension }, + shaderSource, }; }; const createMeanAndVarianceProgramInfoLoader = (input: Tensor): ProgramInfoLoader => ({ ...meanAndVarianceProgramMetadata, - get: () => createMeanAndVarianceProgramInfo(meanAndVarianceProgramMetadata, input) + get: () => createMeanAndVarianceProgramInfo(meanAndVarianceProgramMetadata, input), }); const computeOutputProgramMetadata = { @@ -82,14 +86,20 @@ const computeOutputProgramMetadata = { inputTypes: [TextureType.unpacked, TextureType.packedLastDimension, TextureType.unpacked, TextureType.unpacked], }; -const createComputeOutputProgramInfo = - (inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, input: Tensor, epsilon: number, - meanAndVarianceShape: readonly number[]): ProgramInfo => { - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const [textureWidth, textureHeight] = - inferenceHandler.calculateTextureWidthAndHeight(meanAndVarianceShape, TextureType.packedLastDimension); - const [meanAndVarianceWidth, meanAndVarianceHeight] = [textureWidth / 4, textureHeight]; - const shaderSource = ` +const createComputeOutputProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + metadata: ProgramMetadata, + input: Tensor, + epsilon: number, + meanAndVarianceShape: readonly number[], +): ProgramInfo => { + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight( + meanAndVarianceShape, + TextureType.packedLastDimension, + ); + const [meanAndVarianceWidth, meanAndVarianceHeight] = [textureWidth / 4, textureHeight]; + const shaderSource = ` vec4 get_MeanAndVariance(int[2] mv) { int offset = indicesToOffset_MeanAndVariance(mv); vec2 coords = offsetToCoords(offset, ${meanAndVarianceWidth}, ${meanAndVarianceHeight}); @@ -111,23 +121,26 @@ const createComputeOutputProgramInfo = return scale * (_X(indices) - mean) / sqrt(variance + epsilon) + b; }`; - return { - ...metadata, - output: {dims: input.dims, type: input.type, textureType: TextureType.unpacked}, - variables: [{name: 'epsilon', type: 'float', data: epsilon}], - shaderSource - }; - }; - -const createComputeOutputProgramInfoLoader = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, epsilon: number, meanAndVarianceShape: readonly number[]): - ProgramInfoLoader => { - const metadata = {...computeOutputProgramMetadata, cacheHint: `${epsilon}`}; - return { - ...metadata, - get: () => createComputeOutputProgramInfo(inferenceHandler, metadata, input, epsilon, meanAndVarianceShape) - }; - }; + return { + ...metadata, + output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked }, + variables: [{ name: 'epsilon', type: 'float', data: epsilon }], + shaderSource, + }; +}; + +const createComputeOutputProgramInfoLoader = ( + inferenceHandler: WebGLInferenceHandler, + input: Tensor, + epsilon: number, + meanAndVarianceShape: readonly number[], +): ProgramInfoLoader => { + const metadata = { ...computeOutputProgramMetadata, cacheHint: `${epsilon}` }; + return { + ...metadata, + get: () => createComputeOutputProgramInfo(inferenceHandler, metadata, input, epsilon, meanAndVarianceShape), + }; +}; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 3) { @@ -146,8 +159,11 @@ const validateInputs = (inputs: Tensor[]): void => { if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1]) { throw new Error('Input shapes are mismatched.'); } - if ((X.type !== 'float32' && X.type !== 'float64') || (scale.type !== 'float32' && scale.type !== 'float64') || - (B.type !== 'float32' && B.type !== 'float64')) { + if ( + (X.type !== 'float32' && X.type !== 'float64') || + (scale.type !== 'float32' && scale.type !== 'float64') || + (B.type !== 'float32' && B.type !== 'float64') + ) { throw new Error('Invalid input type.'); } if (inputs[0].dims.length !== 4) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/lrn.ts b/js/web/lib/onnxjs/backends/webgl/ops/lrn.ts index 21dae1200e800..5942b698977ce 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/lrn.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/lrn.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types'; export interface LrnAttributes extends AttributeWithCacheKey { alpha: number; @@ -15,17 +15,20 @@ export interface LrnAttributes extends AttributeWithCacheKey { size: number; } -export const lrn: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: LrnAttributes): Tensor[] => { - validateInputs(inputs); +export const lrn: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: LrnAttributes, +): Tensor[] => { + validateInputs(inputs); - // if (inferenceHandler.session.pack) { - // return [inferenceHandler.run(createPackedLrnProgramInfoLoader(inferenceHandler, inputs, attributes), - // inputs)]; - // } else { - return [inferenceHandler.run(createLrnProgramInfoLoader(inputs, attributes), inputs)]; - //} - }; + // if (inferenceHandler.session.pack) { + // return [inferenceHandler.run(createPackedLrnProgramInfoLoader(inferenceHandler, inputs, attributes), + // inputs)]; + // } else { + return [inferenceHandler.run(createLrnProgramInfoLoader(inputs, attributes), inputs)]; + //} +}; export const parseLrnAttributes: OperatorInitialization = (node: Graph.Node): LrnAttributes => { const alpha = node.attributes.getFloat('alpha', 0.0001); @@ -33,13 +36,13 @@ export const parseLrnAttributes: OperatorInitialization = (node: const bias = node.attributes.getFloat('bias', 1.0); const size = node.attributes.getInt('size'); - return createAttributeWithCacheKey({alpha, beta, bias, size}); + return createAttributeWithCacheKey({ alpha, beta, bias, size }); }; const lrnProgramMetadata = { name: 'LRN', inputNames: ['X'], - inputTypes: [TextureType.unpacked] + inputTypes: [TextureType.unpacked], }; function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): ProgramInfo { @@ -70,13 +73,13 @@ function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): Prog return { ...lrnProgramMetadata, cacheHint: attributes.cacheKey, - output: {dims: inputs[0].dims, type: inputs[0].type, textureType: TextureType.unpacked}, + output: { dims: inputs[0].dims, type: inputs[0].type, textureType: TextureType.unpacked }, shaderSource, }; } export function createLrnProgramInfoLoader(inputs: Tensor[], attributes: LrnAttributes): ProgramInfoLoader { - return {...lrnProgramMetadata, cacheHint: attributes.cacheKey, get: () => createLrnProgramInfo(inputs, attributes)}; + return { ...lrnProgramMetadata, cacheHint: attributes.cacheKey, get: () => createLrnProgramInfo(inputs, attributes) }; } const validateInputs = (inputs: Tensor[]): void => { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts index 0be6d1ba8bcd2..034b4fd6c2b04 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts @@ -1,61 +1,69 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {BroadcastUtil, ShapeUtil} from '../../../util'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; -import {getCoordsDataType, getGlChannels} from '../utils'; +import { Tensor } from '../../../tensor'; +import { BroadcastUtil, ShapeUtil } from '../../../util'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; +import { getCoordsDataType, getGlChannels } from '../utils'; -import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; -import {getBiasForMatmul} from './matmul'; +import { getActivationSnippet, InternalActivationAttributes } from './fuse-utils'; +import { getBiasForMatmul } from './matmul'; const createPackedMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ name: 'MatMul (packed)', inputNames: hasBias ? ['A', 'B', 'Bias'] : ['A', 'B'], - inputTypes: hasBias ? [TextureType.packed, TextureType.packed, TextureType.packed] : - [TextureType.packed, TextureType.packed], - cacheHint + inputTypes: hasBias + ? [TextureType.packed, TextureType.packed, TextureType.packed] + : [TextureType.packed, TextureType.packed], + cacheHint, }); -const createPackedMatmulProgramInfo = - (inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], - activationAttributes: InternalActivationAttributes): ProgramInfo => { - const hasBias = inputs.length > 2; - const processBias = hasBias ? 'value += getBiasForMatmul();' : ''; - const aShape = inputs[0].dims; - const bShape = inputs[1].dims; - const outputShape = BroadcastUtil.calcShape(aShape, bShape, true); - const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims); - - if (!outputShape) { - throw new Error('Can\'t use matmul on the given tensors'); - } - const sharedDim = aShape[aShape.length - 1]; - const sharedDimIndex = Math.ceil(sharedDim / 2); - const aRank = aShape.length; - const bRank = bShape.length; - - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const coordsDataType = getCoordsDataType(outputShape.length); - const outRank = outputShape.length; - const allGlChannels = getGlChannels(); - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes); - - const getBiasForMatmulSnippet = - hasBias ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, true)}` : ''; - - const getBcastedSamplerForMatmulSnippet = - isBroadcast ? `${getBcastSamplerForMatmul(coordsDataType, allGlChannels, inputs, outputShape)}` : ''; - - const getSamplerAInLoopSnippet = isBroadcast ? 'getAAtOutCoordsMatmul(i)' : `getA(${getA(allGlChannels, aRank)})`; - const getSamplerBInLoopSnippet = isBroadcast ? 'getBAtOutCoordsMatmul(i)' : `getB(${getB(allGlChannels, bRank)})`; - const getOutputCoordsSnippet = isBroadcast ? '' : `${coordsDataType} rc = +const createPackedMatmulProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + metadata: ProgramMetadata, + inputs: Tensor[], + activationAttributes: InternalActivationAttributes, +): ProgramInfo => { + const hasBias = inputs.length > 2; + const processBias = hasBias ? 'value += getBiasForMatmul();' : ''; + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; + const outputShape = BroadcastUtil.calcShape(aShape, bShape, true); + const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims); + + if (!outputShape) { + throw new Error("Can't use matmul on the given tensors"); + } + const sharedDim = aShape[aShape.length - 1]; + const sharedDimIndex = Math.ceil(sharedDim / 2); + const aRank = aShape.length; + const bRank = bShape.length; + + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const coordsDataType = getCoordsDataType(outputShape.length); + const outRank = outputShape.length; + const allGlChannels = getGlChannels(); + const { activationFunction, applyActivation } = getActivationSnippet(activationAttributes); + + const getBiasForMatmulSnippet = hasBias + ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, true)}` + : ''; + + const getBcastedSamplerForMatmulSnippet = isBroadcast + ? `${getBcastSamplerForMatmul(coordsDataType, allGlChannels, inputs, outputShape)}` + : ''; + + const getSamplerAInLoopSnippet = isBroadcast ? 'getAAtOutCoordsMatmul(i)' : `getA(${getA(allGlChannels, aRank)})`; + const getSamplerBInLoopSnippet = isBroadcast ? 'getBAtOutCoordsMatmul(i)' : `getB(${getB(allGlChannels, bRank)})`; + const getOutputCoordsSnippet = isBroadcast + ? '' + : `${coordsDataType} rc = getOutputCoords(); int lastDim = rc.${allGlChannels[outRank - 1]}; rc.${allGlChannels[outRank - 1]} = rc.${allGlChannels[outRank - 2]}; rc.${allGlChannels[outRank - 2]} = lastDim; `; - const shaderSource = ` + const shaderSource = ` ${getBcastedSamplerForMatmulSnippet} ${getBiasForMatmulSnippet} ${activationFunction} @@ -74,26 +82,32 @@ const createPackedMatmulProgramInfo = ${applyActivation} ${glsl.output} = value; }`; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.packed}, - shaderSource, - hasMain: true - }; - }; - -export const createPackedMatmulProgramInfoLoader = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], - activationAttributes: InternalActivationAttributes): ProgramInfoLoader => { - const metadata = createPackedMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); - return { - ...metadata, - get: () => createPackedMatmulProgramInfo(inferenceHandler, metadata, inputs, activationAttributes) - }; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed }, + shaderSource, + hasMain: true, + }; +}; + +export const createPackedMatmulProgramInfoLoader = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + activationAttributes: InternalActivationAttributes, +): ProgramInfoLoader => { + const metadata = createPackedMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); + return { + ...metadata, + get: () => createPackedMatmulProgramInfo(inferenceHandler, metadata, inputs, activationAttributes), + }; +}; function getBcastSamplerForMatmul( - coordsDataType: string, allGlChannels: readonly string[], inputs: Tensor[], outShape: readonly number[]): string { + coordsDataType: string, + allGlChannels: readonly string[], + inputs: Tensor[], + outShape: readonly number[], +): string { let unpackedACoordsSnippet = []; let unpackedBCoordsSnippet = []; @@ -117,8 +131,8 @@ function getBcastSamplerForMatmul( const broadcastADims = BroadcastUtil.getBroadcastDims(inAShape, outShape); const broadcastBDims = BroadcastUtil.getBroadcastDims(inBShape, outShape); - const coordsASnippet = broadcastADims.map(d => `coords.${allGlChannels[d + rankADiff]} = 0;`).join('\n'); - const coordsBSnippet = broadcastBDims.map(d => `coords.${allGlChannels[d + rankBDiff]} = 0;`).join('\n'); + const coordsASnippet = broadcastADims.map((d) => `coords.${allGlChannels[d + rankADiff]} = 0;`).join('\n'); + const coordsBSnippet = broadcastBDims.map((d) => `coords.${allGlChannels[d + rankBDiff]} = 0;`).join('\n'); const swapDimSnippet = `int lastDim = coords.${allGlChannels[outRank - 1]}; coords.${allGlChannels[outRank - 1]} = coords.${allGlChannels[outRank - 2]}; coords.${allGlChannels[outRank - 2]} = lastDim;`; @@ -148,8 +162,7 @@ function getA(allGlChannels: string[], rank: number): string { for (let i = 0; i < rank - 2; i++) { res += `rc.${allGlChannels[i]}, `; } - res += `rc.${allGlChannels[rank - 2]}, ` + - 'i*2'; + res += `rc.${allGlChannels[rank - 2]}, ` + 'i*2'; return res; } @@ -158,7 +171,6 @@ function getB(allGlChannels: string[], rank: number): string { for (let i = 0; i < rank - 2; i++) { res += `rc.${allGlChannels[i]}, `; } - res += 'i*2, ' + - `rc.${allGlChannels[rank - 1]}`; + res += 'i*2, ' + `rc.${allGlChannels[rank - 1]}`; return res; } diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts index 523165f29f852..ea22d4b81a886 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts @@ -1,56 +1,64 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {BroadcastUtil, ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; -import {getCoordsDataType, getGlChannels} from '../utils'; - -import {getActivationSnippet, InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; -import {createPackedMatmulProgramInfoLoader} from './matmul-pack'; - -export const matMul: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: InternalActivationAttributes): Tensor[] => { - validateInputs(inputs); - - if (inferenceHandler.session.pack) { - return [inferenceHandler.run( - createPackedMatmulProgramInfoLoader(inferenceHandler, inputs, attributes), inputs)]; - } else { - return [inferenceHandler.run(createMatmulProgramInfoLoader(inputs, attributes), inputs)]; - } - }; - -export const parseMatMulAttributes: OperatorInitialization = - (node: Graph.Node): InternalActivationAttributes => parseInternalActivationAttributes(node.attributes); +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { BroadcastUtil, ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; +import { getCoordsDataType, getGlChannels } from '../utils'; + +import { getActivationSnippet, InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils'; +import { createPackedMatmulProgramInfoLoader } from './matmul-pack'; + +export const matMul: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: InternalActivationAttributes, +): Tensor[] => { + validateInputs(inputs); + + if (inferenceHandler.session.pack) { + return [inferenceHandler.run(createPackedMatmulProgramInfoLoader(inferenceHandler, inputs, attributes), inputs)]; + } else { + return [inferenceHandler.run(createMatmulProgramInfoLoader(inputs, attributes), inputs)]; + } +}; + +export const parseMatMulAttributes: OperatorInitialization = ( + node: Graph.Node, +): InternalActivationAttributes => parseInternalActivationAttributes(node.attributes); const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ name: 'MatMul', inputNames: hasBias ? ['A', 'B', 'Bias'] : ['A', 'B'], - inputTypes: hasBias ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] : - [TextureType.unpacked, TextureType.unpacked], - cacheHint + inputTypes: hasBias + ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] + : [TextureType.unpacked, TextureType.unpacked], + cacheHint, }); function createMatmulProgramInfo( - metadata: ProgramMetadata, inputs: Tensor[], activationAttributes: InternalActivationAttributes): ProgramInfo { + metadata: ProgramMetadata, + inputs: Tensor[], + activationAttributes: InternalActivationAttributes, +): ProgramInfo { const aShape = inputs[0].dims; const bShape = inputs[1].dims; const outputShape = BroadcastUtil.calcShape(aShape, bShape, true); if (!outputShape) { - throw new Error('Can\'t use matmul on the given tensors'); + throw new Error("Can't use matmul on the given tensors"); } const coordsDataType = getCoordsDataType(outputShape.length); const allGlChannels = getGlChannels(); - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes); + const { activationFunction, applyActivation } = getActivationSnippet(activationAttributes); const hasBias = inputs.length > 2; const processBias = hasBias ? 'value += getBiasForMatmul();' : ''; - const getBiasForMatmulSnippet = - hasBias ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, false)}` : ''; + const getBiasForMatmulSnippet = hasBias + ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, false)}` + : ''; const rank = outputShape.length; const arank = aShape.length; @@ -77,15 +85,17 @@ function createMatmulProgramInfo( }`; return { ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, shaderSource, }; } export function createMatmulProgramInfoLoader( - inputs: Tensor[], activationAttributes: InternalActivationAttributes): ProgramInfoLoader { + inputs: Tensor[], + activationAttributes: InternalActivationAttributes, +): ProgramInfoLoader { const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); - return {...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes)}; + return { ...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes) }; } const validateInputs = (inputs: Tensor[]): void => { @@ -97,8 +107,10 @@ const validateInputs = (inputs: Tensor[]): void => { throw new Error('shared dimension does not match.'); } - if ((inputs[0].type !== 'float32' && inputs[0].type !== 'float64') || - (inputs[1].type !== 'float32' && inputs[1].type !== 'float64')) { + if ( + (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') || + (inputs[1].type !== 'float32' && inputs[1].type !== 'float64') + ) { throw new Error('inputs should be float type'); } @@ -108,8 +120,12 @@ const validateInputs = (inputs: Tensor[]): void => { }; export function getBiasForMatmul( - coordsDataType: string, allGlChannels: readonly string[], inShape: readonly number[], outShape: readonly number[], - isPacked: boolean): string { + coordsDataType: string, + allGlChannels: readonly string[], + inShape: readonly number[], + outShape: readonly number[], + isPacked: boolean, +): string { let unpackedCoordsSnippet = ''; const inRank = inShape.length; const outRank = outShape.length; @@ -120,21 +136,22 @@ export function getBiasForMatmul( unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); } const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape); - const coordsSnippet = broadcastDims.map(d => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n'); + const coordsSnippet = broadcastDims.map((d) => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n'); const inSize = ShapeUtil.size(inShape); const isInputScalar = inSize === 1; let output = 'vec4(outputValue.xx, outputValue.yy)'; if (isInputScalar) { output = 'vec4(outputValue.x)'; } - const getBiasForMatmulSource = isPacked ? ` + const getBiasForMatmulSource = isPacked + ? ` vec4 getBiasForMatmul() { ${coordsDataType} coords = getOutputCoords(); ${coordsSnippet} vec4 outputValue = getBias(${unpackedCoordsSnippet}); return ${output}; -}` : - ` +}` + : ` float getBiasForMatmul() { ${coordsDataType} coords = getOutputCoords(); ${coordsSnippet} diff --git a/js/web/lib/onnxjs/backends/webgl/ops/pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/pack.ts index 37ef8c8fe2435..745455089ddc5 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/pack.ts @@ -1,18 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, TextureType} from '../types'; -import {getCoordsDataType} from '../utils'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types'; +import { getCoordsDataType } from '../utils'; -import {getChannels} from './packing-utils'; +import { getChannels } from './packing-utils'; const packProgramMetadata = { name: 'pack', inputNames: ['A'], - inputTypes: [TextureType.unpackedReversed] + inputTypes: [TextureType.unpackedReversed], }; const createPackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => { @@ -54,13 +54,15 @@ const createPackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): P return { ...packProgramMetadata, hasMain: true, - output: {dims: input.dims, type: input.type, textureType: TextureType.packed}, - shaderSource + output: { dims: input.dims, type: input.type, textureType: TextureType.packed }, + shaderSource, }; }; -export const createPackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => - ({...packProgramMetadata, get: () => createPackProgramInfo(handler, input)}); +export const createPackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => ({ + ...packProgramMetadata, + get: () => createPackProgramInfo(handler, input), +}); /** * check output coordinate location and return false if it is outside input's width/height boundary diff --git a/js/web/lib/onnxjs/backends/webgl/ops/packing-utils.ts b/js/web/lib/onnxjs/backends/webgl/ops/packing-utils.ts index d391b77b7752d..29740b86952e5 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/packing-utils.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/packing-utils.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {getGlChannels} from '../utils'; +import { getGlChannels } from '../utils'; export function getVecChannels(name: string, rank: number): string[] { - return getGlChannels(rank).map(d => `${name}.${d}`); + return getGlChannels(rank).map((d) => `${name}.${d}`); } export function getChannels(name: string, rank: number): string[] { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/pad.ts b/js/web/lib/onnxjs/backends/webgl/ops/pad.ts index f0a0bc21cd77e..5a18ccd15b69c 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/pad.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/pad.ts @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {getGlsl, Glsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { getGlsl, Glsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, TextureType } from '../types'; export interface PadAttributes extends AttributeWithCacheKey { readonly mode: string; @@ -22,67 +22,82 @@ const padProgramMetadata = { inputTypes: [TextureType.unpacked], }; -export const padV2: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: PadAttributes): Tensor[] => { - validateInputsV2(inputs); - const output = inferenceHandler.run( - { - ...padProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => createPadProgramInfo(inferenceHandler, inputs[0], attributes) - }, - inputs); - return [output]; - }; +export const padV2: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: PadAttributes, +): Tensor[] => { + validateInputsV2(inputs); + const output = inferenceHandler.run( + { + ...padProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createPadProgramInfo(inferenceHandler, inputs[0], attributes), + }, + inputs, + ); + return [output]; +}; export const parsePadAttributesV2: OperatorInitialization = (node: Graph.Node): PadAttributes => { const mode = node.attributes.getString('mode', 'constant'); const value = node.attributes.getFloat('value', 0.0); const pads = node.attributes.getInts('pads'); - return createAttributeWithCacheKey({mode, value, pads}); + return createAttributeWithCacheKey({ mode, value, pads }); }; -export const padV11: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], mode: string): Tensor[] => { - validateInputsV11(inputs); - const attrubutes = generatePadAttributesFromInputs(inferenceHandler, inputs, mode); - return padV2(inferenceHandler, [inputs[0]], attrubutes); - }; +export const padV11: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + mode: string, +): Tensor[] => { + validateInputsV11(inputs); + const attrubutes = generatePadAttributesFromInputs(inferenceHandler, inputs, mode); + return padV2(inferenceHandler, [inputs[0]], attrubutes); +}; export const parsePadAttributesV11: OperatorInitialization = (node: Graph.Node): string => - node.attributes.getString('mode', 'constant'); - -const generatePadAttributesFromInputs = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], mode: string): PadAttributes => { - if (!inferenceHandler.session.isInitializer(inputs[1].dataId) || - (inputs.length >= 3 && !inferenceHandler.session.isInitializer(inputs[2].dataId))) { - throw new Error('dynamic pad attributes are not allowed'); - } + node.attributes.getString('mode', 'constant'); + +const generatePadAttributesFromInputs = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + mode: string, +): PadAttributes => { + if ( + !inferenceHandler.session.isInitializer(inputs[1].dataId) || + (inputs.length >= 3 && !inferenceHandler.session.isInitializer(inputs[2].dataId)) + ) { + throw new Error('dynamic pad attributes are not allowed'); + } - const pads = Array.from(inputs[1].integerData); - const value = (inputs.length >= 3) ? inputs[2].floatData[0] : 0.0; + const pads = Array.from(inputs[1].integerData); + const value = inputs.length >= 3 ? inputs[2].floatData[0] : 0.0; - return createAttributeWithCacheKey({mode, pads, value}); - }; + return createAttributeWithCacheKey({ mode, pads, value }); +}; -const createPadProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: PadAttributes): ProgramInfo => { - const outputShape = ShapeUtil.padShape(input.dims.slice(), attributes.pads); - const rank = outputShape.length; - const padFunction = getPadFunction(inferenceHandler, input, attributes); - const shaderSource = ` +const createPadProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + input: Tensor, + attributes: PadAttributes, +): ProgramInfo => { + const outputShape = ShapeUtil.padShape(input.dims.slice(), attributes.pads); + const rank = outputShape.length; + const padFunction = getPadFunction(inferenceHandler, input, attributes); + const shaderSource = ` ${padFunction} float process(int[${rank}] indices) { return padA(indices); }`; - return { - name: 'Pad', - inputNames: ['A'], - inputTypes: [TextureType.unpacked], - output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + name: 'Pad', + inputNames: ['A'], + inputTypes: [TextureType.unpacked], + output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; const validateInputsV2 = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { @@ -122,20 +137,26 @@ const getPadFunction = (inferenceHandler: WebGLInferenceHandler, input: Tensor, } }; -const getPadConstant = - (glsl: Glsl, shape: readonly number[], strides: readonly number[], width: number, height: number, pads: number[], - value: number): string => { - const rank = shape.length; - let block = ''; - for (let i = rank - 1; i >= 0; --i) { - block += ` +const getPadConstant = ( + glsl: Glsl, + shape: readonly number[], + strides: readonly number[], + width: number, + height: number, + pads: number[], + value: number, +): string => { + const rank = shape.length; + let block = ''; + for (let i = rank - 1; i >= 0; --i) { + block += ` k = m[${i}] - ${pads[i]}; if (k < 0) return constant; if (k >= ${shape[i]}) return constant; offset += k * ${strides[i]}; `; - } - return ` + } + return ` float padA(int m[${rank}]) { const float constant = float(${value}); int offset = 0; @@ -146,16 +167,21 @@ const getPadConstant = return value; } `; - }; - -const getPadReflect = - (glsl: Glsl, shape: readonly number[], strides: readonly number[], width: number, height: number, pads: number[]): - string => { - const rank = shape.length; +}; - let block = ''; - for (let i = rank - 1; i >= 0; --i) { - block += ` +const getPadReflect = ( + glsl: Glsl, + shape: readonly number[], + strides: readonly number[], + width: number, + height: number, + pads: number[], +): string => { + const rank = shape.length; + + let block = ''; + for (let i = rank - 1; i >= 0; --i) { + block += ` k = m[${i}] - ${pads[i]}; if (k < 0) { k = -k; } { @@ -165,8 +191,8 @@ const getPadReflect = } offset += k * ${strides[i]}; `; - } - return ` + } + return ` float padA(int m[${rank}]) { int offset = 0; int k = 0; @@ -176,23 +202,28 @@ const getPadReflect = return value; } `; - }; - -const getPadEdge = - (glsl: Glsl, shape: readonly number[], strides: readonly number[], width: number, height: number, pads: number[]): - string => { - const rank = shape.length; +}; - let block = ''; - for (let i = rank - 1; i >= 0; --i) { - block += ` +const getPadEdge = ( + glsl: Glsl, + shape: readonly number[], + strides: readonly number[], + width: number, + height: number, + pads: number[], +): string => { + const rank = shape.length; + + let block = ''; + for (let i = rank - 1; i >= 0; --i) { + block += ` k = m[${i}] - ${pads[i]}; if (k < 0) k = 0; if (k >= ${shape[i]}) k = ${shape[i] - 1}; offset += k * ${strides[i]}; `; - } - return ` + } + return ` float padA(int m[${rank}]) { int offset = 0; int k = 0; @@ -202,4 +233,4 @@ const getPadEdge = return value; } `; - }; +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/pool.ts b/js/web/lib/onnxjs/backends/webgl/ops/pool.ts index d7b07fcc57a3d..c603080fb0de1 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/pool.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/pool.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {PoolConvUtil, ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramMetadata, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { PoolConvUtil, ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramMetadata, TextureType } from '../types'; export interface AveragePoolAttributes extends AttributeWithCacheKey { readonly autoPad: string; @@ -18,157 +18,218 @@ export interface AveragePoolAttributes extends AttributeWithCacheKey { readonly pads: readonly number[]; } -export const averagePool: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: AveragePoolAttributes): Tensor[] => { - validateInputs(inputs); - const metadata = - {name: 'AveragePool', inputNames: ['X'], inputTypes: [TextureType.unpacked], cacheHint: attributes.cacheKey}; - const output = inferenceHandler.run( - {...metadata, get: () => createAveragePoolProgramInfo(inputs, metadata, false, attributes)}, inputs); - return [output]; - }; +export const averagePool: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: AveragePoolAttributes, +): Tensor[] => { + validateInputs(inputs); + const metadata = { + name: 'AveragePool', + inputNames: ['X'], + inputTypes: [TextureType.unpacked], + cacheHint: attributes.cacheKey, + }; + const output = inferenceHandler.run( + { ...metadata, get: () => createAveragePoolProgramInfo(inputs, metadata, false, attributes) }, + inputs, + ); + return [output]; +}; -export const parseAveragePoolAttributes: OperatorInitialization = - (node: Graph.Node): AveragePoolAttributes => { - const autoPad = node.attributes.getString('auto_pad', 'NOTSET'); - const ceilMode = node.attributes.getInt('ceil_mode', 0); - const countIncludePad = (node.attributes.getInt('count_include_pad', 0) === 0 ? false : true); - const kernelShape = node.attributes.getInts('kernel_shape'); - const strides = node.attributes.getInts('strides', []); - const pads = node.attributes.getInts('pads', []); +export const parseAveragePoolAttributes: OperatorInitialization = ( + node: Graph.Node, +): AveragePoolAttributes => { + const autoPad = node.attributes.getString('auto_pad', 'NOTSET'); + const ceilMode = node.attributes.getInt('ceil_mode', 0); + const countIncludePad = node.attributes.getInt('count_include_pad', 0) === 0 ? false : true; + const kernelShape = node.attributes.getInts('kernel_shape'); + const strides = node.attributes.getInts('strides', []); + const pads = node.attributes.getInts('pads', []); - // TODO: support attribute 'ceil_mode' - if (ceilMode !== 0) { - throw new Error('using ceil() in shape computation is not yet supported for AveragePool'); - } + // TODO: support attribute 'ceil_mode' + if (ceilMode !== 0) { + throw new Error('using ceil() in shape computation is not yet supported for AveragePool'); + } - return createAttributeWithCacheKey({autoPad, ceilMode, countIncludePad, kernelShape, strides, pads}); - }; + return createAttributeWithCacheKey({ autoPad, ceilMode, countIncludePad, kernelShape, strides, pads }); +}; -const createAveragePoolProgramInfo = - (inputs: Tensor[], metadata: ProgramMetadata, isGlobalOperator: boolean, attributes: AveragePoolAttributes): - ProgramInfo => { - const [adjustedAttributes, outputShape] = - getAdjustedPoolAttributesAndOutputShape(inputs, attributes, isGlobalOperator); - const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape); - const op1 = 'value += _X(x);'; - let op2 = ''; - if (adjustedAttributes.countIncludePad) { - op2 += `value /= float(${kernelSize});`; - } else { - op2 += `value /= float(${kernelSize} - pad);`; - } - const poolingCode = generatePoolingCode(inputs[0].dims, adjustedAttributes, op1, op2, '0.0'); - const shaderSource = ` +const createAveragePoolProgramInfo = ( + inputs: Tensor[], + metadata: ProgramMetadata, + isGlobalOperator: boolean, + attributes: AveragePoolAttributes, +): ProgramInfo => { + const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape( + inputs, + attributes, + isGlobalOperator, + ); + const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape); + const op1 = 'value += _X(x);'; + let op2 = ''; + if (adjustedAttributes.countIncludePad) { + op2 += `value /= float(${kernelSize});`; + } else { + op2 += `value /= float(${kernelSize} - pad);`; + } + const poolingCode = generatePoolingCode(inputs[0].dims, adjustedAttributes, op1, op2, '0.0'); + const shaderSource = ` ${poolingCode} `; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; -export const globalAveragePool: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: AveragePoolAttributes): Tensor[] => { - validateInputs(inputs); - const metadata = { - name: 'GlobalAveragePool', - inputNames: ['X'], - inputTypes: [TextureType.unpacked], - cacheHint: `${attributes.countIncludePad}` - }; - const output = inferenceHandler.run( - {...metadata, get: () => createAveragePoolProgramInfo(inputs, metadata, true, attributes)}, inputs); - return [output]; - }; +export const globalAveragePool: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: AveragePoolAttributes, +): Tensor[] => { + validateInputs(inputs); + const metadata = { + name: 'GlobalAveragePool', + inputNames: ['X'], + inputTypes: [TextureType.unpacked], + cacheHint: `${attributes.countIncludePad}`, + }; + const output = inferenceHandler.run( + { ...metadata, get: () => createAveragePoolProgramInfo(inputs, metadata, true, attributes) }, + inputs, + ); + return [output]; +}; -export const parseGlobalAveragePoolAttributes: OperatorInitialization = - (node: Graph.Node): AveragePoolAttributes => { - const countIncludePad = (node.attributes.getInt('count_include_pad', 0) === 0 ? false : true); - return createAttributeWithCacheKey( - {autoPad: '', ceilMode: 0, countIncludePad, kernelShape: [], strides: [], pads: []}); - }; +export const parseGlobalAveragePoolAttributes: OperatorInitialization = ( + node: Graph.Node, +): AveragePoolAttributes => { + const countIncludePad = node.attributes.getInt('count_include_pad', 0) === 0 ? false : true; + return createAttributeWithCacheKey({ + autoPad: '', + ceilMode: 0, + countIncludePad, + kernelShape: [], + strides: [], + pads: [], + }); +}; export interface MaxPoolAttributes extends AveragePoolAttributes { readonly storageOrder: number; readonly dilations: number[]; } -export const maxPool: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: MaxPoolAttributes): Tensor[] => { - validateInputs(inputs); - const metadata = - {name: 'MaxPool', inputNames: ['X'], inputTypes: [TextureType.unpacked], cacheHint: attributes.cacheKey}; - const output = inferenceHandler.run( - {...metadata, get: () => createMaxPoolProgramInfo(inputs, metadata, false, attributes)}, inputs); - return [output]; - }; +export const maxPool: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: MaxPoolAttributes, +): Tensor[] => { + validateInputs(inputs); + const metadata = { + name: 'MaxPool', + inputNames: ['X'], + inputTypes: [TextureType.unpacked], + cacheHint: attributes.cacheKey, + }; + const output = inferenceHandler.run( + { ...metadata, get: () => createMaxPoolProgramInfo(inputs, metadata, false, attributes) }, + inputs, + ); + return [output]; +}; -export const parseMaxPoolAttributes: OperatorInitialization = - (node: Graph.Node): MaxPoolAttributes => { - const autoPad = node.attributes.getString('auto_pad', 'NOTSET'); - const ceilMode = node.attributes.getInt('ceil_mode', 0); - const kernelShape = node.attributes.getInts('kernel_shape'); - const strides = node.attributes.getInts('strides', []); - const pads = node.attributes.getInts('pads', []); - const storageOrder = node.attributes.getInt('storage_order', 0); - const dilations = node.attributes.getInts('dilations', []); +export const parseMaxPoolAttributes: OperatorInitialization = ( + node: Graph.Node, +): MaxPoolAttributes => { + const autoPad = node.attributes.getString('auto_pad', 'NOTSET'); + const ceilMode = node.attributes.getInt('ceil_mode', 0); + const kernelShape = node.attributes.getInts('kernel_shape'); + const strides = node.attributes.getInts('strides', []); + const pads = node.attributes.getInts('pads', []); + const storageOrder = node.attributes.getInt('storage_order', 0); + const dilations = node.attributes.getInts('dilations', []); - // TODO: support attribute 'ceil_mode' and 'storage_order' - if (storageOrder !== 0) { - throw new Error('column major storage order is not yet supported for MaxPool'); - } - if (ceilMode !== 0) { - throw new Error('using ceil() in shape computation is not yet supported for MaxPool'); - } + // TODO: support attribute 'ceil_mode' and 'storage_order' + if (storageOrder !== 0) { + throw new Error('column major storage order is not yet supported for MaxPool'); + } + if (ceilMode !== 0) { + throw new Error('using ceil() in shape computation is not yet supported for MaxPool'); + } - return createAttributeWithCacheKey( - {autoPad, ceilMode, countIncludePad: false, kernelShape, strides, pads, storageOrder, dilations}); - }; + return createAttributeWithCacheKey({ + autoPad, + ceilMode, + countIncludePad: false, + kernelShape, + strides, + pads, + storageOrder, + dilations, + }); +}; -const createMaxPoolProgramInfo = - (inputs: Tensor[], metadata: ProgramMetadata, isGlobalOperator: boolean, attributes: MaxPoolAttributes): - ProgramInfo => { - const [adjustedAttributes, outputShape] = - getAdjustedPoolAttributesAndOutputShape(inputs, attributes, isGlobalOperator); - const op1 = ` +const createMaxPoolProgramInfo = ( + inputs: Tensor[], + metadata: ProgramMetadata, + isGlobalOperator: boolean, + attributes: MaxPoolAttributes, +): ProgramInfo => { + const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape( + inputs, + attributes, + isGlobalOperator, + ); + const op1 = ` value = max(_X(x), value); `; - const op2 = ''; - const poolingCode = generatePoolingCode(inputs[0].dims, adjustedAttributes, op1, op2, '-1e5'); - const shaderSource = ` + const op2 = ''; + const poolingCode = generatePoolingCode(inputs[0].dims, adjustedAttributes, op1, op2, '-1e5'); + const shaderSource = ` ${poolingCode} `; - return { - ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...metadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; -const getAdjustedPoolAttributesAndOutputShape = - (inputs: Tensor[], attributes: AveragePoolAttributes|MaxPoolAttributes, isGlobalOperator: boolean): - [AveragePoolAttributes|MaxPoolAttributes, number[]] => { - const inputShape = inputs[0].dims.slice(); - const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); - const kernelShape = attributes.kernelShape.slice(); - const strides = attributes.strides.slice(); - const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : []; - const pads = attributes.pads.slice(); - PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShape, kernelShape, strides, dilations, pads); +const getAdjustedPoolAttributesAndOutputShape = ( + inputs: Tensor[], + attributes: AveragePoolAttributes | MaxPoolAttributes, + isGlobalOperator: boolean, +): [AveragePoolAttributes | MaxPoolAttributes, number[]] => { + const inputShape = inputs[0].dims.slice(); + const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); + const kernelShape = attributes.kernelShape.slice(); + const strides = attributes.strides.slice(); + const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : []; + const pads = attributes.pads.slice(); + PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShape, kernelShape, strides, dilations, pads); - const outputShape = PoolConvUtil.computePoolOutputShape( - isGlobalOperator, inputShape, strides, dilations, kernelShape, pads, attributes.autoPad); + const outputShape = PoolConvUtil.computePoolOutputShape( + isGlobalOperator, + inputShape, + strides, + dilations, + kernelShape, + pads, + attributes.autoPad, + ); - const newAttributes = Object.assign({}, attributes); - if (hasDilations) { - Object.assign(newAttributes, {kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey}); - } else { - Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey}); - } - return [newAttributes, outputShape]; - }; + const newAttributes = Object.assign({}, attributes); + if (hasDilations) { + Object.assign(newAttributes, { kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey }); + } else { + Object.assign(newAttributes, { kernelShape, strides, pads, cacheKey: attributes.cacheKey }); + } + return [newAttributes, outputShape]; +}; const globalMaxPoolAttributes = { autoPad: '', @@ -179,23 +240,24 @@ const globalMaxPoolAttributes = { pads: [], storageOrder: 0, dilations: [], - cacheKey: '' + cacheKey: '', }; const globalMaxPoolMetadata = { name: 'GlobalMaxPool', inputNames: ['X'], - inputTypes: [TextureType.unpacked] + inputTypes: [TextureType.unpacked], }; export const globalMaxPool = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputs(inputs); const output = inferenceHandler.run( - { - ...globalMaxPoolMetadata, - get: () => createMaxPoolProgramInfo(inputs, globalMaxPoolMetadata, true, globalMaxPoolAttributes) - }, - inputs); + { + ...globalMaxPoolMetadata, + get: () => createMaxPoolProgramInfo(inputs, globalMaxPoolMetadata, true, globalMaxPoolAttributes), + }, + inputs, + ); return [output]; }; @@ -208,21 +270,25 @@ const validateInputs = (inputs: Tensor[]): void => { } }; -const generatePoolingCode = - (inputDims: readonly number[], attributes: AveragePoolAttributes, op1: string, op2: string, start: string): - string => { - const rank = inputDims.length; - if (attributes.kernelShape.length <= 2) { - const kw = attributes.kernelShape[attributes.kernelShape.length - 1]; - const sw = attributes.strides[attributes.strides.length - 1]; - const pwStart = attributes.pads[attributes.pads.length / 2 - 1]; - const pwEnd = attributes.pads[attributes.pads.length - 1]; - const dimW = inputDims[rank - 1]; - let codeW = ''; - let codeH = ''; - let codeHEnd = ''; - if (pwStart + pwEnd !== 0) { - codeW = ` +const generatePoolingCode = ( + inputDims: readonly number[], + attributes: AveragePoolAttributes, + op1: string, + op2: string, + start: string, +): string => { + const rank = inputDims.length; + if (attributes.kernelShape.length <= 2) { + const kw = attributes.kernelShape[attributes.kernelShape.length - 1]; + const sw = attributes.strides[attributes.strides.length - 1]; + const pwStart = attributes.pads[attributes.pads.length / 2 - 1]; + const pwEnd = attributes.pads[attributes.pads.length - 1]; + const dimW = inputDims[rank - 1]; + let codeW = ''; + let codeH = ''; + let codeHEnd = ''; + if (pwStart + pwEnd !== 0) { + codeW = ` for (int i = 0; i < ${kw}; i++) { x[${rank} - 1] = indices[${rank} - 1] * ${sw} - ${pwStart} + i; if (x[${rank} - 1] < 0 || x[${rank} - 1] >= ${dimW}) { @@ -231,22 +297,22 @@ const generatePoolingCode = } ${op1} }`; - } else { - codeW = ` + } else { + codeW = ` for (int i = 0; i < ${kw}; i++) { x[${rank} - 1] = indices[${rank} - 1] * ${sw} - ${pwStart} + i; ${op1} }`; - } + } - if (attributes.kernelShape.length === 2) { - const kh = attributes.kernelShape[attributes.kernelShape.length - 2]; - const sh = attributes.strides[attributes.strides.length - 2]; - const phStart = attributes.pads[attributes.pads.length / 2 - 2]; - const phEnd = attributes.pads[attributes.pads.length - 2]; - const dimH = inputDims[rank - 2]; - if (phStart + phEnd !== 0) { - codeH = ` + if (attributes.kernelShape.length === 2) { + const kh = attributes.kernelShape[attributes.kernelShape.length - 2]; + const sh = attributes.strides[attributes.strides.length - 2]; + const phStart = attributes.pads[attributes.pads.length / 2 - 2]; + const phEnd = attributes.pads[attributes.pads.length - 2]; + const dimH = inputDims[rank - 2]; + if (phStart + phEnd !== 0) { + codeH = ` for (int j = 0; j < ${kh}; j++) { x[${rank} - 2] = indices[${rank} - 2] * ${sh} - ${phStart} + j; if (x[${rank} - 2] < 0 || x[${rank} - 2] >= ${dimH}) { @@ -254,18 +320,18 @@ const generatePoolingCode = continue; } `; - } else { - codeH = ` + } else { + codeH = ` for (int j = 0; j < ${kh}; j++) { x[${rank} - 2] = indices[${rank} - 2] * ${sh} - ${phStart} + j; `; - } - codeHEnd = ` + } + codeHEnd = ` } `; - } + } - const poolingCode = ` + const poolingCode = ` float process(int indices[${rank}]) { int x[${rank}]; copyVec(indices, x); @@ -279,21 +345,21 @@ const generatePoolingCode = return value; } `; - return poolingCode; - } else { - const kernelSize = ShapeUtil.size(attributes.kernelShape); - const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape); - const stridesRank = kernelStrides.length; - const padsRank = attributes.pads.length; - const offsetToIndicesFunction = offsetToIndices(stridesRank); - const copyInputDims = copyArray(inputDims, 'inputDims'); - const copyPads = copyArray(attributes.pads, 'pads'); - const copyKernelStrides = copyArray(kernelStrides, 'kernelStrides'); - const copyStrides = copyArray(attributes.strides, 'strides'); - const hasPads = attributes.pads.reduce((sum, cur) => sum + cur); - let padCode = ''; - if (hasPads) { - padCode = ` + return poolingCode; + } else { + const kernelSize = ShapeUtil.size(attributes.kernelShape); + const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape); + const stridesRank = kernelStrides.length; + const padsRank = attributes.pads.length; + const offsetToIndicesFunction = offsetToIndices(stridesRank); + const copyInputDims = copyArray(inputDims, 'inputDims'); + const copyPads = copyArray(attributes.pads, 'pads'); + const copyKernelStrides = copyArray(kernelStrides, 'kernelStrides'); + const copyStrides = copyArray(attributes.strides, 'strides'); + const hasPads = attributes.pads.reduce((sum, cur) => sum + cur); + let padCode = ''; + if (hasPads) { + padCode = ` if (x[j] >= inputDims[j] || x[j] < 0) { pad++; isPad = true; @@ -303,13 +369,13 @@ const generatePoolingCode = if (!isPad) { ${op1} }`; - } else { - padCode = ` + } else { + padCode = ` } ${op1} `; - } - const poolingCode = ` + } + const poolingCode = ` ${offsetToIndicesFunction} float process(int indices[${rank}]) { int x[${rank}]; @@ -340,9 +406,9 @@ const generatePoolingCode = return value; } `; - return poolingCode; - } - }; + return poolingCode; + } +}; const copyArray = (array: readonly number[], arrayName: string): string => { let block = ''; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts index c9ea460a6f1fc..b0ddfb4b44b96 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {NUMBER_TYPES, OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramMetadata, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramMetadata, TextureType } from '../types'; export interface ReduceAttributes extends AttributeWithCacheKey { readonly axes: number[]; @@ -17,69 +17,78 @@ export interface ReduceAttributes extends AttributeWithCacheKey { // return [init ops, reduce ops, final ops] type ReduceOp = (inputs: Tensor[], axes: number[]) => string[]; -const reduce = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, name: string, - reduceOp: ReduceOp): Tensor[] => { - validateInputs(inputs); - - const reduceProgramMetadata = { - name, - inputNames: ['A'], - inputTypes: [TextureType.unpacked], - }; - - const output = inferenceHandler.run( - { - ...reduceProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => - createReduceProgramInfo(inferenceHandler, inputs, attributes, name, reduceOp, reduceProgramMetadata) - }, - inputs); - return [output]; - }; +const reduce = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ReduceAttributes, + name: string, + reduceOp: ReduceOp, +): Tensor[] => { + validateInputs(inputs); + + const reduceProgramMetadata = { + name, + inputNames: ['A'], + inputTypes: [TextureType.unpacked], + }; + + const output = inferenceHandler.run( + { + ...reduceProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createReduceProgramInfo(inferenceHandler, inputs, attributes, name, reduceOp, reduceProgramMetadata), + }, + inputs, + ); + return [output]; +}; export const parseReduceAttributes: OperatorInitialization = (node: Graph.Node): ReduceAttributes => { const axes = node.attributes.getInts('axes', []); const keepDims = node.attributes.getInt('keepdims', 1) === 1; - return createAttributeWithCacheKey({axes, keepDims}); + return createAttributeWithCacheKey({ axes, keepDims }); }; -const createReduceProgramInfo = - (_handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, _name: string, reduceOp: ReduceOp, - reduceProgramMetadata: ProgramMetadata): ProgramInfo => { - const outputShape: number[] = []; - const iRank = inputs[0].dims.length || 1; - - const idxCopy = []; // copy output indexes to input indexes - - const axes = ShapeUtil.normalizeAxes(attributes.axes, inputs[0].dims.length); - const ops = reduceOp(inputs, axes); - let reduceOps = ops[1]; - - for (let k = 0; k < inputs[0].dims.length; k++) { - // if this axis is reduced - if (axes.indexOf(k) >= 0 || axes.length === 0) { - if (attributes.keepDims) { - outputShape.push(1); - } // else { remove the axis from outputShape; } - - // loop over the d-th axis - reduceOps = ` +const createReduceProgramInfo = ( + _handler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ReduceAttributes, + _name: string, + reduceOp: ReduceOp, + reduceProgramMetadata: ProgramMetadata, +): ProgramInfo => { + const outputShape: number[] = []; + const iRank = inputs[0].dims.length || 1; + + const idxCopy = []; // copy output indexes to input indexes + + const axes = ShapeUtil.normalizeAxes(attributes.axes, inputs[0].dims.length); + const ops = reduceOp(inputs, axes); + let reduceOps = ops[1]; + + for (let k = 0; k < inputs[0].dims.length; k++) { + // if this axis is reduced + if (axes.indexOf(k) >= 0 || axes.length === 0) { + if (attributes.keepDims) { + outputShape.push(1); + } // else { remove the axis from outputShape; } + + // loop over the d-th axis + reduceOps = ` for(int j${k} = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) { inputIdx[${k}] = j${k}; ${reduceOps} }`; - } else { - idxCopy.push(`inputIdx[${k}] = outputIdx[${outputShape.length}];`); + } else { + idxCopy.push(`inputIdx[${k}] = outputIdx[${outputShape.length}];`); - outputShape.push(inputs[0].dims[k]); - } - } + outputShape.push(inputs[0].dims[k]); + } + } - const oRank = outputShape.length || 1; + const oRank = outputShape.length || 1; - const shaderSource = ` + const shaderSource = ` float process(int outputIdx[${oRank}]) { float value; // final result int inputIdx[${iRank}]; // addressing input data @@ -90,12 +99,12 @@ const createReduceProgramInfo = return value; }`; - return { - ...reduceProgramMetadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...reduceProgramMetadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; const validateInputs = (inputs: Tensor[]): void => { // TODO: support Reduce* operators with 2 inputs. @@ -108,71 +117,92 @@ const validateInputs = (inputs: Tensor[]): void => { } }; -export const reduceSum: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => { - const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', '']; - return reduce(inferenceHandler, inputs, attributes, 'ReduceSum', reduceOp); - }; - -export const reduceMean: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => { - const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => { - let size = 1.0; - for (let k = 0; k < inputs[0].dims.length; k++) { - if (axes.indexOf(k) >= 0 || axes.length === 0) { - size *= inputs[0].dims[k]; - } - } - - return ['value = 0.0;', 'value += _A(inputIdx);', `value /= ${size}.;`]; // ensure real number with `.` - }; - return reduce(inferenceHandler, inputs, attributes, 'ReduceMean', reduceOp); - }; - -export const reduceMax: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => { - const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => { - const idxZero = []; - for (let k = 0; k < inputs[0].dims.length; k++) { - if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`inputIdx[${k}] = 0;`); // first element - } - } - - return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = max(value, _A(inputIdx));', '']; - }; - return reduce(inferenceHandler, inputs, attributes, 'ReduceMax', reduceOp); - }; - -export const reduceMin: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => { - const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => { - const idxZero = []; - for (let k = 0; k < inputs[0].dims.length; k++) { - if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`inputIdx[${k}] = 0;`); // first element - } - } - - return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = min(value, _A(inputIdx));', '']; - }; - return reduce(inferenceHandler, inputs, attributes, 'ReduceMin', reduceOp); - }; - -export const reduceProd: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => { - const reduceOp: ReduceOp = (): string[] => ['value = 1.0;', 'value *= _A(inputIdx);', '']; - return reduce(inferenceHandler, inputs, attributes, 'ReduceProd', reduceOp); - }; - -export const reduceLogSum: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => { - const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', 'value = log(value);']; - return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSum', reduceOp); - }; - -export const reduceLogSumSquare: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => { - const reduceOp: ReduceOp = (): string[] => ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', '']; - return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSumSquare', reduceOp); - }; +export const reduceSum: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ReduceAttributes, +): Tensor[] => { + const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', '']; + return reduce(inferenceHandler, inputs, attributes, 'ReduceSum', reduceOp); +}; + +export const reduceMean: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ReduceAttributes, +): Tensor[] => { + const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => { + let size = 1.0; + for (let k = 0; k < inputs[0].dims.length; k++) { + if (axes.indexOf(k) >= 0 || axes.length === 0) { + size *= inputs[0].dims[k]; + } + } + + return ['value = 0.0;', 'value += _A(inputIdx);', `value /= ${size}.;`]; // ensure real number with `.` + }; + return reduce(inferenceHandler, inputs, attributes, 'ReduceMean', reduceOp); +}; + +export const reduceMax: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ReduceAttributes, +): Tensor[] => { + const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => { + const idxZero = []; + for (let k = 0; k < inputs[0].dims.length; k++) { + if (axes.indexOf(k) >= 0 || axes.length === 0) { + idxZero.push(`inputIdx[${k}] = 0;`); // first element + } + } + + return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = max(value, _A(inputIdx));', '']; + }; + return reduce(inferenceHandler, inputs, attributes, 'ReduceMax', reduceOp); +}; + +export const reduceMin: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ReduceAttributes, +): Tensor[] => { + const reduceOp: ReduceOp = (inputs: Tensor[], axes: number[]): string[] => { + const idxZero = []; + for (let k = 0; k < inputs[0].dims.length; k++) { + if (axes.indexOf(k) >= 0 || axes.length === 0) { + idxZero.push(`inputIdx[${k}] = 0;`); // first element + } + } + + return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = min(value, _A(inputIdx));', '']; + }; + return reduce(inferenceHandler, inputs, attributes, 'ReduceMin', reduceOp); +}; + +export const reduceProd: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ReduceAttributes, +): Tensor[] => { + const reduceOp: ReduceOp = (): string[] => ['value = 1.0;', 'value *= _A(inputIdx);', '']; + return reduce(inferenceHandler, inputs, attributes, 'ReduceProd', reduceOp); +}; + +export const reduceLogSum: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ReduceAttributes, +): Tensor[] => { + const reduceOp: ReduceOp = (): string[] => ['value = 0.0;', 'value += _A(inputIdx);', 'value = log(value);']; + return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSum', reduceOp); +}; + +export const reduceLogSumSquare: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: ReduceAttributes, +): Tensor[] => { + const reduceOp: ReduceOp = (): string[] => ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', '']; + return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSumSquare', reduceOp); +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/reshape-packed.ts b/js/web/lib/onnxjs/backends/webgl/ops/reshape-packed.ts index bc7e823610d84..5de23c7f6799c 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/reshape-packed.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/reshape-packed.ts @@ -1,44 +1,51 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; - -import {unpackFromChannel} from './packing-utils'; - -const createPackedReshape3DProgramMetadata = (outputShape3D: readonly number[]) => - ({name: 'Reshape (packed)', inputTypes: [TextureType.packed], inputNames: ['A'], cacheHint: `${outputShape3D}`}); - -const createPackedReshape3DProgramInfo = - (handler: WebGLInferenceHandler, input3D: Tensor, metadata: ProgramMetadata, outputShape3D: readonly number[]): - ProgramInfo => { - const inputShape3D = input3D.dims as [number, number, number]; - const squeezedOutputShape = outputShape3D as [number, number, number]; - - let mainLoop = ''; - for (let i = 0; i < 4; i++) { - let outputCoords = ''; - switch (i) { - case 0: - outputCoords = 'outputCoords = rc;'; - break; - case 1: - outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z);'; - break; - case 2: - outputCoords = 'outputCoords = ivec3(rc.x, rc.y, rc.z+1);'; - break; - case 3: - outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z+1);'; - break; - default: - throw new Error(); - } - - mainLoop += ` +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; + +import { unpackFromChannel } from './packing-utils'; + +const createPackedReshape3DProgramMetadata = (outputShape3D: readonly number[]) => ({ + name: 'Reshape (packed)', + inputTypes: [TextureType.packed], + inputNames: ['A'], + cacheHint: `${outputShape3D}`, +}); + +const createPackedReshape3DProgramInfo = ( + handler: WebGLInferenceHandler, + input3D: Tensor, + metadata: ProgramMetadata, + outputShape3D: readonly number[], +): ProgramInfo => { + const inputShape3D = input3D.dims as [number, number, number]; + const squeezedOutputShape = outputShape3D as [number, number, number]; + + let mainLoop = ''; + for (let i = 0; i < 4; i++) { + let outputCoords = ''; + switch (i) { + case 0: + outputCoords = 'outputCoords = rc;'; + break; + case 1: + outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z);'; + break; + case 2: + outputCoords = 'outputCoords = ivec3(rc.x, rc.y, rc.z+1);'; + break; + case 3: + outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z+1);'; + break; + default: + throw new Error(); + } + + mainLoop += ` ${outputCoords} ${i > 0 ? 'if(outputCoords.y < rows && outputCoords.z < cols){' : ''} int flattenedIndex = getFlattenedIndex(outputCoords); @@ -50,10 +57,10 @@ const createPackedReshape3DProgramInfo = ${i > 0 ? '}' : ''} `; - } - const glsl = getGlsl(handler.session.backend.glContext.version); + } + const glsl = getGlsl(handler.session.backend.glContext.version); - const shaderSource = ` + const shaderSource = ` ${getReshapedInputCoords(inputShape3D)} ${getFlattenedIndexFrom3D(squeezedOutputShape)} ${unpackFromChannel()} @@ -72,19 +79,22 @@ const createPackedReshape3DProgramInfo = } `; - return { - ...metadata, - output: {dims: squeezedOutputShape, type: input3D.type, textureType: TextureType.packed}, - shaderSource, - hasMain: true - }; - }; - -export const createPackedReshape3DProgramInfoLoader = - (handler: WebGLInferenceHandler, input3D: Tensor, outputShape3D: readonly number[]): ProgramInfoLoader => { - const metadata = createPackedReshape3DProgramMetadata(outputShape3D); - return {...metadata, get: () => createPackedReshape3DProgramInfo(handler, input3D, metadata, outputShape3D)}; - }; + return { + ...metadata, + output: { dims: squeezedOutputShape, type: input3D.type, textureType: TextureType.packed }, + shaderSource, + hasMain: true, + }; +}; + +export const createPackedReshape3DProgramInfoLoader = ( + handler: WebGLInferenceHandler, + input3D: Tensor, + outputShape3D: readonly number[], +): ProgramInfoLoader => { + const metadata = createPackedReshape3DProgramMetadata(outputShape3D); + return { ...metadata, get: () => createPackedReshape3DProgramInfo(handler, input3D, metadata, outputShape3D) }; +}; export function processDims3D(shape: ArrayLike): [number, number, number] { if (shape.length === 0) { @@ -111,13 +121,17 @@ export function processDims3D(shape: ArrayLike): [number, number, number // treated as no-op. export function isReshapeCheap(dims: readonly number[], reshapedDims: readonly number[]) { let isCheapReshape = false; - if (dims.length === 0 || reshapedDims.length === 0) { // scalar + if (dims.length === 0 || reshapedDims.length === 0) { + // scalar isCheapReshape = true; - } else if (dims.length < 2 || reshapedDims.length < 2) { // 1D + } else if (dims.length < 2 || reshapedDims.length < 2) { + // 1D isCheapReshape = dims[dims.length - 1] === reshapedDims[reshapedDims.length - 1]; - } else { // 2D + - isCheapReshape = dims[dims.length - 1] === reshapedDims[reshapedDims.length - 1] && - dims[dims.length - 2] === reshapedDims[reshapedDims.length - 2]; + } else { + // 2D + + isCheapReshape = + dims[dims.length - 1] === reshapedDims[reshapedDims.length - 1] && + dims[dims.length - 2] === reshapedDims[reshapedDims.length - 2]; } return isCheapReshape; @@ -128,14 +142,15 @@ function getReshapedInputCoords(shape: [number, number, number]): string { const coords = ['b', 'r', 'c']; const index = 'index'; const coordsFromIndexSnippet = strides - .map((stride, i) => { - const line1 = `int ${coords[i]} = ${index} / ${stride}`; - const line2 = i === strides.length - 1 ? - `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` : - `index -= ${coords[i]} * ${stride}`; - return `${line1}; ${line2};`; - }) - .join(''); + .map((stride, i) => { + const line1 = `int ${coords[i]} = ${index} / ${stride}`; + const line2 = + i === strides.length - 1 + ? `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` + : `index -= ${coords[i]} * ${stride}`; + return `${line1}; ${line2};`; + }) + .join(''); return ` ivec3 inputCoordsFromReshapedOutCoords(int index) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/reshape.ts b/js/web/lib/onnxjs/backends/webgl/ops/reshape.ts index 792fccc9d6d41..2fd66472d9d16 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/reshape.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/reshape.ts @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; export const reshape = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { const reshapedDims = ShapeUtil.calculateReshapedDims(inputs[0].dims, inputs[1].integerData); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/resize-packed.ts b/js/web/lib/onnxjs/backends/webgl/ops/resize-packed.ts index c0d485d95f036..03f36f7ac6ca4 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/resize-packed.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/resize-packed.ts @@ -1,102 +1,110 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, TextureType} from '../types'; -import {getCoordsDataType} from '../utils'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, TextureType } from '../types'; +import { getCoordsDataType } from '../utils'; -import {unpackFromChannel} from './packing-utils'; -import {parseUpsampleAttributes, scalesValidation, UpsampleAttributes, validateInputs} from './upsample'; +import { unpackFromChannel } from './packing-utils'; +import { parseUpsampleAttributes, scalesValidation, UpsampleAttributes, validateInputs } from './upsample'; const resizeProgramMetadata = { name: 'Resize', inputNames: ['A'], - inputTypes: [TextureType.packed] + inputTypes: [TextureType.packed], }; -export const resize: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: UpsampleAttributes): Tensor[] => { - validateInputs(inputs, attributes); - const output = inferenceHandler.run( - { - ...resizeProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => createPackedResizeProgramInfo(inferenceHandler, inputs, attributes) - }, - inputs); - return [output]; - }; +export const resize: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: UpsampleAttributes, +): Tensor[] => { + validateInputs(inputs, attributes); + const output = inferenceHandler.run( + { + ...resizeProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createPackedResizeProgramInfo(inferenceHandler, inputs, attributes), + }, + inputs, + ); + return [output]; +}; -export const parseResizeAttributesV10: OperatorInitialization = - (node: Graph.Node): UpsampleAttributes => parseUpsampleAttributes(node, 10); - -export const parseResizeAttributesV11: OperatorInitialization = - (node: Graph.Node): UpsampleAttributes => parseUpsampleAttributes(node, 11); - -const createPackedResizeProgramInfo = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: UpsampleAttributes): ProgramInfo => { - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const [scales, outputShape] = prepareInputs(inputs, attributes); - - const isSame = - scales.every((s: number) => s === 1) && attributes.coordinateTransformMode !== 'tf_crop_and_resize'; - if (isSame) { - return { - ...resizeProgramMetadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.packed}, - hasMain: true, - shaderSource: `void main() { +export const parseResizeAttributesV10: OperatorInitialization = ( + node: Graph.Node, +): UpsampleAttributes => parseUpsampleAttributes(node, 10); + +export const parseResizeAttributesV11: OperatorInitialization = ( + node: Graph.Node, +): UpsampleAttributes => parseUpsampleAttributes(node, 11); + +const createPackedResizeProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: UpsampleAttributes, +): ProgramInfo => { + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const [scales, outputShape] = prepareInputs(inputs, attributes); + + const isSame = scales.every((s: number) => s === 1) && attributes.coordinateTransformMode !== 'tf_crop_and_resize'; + if (isSame) { + return { + ...resizeProgramMetadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed }, + hasMain: true, + shaderSource: `void main() { vec4 v = ${glsl.texture2D}(X, TexCoords); ${glsl.output} = v; - }` - }; - } + }`, + }; + } - const dim = outputShape.length; - if (dim < 2) { - throw new Error(`output dimension should be at least 2, but got ${dim}`); - } + const dim = outputShape.length; + if (dim < 2) { + throw new Error(`output dimension should be at least 2, but got ${dim}`); + } - const outputHeight = outputShape[dim - 2]; - const outputWidth = outputShape[dim - 1]; + const outputHeight = outputShape[dim - 2]; + const outputWidth = outputShape[dim - 1]; - const inputShape = inputs[0].dims; - if (dim !== inputShape.length) { - throw new Error(`output dimension should match input ${inputShape.length}, but got ${dim}`); - } - const inputHeight = inputShape[dim - 2]; - const inputWidth = inputShape[dim - 1]; + const inputShape = inputs[0].dims; + if (dim !== inputShape.length) { + throw new Error(`output dimension should match input ${inputShape.length}, but got ${dim}`); + } + const inputHeight = inputShape[dim - 2]; + const inputWidth = inputShape[dim - 1]; - const scalesHeight = scales[dim - 2]; - const scalesWidth = scales[dim - 1]; + const scalesHeight = scales[dim - 2]; + const scalesWidth = scales[dim - 1]; - let getSourceFracIndex = ''; + let getSourceFracIndex = ''; - if (attributes.mode !== 'linear') { - // TODO: support other modes - throw new Error(`resize (packed) does not support mode: '${attributes.mode}'`); - } - switch (attributes.coordinateTransformMode) { - case 'asymmetric': - getSourceFracIndex = ` + if (attributes.mode !== 'linear') { + // TODO: support other modes + throw new Error(`resize (packed) does not support mode: '${attributes.mode}'`); + } + switch (attributes.coordinateTransformMode) { + case 'asymmetric': + getSourceFracIndex = ` vec4 getSourceFracIndex(ivec4 coords) { return vec4(coords) / scaleWHWH; } `; - break; - case 'half_pixel': - getSourceFracIndex = ` + break; + case 'half_pixel': + getSourceFracIndex = ` vec4 getSourceFracIndex(ivec4 coords) { return (vec4(coords) + 0.5) / scaleWHWH - 0.5; } `; - break; - case 'pytorch_half_pixel': - getSourceFracIndex = ` + break; + case 'pytorch_half_pixel': + getSourceFracIndex = ` vec4 getSourceFracIndex(ivec4 coords) { vec4 fcoords = vec4(coords); return vec4( @@ -107,9 +115,9 @@ const createPackedResizeProgramInfo = ); } `; - break; - case 'align_corners': - getSourceFracIndex = ` + break; + case 'align_corners': + getSourceFracIndex = ` vec4 getSourceFracIndex(ivec4 coords) { vec4 resized = vec4(${outputWidth}.0 - 1.0, ${outputHeight}.0 - 1.0, ${outputWidth}.0 - 1.0, ${outputHeight}.0 - 1.0); @@ -119,19 +127,20 @@ const createPackedResizeProgramInfo = return vec4(coords) * new_scale; } `; - break; - default: - // TODO:supporting other coordinateTransformModes - throw new Error(`resize (packed) does not support coordinateTransformMode: \ + break; + default: + // TODO:supporting other coordinateTransformModes + throw new Error(`resize (packed) does not support coordinateTransformMode: \ '${attributes.coordinateTransformMode}'`); - } + } - const coordsDataType = getCoordsDataType(dim); - const unpackChannel = unpackFromChannel(); - const shaderSource = ` + const coordsDataType = getCoordsDataType(dim); + const unpackChannel = unpackFromChannel(); + const shaderSource = ` const vec2 inputWH = vec2(${inputHeight}.0, ${inputWidth}.0); const vec4 scaleWHWH = vec4(float(${scalesHeight}), float(${scalesWidth}), float(${scalesHeight}), float(${ - scalesWidth})); + scalesWidth + })); ${unpackChannel} ${getSourceFracIndex} float getAValue(int x10, int r, int c, int d) { @@ -197,21 +206,20 @@ const createPackedResizeProgramInfo = ${glsl.output} = vec4(newValue); } `; - return { - ...resizeProgramMetadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.packed}, - hasMain: true, - shaderSource - }; - }; - + return { + ...resizeProgramMetadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed }, + hasMain: true, + shaderSource, + }; +}; const prepareInputs = (inputs: Tensor[], attributes: UpsampleAttributes): [readonly number[], readonly number[]] => { const x = inputs[0]; const xDims = x.dims; let scales = attributes.scales; - let outputSizes: number[]|undefined; + let outputSizes: number[] | undefined; if (scales.length === 0) { const scalesTensor = inputs[attributes.scalesInputIdx]; if (scalesTensor && scalesTensor.size !== 0) { @@ -234,7 +242,7 @@ const prepareInputs = (inputs: Tensor[], attributes: UpsampleAttributes): [reado } } - const yDims = outputSizes || (xDims.map((dim, i) => Math.floor(dim * scales[i]))); + const yDims = outputSizes || xDims.map((dim, i) => Math.floor(dim * scales[i])); return [scales, yDims]; }; @@ -245,24 +253,28 @@ const parseScalesData = (scale: Tensor, mode: string, isResize: boolean): number return scales; }; -const parseScalesDataFromOutputSize = - (yDims: readonly number[], xDims: readonly number[], mode: string, isResize: boolean): number[] => { - const length = xDims.length; - const scales = new Array(length); - - for (let i = 0, end = length; i < end; i++) { - if (xDims[i] === 0) { - if (yDims[i] !== 0) { - throw new Error('Input dim is zero but required output dim is non-zero.'); - } - scales[i] = 1; - } else { - scales[i] = yDims[i] / xDims[i]; - } +const parseScalesDataFromOutputSize = ( + yDims: readonly number[], + xDims: readonly number[], + mode: string, + isResize: boolean, +): number[] => { + const length = xDims.length; + const scales = new Array(length); + + for (let i = 0, end = length; i < end; i++) { + if (xDims[i] === 0) { + if (yDims[i] !== 0) { + throw new Error('Input dim is zero but required output dim is non-zero.'); } - scalesValidation(scales, mode, isResize); - return scales; - }; + scales[i] = 1; + } else { + scales[i] = yDims[i] / xDims[i]; + } + } + scalesValidation(scales, mode, isResize); + return scales; +}; // roi data is not used yet. but leave here for future usage. // const getRoi = (inputs: Tensor[], attributes: UpsampleAttributes) : number[] => { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts index c2d703ed04fa0..24453d14f35ae 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {WebGLInferenceHandler} from '../inference-handler'; +import { Tensor } from '../../../tensor'; +import { WebGLInferenceHandler } from '../inference-handler'; export const shape = (_inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputs(inputs); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/slice.ts b/js/web/lib/onnxjs/backends/webgl/ops/slice.ts index 81fc1b7076fdb..f147a22cccc5f 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/slice.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/slice.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {NUMBER_TYPES, OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, TextureType } from '../types'; export interface SliceAttributes extends AttributeWithCacheKey { readonly axes: number[]; @@ -18,68 +18,75 @@ export interface SliceAttributes extends AttributeWithCacheKey { const sliceProgramMetadata = { name: 'Slice', inputNames: ['A'], - inputTypes: [TextureType.unpacked] + inputTypes: [TextureType.unpacked], }; -export const slice: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SliceAttributes): Tensor[] => { - validateInputs(inputs); - const output = inferenceHandler.run( - { - ...sliceProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => createSliceProgramInfo(inferenceHandler, inputs[0], attributes) - }, - inputs); - return [output]; - }; +export const slice: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: SliceAttributes, +): Tensor[] => { + validateInputs(inputs); + const output = inferenceHandler.run( + { + ...sliceProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createSliceProgramInfo(inferenceHandler, inputs[0], attributes), + }, + inputs, + ); + return [output]; +}; export const parseSliceAttributes: OperatorInitialization = (node: Graph.Node): SliceAttributes => { const starts = node.attributes.getInts('starts'); const ends = node.attributes.getInts('ends'); const axes = node.attributes.getInts('axes', []); - return createAttributeWithCacheKey({starts, ends, axes}); + return createAttributeWithCacheKey({ starts, ends, axes }); }; -const createSliceProgramInfo = - (_inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SliceAttributes): ProgramInfo => { - const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((_val, i) => i) : attributes.axes; - const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length); - const starts = attributes.starts.map((start, i) => { - if (start > input.dims[normalizedAxes[i]] - 1) { - return input.dims[normalizedAxes[i]]; - } - return ShapeUtil.normalizeAxis(start, input.dims[normalizedAxes[i]]); - }); - const ends = attributes.ends.map((end, i) => { - if (end > input.dims[normalizedAxes[i]] - 1) { - return input.dims[normalizedAxes[i]]; - } - return ShapeUtil.normalizeAxis(end, input.dims[normalizedAxes[i]]); - }); +const createSliceProgramInfo = ( + _inferenceHandler: WebGLInferenceHandler, + input: Tensor, + attributes: SliceAttributes, +): ProgramInfo => { + const axes = attributes.axes.length === 0 ? input.dims.slice(0).map((_val, i) => i) : attributes.axes; + const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length); + const starts = attributes.starts.map((start, i) => { + if (start > input.dims[normalizedAxes[i]] - 1) { + return input.dims[normalizedAxes[i]]; + } + return ShapeUtil.normalizeAxis(start, input.dims[normalizedAxes[i]]); + }); + const ends = attributes.ends.map((end, i) => { + if (end > input.dims[normalizedAxes[i]] - 1) { + return input.dims[normalizedAxes[i]]; + } + return ShapeUtil.normalizeAxis(end, input.dims[normalizedAxes[i]]); + }); - const outputShape = input.dims.slice(); + const outputShape = input.dims.slice(); - const sliceOps: string[] = []; - for (let i = 0; i < normalizedAxes.length; i++) { - outputShape[normalizedAxes[i]] = ends[i] - starts[i]; - if (starts[i] > 0) { - sliceOps.push(`outputIdx[${normalizedAxes[i]}] += ${starts[i]};`); - } // else { sliceOps.push(`outputIdx[${normalizedAxes[i]}] += 0;`); } - } + const sliceOps: string[] = []; + for (let i = 0; i < normalizedAxes.length; i++) { + outputShape[normalizedAxes[i]] = ends[i] - starts[i]; + if (starts[i] > 0) { + sliceOps.push(`outputIdx[${normalizedAxes[i]}] += ${starts[i]};`); + } // else { sliceOps.push(`outputIdx[${normalizedAxes[i]}] += 0;`); } + } - const rank = outputShape.length; - const shaderSource = ` + const rank = outputShape.length; + const shaderSource = ` float process(int outputIdx[${rank}]) { ${sliceOps.join('\n ')} return _A(outputIdx); }`; - return { - ...sliceProgramMetadata, - output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...sliceProgramMetadata, + output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { @@ -94,34 +101,39 @@ export const sliceV10 = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor validateInputsV10(inputs); const attributes = generateSliceAttributesFromInputs(inferenceHandler, inputs); const output = inferenceHandler.run( - { - ...sliceProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => createSliceProgramInfo(inferenceHandler, inputs[0], attributes) - }, - [inputs[0]]); + { + ...sliceProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createSliceProgramInfo(inferenceHandler, inputs[0], attributes), + }, + [inputs[0]], + ); return [output]; }; -const generateSliceAttributesFromInputs = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): SliceAttributes => { - if (!inferenceHandler.session.isInitializer(inputs[1].dataId) || - !inferenceHandler.session.isInitializer(inputs[2].dataId) || - (inputs.length >= 4 && !inferenceHandler.session.isInitializer(inputs[3].dataId)) || - (inputs.length >= 5 && !inferenceHandler.session.isInitializer(inputs[4].dataId))) { - throw new Error('dynamic slice attributes are not allowed'); - } +const generateSliceAttributesFromInputs = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], +): SliceAttributes => { + if ( + !inferenceHandler.session.isInitializer(inputs[1].dataId) || + !inferenceHandler.session.isInitializer(inputs[2].dataId) || + (inputs.length >= 4 && !inferenceHandler.session.isInitializer(inputs[3].dataId)) || + (inputs.length >= 5 && !inferenceHandler.session.isInitializer(inputs[4].dataId)) + ) { + throw new Error('dynamic slice attributes are not allowed'); + } - if (inputs.length >= 5 && inputs[4].integerData.some((i: number) => i !== 1)) { - throw new Error('currently non-1 steps is not supported for Slice'); - } + if (inputs.length >= 5 && inputs[4].integerData.some((i: number) => i !== 1)) { + throw new Error('currently non-1 steps is not supported for Slice'); + } - const starts = Array.from(inputs[1].integerData); - const ends = Array.from(inputs[2].integerData); - const axes = inputs.length >= 4 ? Array.from(inputs[3].integerData) : []; - const cacheKey = `${axes};${starts};${ends}`; - return {starts, ends, axes, cacheKey}; - }; + const starts = Array.from(inputs[1].integerData); + const ends = Array.from(inputs[2].integerData); + const axes = inputs.length >= 4 ? Array.from(inputs[3].integerData) : []; + const cacheKey = `${axes};${starts};${ends}`; + return { starts, ends, axes, cacheKey }; +}; const validateInputsV10 = (inputs: Tensor[]): void => { if (!inputs || inputs.length < 3 || inputs.length > 5) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/softmax.ts b/js/web/lib/onnxjs/backends/webgl/ops/softmax.ts index 585fbf7bbf01b..67143c3ac0fa4 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/softmax.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/softmax.ts @@ -1,16 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, TextureType } from '../types'; -import {transpose, TransposeAttributes} from './transpose'; +import { transpose, TransposeAttributes } from './transpose'; export interface SoftmaxAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -34,24 +34,29 @@ const softmaxProgramMetadata = { inputTypes: [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked], }; -export const softmax: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SoftmaxAttributes): Tensor[] => { - validateInputs(inputs); +export const softmax: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: SoftmaxAttributes, +): Tensor[] => { + validateInputs(inputs); - const inputShape = inputs[0].dims.slice(); - const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); - const logicalRowCount = ShapeUtil.sizeToDimension(inputShape, axis); - const featureCount = ShapeUtil.sizeFromDimension(inputShape, axis); + const inputShape = inputs[0].dims.slice(); + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); + const logicalRowCount = ShapeUtil.sizeToDimension(inputShape, axis); + const featureCount = ShapeUtil.sizeFromDimension(inputShape, axis); - const output = computeSoftmax(inferenceHandler, inputs, attributes, logicalRowCount, featureCount); - return output; - }; + const output = computeSoftmax(inferenceHandler, inputs, attributes, logicalRowCount, featureCount); + return output; +}; -export const parseSoftmaxAttributes: OperatorInitialization = - (node: Graph.Node): SoftmaxAttributes => createAttributeWithCacheKey({axis: node.attributes.getInt('axis', 1)}); +export const parseSoftmaxAttributes: OperatorInitialization = ( + node: Graph.Node, +): SoftmaxAttributes => createAttributeWithCacheKey({ axis: node.attributes.getInt('axis', 1) }); -export const parseSoftmaxAttributesV13: OperatorInitialization = - (node: Graph.Node): SoftmaxAttributes => createAttributeWithCacheKey({axis: node.attributes.getInt('axis', -1)}); +export const parseSoftmaxAttributesV13: OperatorInitialization = ( + node: Graph.Node, +): SoftmaxAttributes => createAttributeWithCacheKey({ axis: node.attributes.getInt('axis', -1) }); // The "semantic" meaning of axis has changed in opset-13. // Please compare: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Softmax @@ -59,98 +64,136 @@ export const parseSoftmaxAttributesV13: OperatorInitialization = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SoftmaxAttributes): Tensor[] => { - validateInputs(inputs); - - const inputShape = inputs[0].dims.slice(); - const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); - const rank = inputShape.length; - - const isTransposeRequired = (axis !== rank - 1) ? true : false; - const transposedInputShape: number[] = []; - let perm: number[] = []; - let transposedInputs: Tensor[] = []; - let transposeAttribute: TransposeAttributes; - - if (isTransposeRequired) { - perm = Array.from({length: rank}).map((_, i) => i); - - // swap the innermost dim with the dim corresponding to axis - perm[axis] = rank - 1; - perm[rank - 1] = axis; - - perm.map(p => transposedInputShape.push(inputShape[p])); - - transposeAttribute = createAttributeWithCacheKey({perm}); - transposedInputs = transpose(inferenceHandler, inputs, transposeAttribute); - } - - const logicalRowCount = isTransposeRequired ? ShapeUtil.sizeToDimension(transposedInputShape, rank - 1) : - ShapeUtil.sizeToDimension(inputShape, rank - 1); - const featureCount = isTransposeRequired ? ShapeUtil.sizeFromDimension(transposedInputShape, rank - 1) : - ShapeUtil.sizeFromDimension(inputShape, rank - 1); - - const output = computeSoftmax( - inferenceHandler, isTransposeRequired ? transposedInputs : inputs, attributes, logicalRowCount, featureCount); - - if (isTransposeRequired) { - const reversedOutput = transpose(inferenceHandler, output, transposeAttribute!); - return reversedOutput; - } else { - return output; - } - }; - -const computeSoftmax = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SoftmaxAttributes, logicalRowCount: number, - featureCount: number): Tensor[] => { - const computeMaxProgramInfo = - createComputeMaxProgramInfo(inferenceHandler, inputs[0], logicalRowCount, featureCount, [logicalRowCount]); - const max = inferenceHandler.run( - {...softmaxComputeMaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeMaxProgramInfo}, - inputs); - - const computeScaleProgramInfo = createComputScaleProgramInfo( - inferenceHandler, inputs[0], logicalRowCount, featureCount, computeMaxProgramInfo.output.dims, - [logicalRowCount]); - const scale = inferenceHandler.run( - {...softmaxComputeScaleProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeScaleProgramInfo}, - [inputs[0], max]); - - const softMaxProgramInfo = createSoftMaxProgramInfo( - inferenceHandler, inputs[0], logicalRowCount, featureCount, computeMaxProgramInfo.output.dims, - computeScaleProgramInfo.output.dims); - const output = inferenceHandler.run( - {...softmaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => softMaxProgramInfo}, - [inputs[0], max, scale]); - return [output]; - }; +export const softmaxV13: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: SoftmaxAttributes, +): Tensor[] => { + validateInputs(inputs); + + const inputShape = inputs[0].dims.slice(); + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); + const rank = inputShape.length; + + const isTransposeRequired = axis !== rank - 1 ? true : false; + const transposedInputShape: number[] = []; + let perm: number[] = []; + let transposedInputs: Tensor[] = []; + let transposeAttribute: TransposeAttributes; + + if (isTransposeRequired) { + perm = Array.from({ length: rank }).map((_, i) => i); + + // swap the innermost dim with the dim corresponding to axis + perm[axis] = rank - 1; + perm[rank - 1] = axis; + + perm.map((p) => transposedInputShape.push(inputShape[p])); + + transposeAttribute = createAttributeWithCacheKey({ perm }); + transposedInputs = transpose(inferenceHandler, inputs, transposeAttribute); + } + + const logicalRowCount = isTransposeRequired + ? ShapeUtil.sizeToDimension(transposedInputShape, rank - 1) + : ShapeUtil.sizeToDimension(inputShape, rank - 1); + const featureCount = isTransposeRequired + ? ShapeUtil.sizeFromDimension(transposedInputShape, rank - 1) + : ShapeUtil.sizeFromDimension(inputShape, rank - 1); + + const output = computeSoftmax( + inferenceHandler, + isTransposeRequired ? transposedInputs : inputs, + attributes, + logicalRowCount, + featureCount, + ); + + if (isTransposeRequired) { + const reversedOutput = transpose(inferenceHandler, output, transposeAttribute!); + return reversedOutput; + } else { + return output; + } +}; + +const computeSoftmax = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: SoftmaxAttributes, + logicalRowCount: number, + featureCount: number, +): Tensor[] => { + const computeMaxProgramInfo = createComputeMaxProgramInfo( + inferenceHandler, + inputs[0], + logicalRowCount, + featureCount, + [logicalRowCount], + ); + const max = inferenceHandler.run( + { ...softmaxComputeMaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeMaxProgramInfo }, + inputs, + ); + + const computeScaleProgramInfo = createComputScaleProgramInfo( + inferenceHandler, + inputs[0], + logicalRowCount, + featureCount, + computeMaxProgramInfo.output.dims, + [logicalRowCount], + ); + const scale = inferenceHandler.run( + { ...softmaxComputeScaleProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeScaleProgramInfo }, + [inputs[0], max], + ); + + const softMaxProgramInfo = createSoftMaxProgramInfo( + inferenceHandler, + inputs[0], + logicalRowCount, + featureCount, + computeMaxProgramInfo.output.dims, + computeScaleProgramInfo.output.dims, + ); + const output = inferenceHandler.run( + { ...softmaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => softMaxProgramInfo }, + [inputs[0], max, scale], + ); + return [output]; +}; /** * Create a texture that contains the maximum value of each of the 'N' rows */ -const createComputeMaxProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, logicalRowCount: number, featureCount: number, - outputShape: number[]): ProgramInfo => { - const [textureWidth, textureHeight] = - inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked); - const rank = outputShape.length; - - if (logicalRowCount < 1 || featureCount < 1) { - throw new Error('Logical row count N and feature count D must be greater than or equal to 1'); - } - - if (outputShape.length !== 1) { - throw new Error('Dimensionality of the output should be 1'); - } - - if (outputShape[0] !== logicalRowCount) { - throw new Error('Shape of the output should be equal to logical row count'); - } - - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const shaderSource = ` +const createComputeMaxProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + input: Tensor, + logicalRowCount: number, + featureCount: number, + outputShape: number[], +): ProgramInfo => { + const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight( + input.dims, + TextureType.unpacked, + ); + const rank = outputShape.length; + + if (logicalRowCount < 1 || featureCount < 1) { + throw new Error('Logical row count N and feature count D must be greater than or equal to 1'); + } + + if (outputShape.length !== 1) { + throw new Error('Dimensionality of the output should be 1'); + } + + if (outputShape[0] !== logicalRowCount) { + throw new Error('Shape of the output should be equal to logical row count'); + } + + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const shaderSource = ` float process(int[${rank}] indices) { int logical_row_start_offset = indices[0] * ${featureCount}; @@ -166,45 +209,52 @@ const createComputeMaxProgramInfo = return max; }`; - return { - ...softmaxComputeMaxProgramMetadata, - output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...softmaxComputeMaxProgramMetadata, + output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; /** * Create a texture that contains the normalization factor for each of the 'N' rows */ -const createComputScaleProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, logicalRowCount: number, featureCount: number, - maxElementPerLogicalRow: readonly number[], outputShape: number[]): ProgramInfo => { - const [textureWidth, textureHeight] = - inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked); - const rank = outputShape.length; - - if (logicalRowCount < 1 || featureCount < 1) { - throw new Error('Logical row count N and feature count D must be greater than or equal to 1'); - } - - if (outputShape.length !== 1) { - throw new Error('Dimensionality of the output should be 1'); - } - - if (outputShape[0] !== logicalRowCount) { - throw new Error('Shape of the output should be equal to logical row count'); - } - - if (maxElementPerLogicalRow.length !== 1) { - throw new Error('Dimensionality of the intermediate results should be 1'); - } - - if (maxElementPerLogicalRow[0] !== logicalRowCount) { - throw new Error('Shape of the intermediate results should be equal to logical row count'); - } - - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const shaderSource = ` +const createComputScaleProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + input: Tensor, + logicalRowCount: number, + featureCount: number, + maxElementPerLogicalRow: readonly number[], + outputShape: number[], +): ProgramInfo => { + const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight( + input.dims, + TextureType.unpacked, + ); + const rank = outputShape.length; + + if (logicalRowCount < 1 || featureCount < 1) { + throw new Error('Logical row count N and feature count D must be greater than or equal to 1'); + } + + if (outputShape.length !== 1) { + throw new Error('Dimensionality of the output should be 1'); + } + + if (outputShape[0] !== logicalRowCount) { + throw new Error('Shape of the output should be equal to logical row count'); + } + + if (maxElementPerLogicalRow.length !== 1) { + throw new Error('Dimensionality of the intermediate results should be 1'); + } + + if (maxElementPerLogicalRow[0] !== logicalRowCount) { + throw new Error('Shape of the intermediate results should be equal to logical row count'); + } + + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const shaderSource = ` float process(int[${rank}] indices) { int logical_row_start_offset = indices[0] * ${featureCount}; @@ -218,33 +268,40 @@ const createComputScaleProgramInfo = return norm_factor; }`; - return { - ...softmaxComputeScaleProgramMetadata, - output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked}, - shaderSource - }; - }; - -const createSoftMaxProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, logicalRowCount: number, featureCount: number, - maxElementPerLogicalRow: readonly number[], normalizationPerLogicalRow: readonly number[]): ProgramInfo => { - const [textureWidth, textureHeight] = - inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked); - const rank = input.dims.length; - - if (logicalRowCount < 1 || featureCount < 1) { - throw new Error('Logical row count N and feature count D must be greater than or equal to 1'); - } - - if (maxElementPerLogicalRow.length !== 1 || normalizationPerLogicalRow.length !== 1) { - throw new Error('Dimensionality of the intermediate results should be 1'); - } - - if (maxElementPerLogicalRow[0] !== logicalRowCount || normalizationPerLogicalRow[0] !== logicalRowCount) { - throw new Error('Shape of the intermediate results should be equal to logical row count'); - } - - const shaderSource = ` + return { + ...softmaxComputeScaleProgramMetadata, + output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; + +const createSoftMaxProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + input: Tensor, + logicalRowCount: number, + featureCount: number, + maxElementPerLogicalRow: readonly number[], + normalizationPerLogicalRow: readonly number[], +): ProgramInfo => { + const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight( + input.dims, + TextureType.unpacked, + ); + const rank = input.dims.length; + + if (logicalRowCount < 1 || featureCount < 1) { + throw new Error('Logical row count N and feature count D must be greater than or equal to 1'); + } + + if (maxElementPerLogicalRow.length !== 1 || normalizationPerLogicalRow.length !== 1) { + throw new Error('Dimensionality of the intermediate results should be 1'); + } + + if (maxElementPerLogicalRow[0] !== logicalRowCount || normalizationPerLogicalRow[0] !== logicalRowCount) { + throw new Error('Shape of the intermediate results should be equal to logical row count'); + } + + const shaderSource = ` float process(int[${rank}] indices) { // get offset of current logical tensor index from the 2-D texture coordinates (TexCoords) @@ -264,12 +321,12 @@ const createSoftMaxProgramInfo = return exp(_A(indices) - _Max(logical_row_index)) / norm_factor; }`; - return { - ...softmaxProgramMetadata, - output: {dims: input.dims, type: input.type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...softmaxProgramMetadata, + output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/split.ts b/js/web/lib/onnxjs/backends/webgl/ops/split.ts index 2ab14563d80e2..47cda68e1cbac 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/split.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/split.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil, SplitUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil, SplitUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, TextureType } from '../types'; export interface SplitAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -21,68 +21,90 @@ const splitProgramMetadata = { inputTypes: [TextureType.unpacked], }; -export const split: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SplitAttributes): Tensor[] => { - validateInputs(inputs); +export const split: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: SplitAttributes, +): Tensor[] => { + validateInputs(inputs); - const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length); - const count = getProgramCount(inferenceHandler, inputs, axis, attributes); - const output: Tensor[] = []; - for (let i = 0; i < count; ++i) { - output.push(inferenceHandler.run( - { - ...splitProgramMetadata, - cacheHint: `${attributes.cacheKey};${i}`, - get: () => createSplitProgramInfo(inferenceHandler, inputs[0], attributes, axis, i) - }, - inputs)); - } + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length); + const count = getProgramCount(inferenceHandler, inputs, axis, attributes); + const output: Tensor[] = []; + for (let i = 0; i < count; ++i) { + output.push( + inferenceHandler.run( + { + ...splitProgramMetadata, + cacheHint: `${attributes.cacheKey};${i}`, + get: () => createSplitProgramInfo(inferenceHandler, inputs[0], attributes, axis, i), + }, + inputs, + ), + ); + } - return output; - }; + return output; +}; export const parseSplitAttributes: OperatorInitialization = (node: Graph.Node): SplitAttributes => { const axis = node.attributes.getInt('axis', 0); const split = node.attributes.getInts('split', []); const numOutputs = node.outputs.length; - return createAttributeWithCacheKey({axis, split, numOutputs}); + return createAttributeWithCacheKey({ axis, split, numOutputs }); }; -const getProgramCount = - (_inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number, attributes: SplitAttributes): number => { - const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs); - return offsets.length; - }; +const getProgramCount = ( + _inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + axis: number, + attributes: SplitAttributes, +): number => { + const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs); + return offsets.length; +}; -const createSplitProgramInfo = - (_inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SplitAttributes, axis: number, index: number): - ProgramInfo => { - const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs); - const offset = offsets[index]; - const outputShape = shapes[index]; - const rank = outputShape.length; - const shaderSource = ` +const createSplitProgramInfo = ( + _inferenceHandler: WebGLInferenceHandler, + input: Tensor, + attributes: SplitAttributes, + axis: number, + index: number, +): ProgramInfo => { + const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs); + const offset = offsets[index]; + const outputShape = shapes[index]; + const rank = outputShape.length; + const shaderSource = ` float process(int indices[${rank}]) { indices[${axis}] += ${offset}; return _A(indices); } `; - return { - ...splitProgramMetadata, - cacheHint: `${attributes.cacheKey}:${index}`, - output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...splitProgramMetadata, + cacheHint: `${attributes.cacheKey}:${index}`, + output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Split requires one input.'); } - if (inputs[0].type !== 'int8' && inputs[0].type !== 'uint8' && inputs[0].type !== 'int16' && - inputs[0].type !== 'uint16' && inputs[0].type !== 'int32' && inputs[0].type !== 'uint32' && - inputs[0].type !== 'float32' && inputs[0].type !== 'float64' && inputs[0].type !== 'bool') { + if ( + inputs[0].type !== 'int8' && + inputs[0].type !== 'uint8' && + inputs[0].type !== 'int16' && + inputs[0].type !== 'uint16' && + inputs[0].type !== 'int32' && + inputs[0].type !== 'uint32' && + inputs[0].type !== 'float32' && + inputs[0].type !== 'float64' && + inputs[0].type !== 'bool' + ) { throw new Error('Invalid input type.'); } }; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/squeeze.ts b/js/web/lib/onnxjs/backends/webgl/ops/squeeze.ts index 73b143b1def62..21a1180c32158 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/squeeze.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/squeeze.ts @@ -1,19 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; - -export const squeeze: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axes: number[]): Tensor[] => { - validateInputs(inputs); - const outputShape = ShapeUtil.squeezeShape(inputs[0].dims, axes); - const output = inferenceHandler.reshapeUnpacked(inputs[0], outputShape); - return [output]; - }; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; + +export const squeeze: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + axes: number[], +): Tensor[] => { + validateInputs(inputs); + const outputShape = ShapeUtil.squeezeShape(inputs[0].dims, axes); + const output = inferenceHandler.reshapeUnpacked(inputs[0], outputShape); + return [output]; +}; export const squeezeV13 = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputsV13(inputs); @@ -21,7 +24,7 @@ export const squeezeV13 = (inferenceHandler: WebGLInferenceHandler, inputs: Tens }; export const parseSqueezeAttributes: OperatorInitialization = (node: Graph.Node): number[] => - node.attributes.getInts('axes'); + node.attributes.getInts('axes'); const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { @@ -41,4 +44,4 @@ const validateInputsV13 = (inputs: Tensor[]): void => { if (inputs[1].type !== 'int32') { throw new Error('Invalid input type.'); } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/sum.ts b/js/web/lib/onnxjs/backends/webgl/ops/sum.ts index 2c25b10c5872c..0ca009dcef368 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/sum.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/sum.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramMetadata, TextureType} from '../types'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramMetadata, TextureType } from '../types'; export const sum = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputs(inputs); @@ -12,32 +12,37 @@ export const sum = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): const sumProgramMetadata = { name: 'Sum', inputNames: inputs.map((_v, i) => `X${i}`), - inputTypes: new Array(inputs.length).fill(TextureType.unpacked) + inputTypes: new Array(inputs.length).fill(TextureType.unpacked), }; const output = inferenceHandler.run( - {...sumProgramMetadata, get: () => createSumProgramInfo(inferenceHandler, inputs, sumProgramMetadata)}, inputs); + { ...sumProgramMetadata, get: () => createSumProgramInfo(inferenceHandler, inputs, sumProgramMetadata) }, + inputs, + ); return [output]; }; -const createSumProgramInfo = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], sumProgramMetadata: ProgramMetadata): ProgramInfo => { - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const outputShape = inputs[0].dims.slice(); - const sumLine = inputs.map((_v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + '); - const shaderSource = ` +const createSumProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + sumProgramMetadata: ProgramMetadata, +): ProgramInfo => { + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const outputShape = inputs[0].dims.slice(); + const sumLine = inputs.map((_v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + '); + const shaderSource = ` void main() { vec4 result = ${sumLine}; ${glsl.output} = result; } `; - return { - ...sumProgramMetadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - hasMain: true, - shaderSource - }; - }; + return { + ...sumProgramMetadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + hasMain: true, + shaderSource, + }; +}; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length === 0) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/tile.ts b/js/web/lib/onnxjs/backends/webgl/ops/tile.ts index 1d2cba7d9d75f..e91c6afe105bc 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/tile.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/tile.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {NUMBER_TYPES} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramMetadata, TextureType} from '../types'; +import { NUMBER_TYPES } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramMetadata, TextureType } from '../types'; export const tile = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputs(inputs); @@ -16,36 +16,40 @@ export const tile = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): }; const output = inferenceHandler.run( - {...tileProgramMetadata, get: () => createTileProgramInfo(inferenceHandler, inputs, tileProgramMetadata)}, - inputs); + { ...tileProgramMetadata, get: () => createTileProgramInfo(inferenceHandler, inputs, tileProgramMetadata) }, + inputs, + ); return [output]; }; -const createTileProgramInfo = - (_handler: WebGLInferenceHandler, inputs: Tensor[], tileProgramMetadata: ProgramMetadata): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); - const outputShape = new Array(inputShape.length); +const createTileProgramInfo = ( + _handler: WebGLInferenceHandler, + inputs: Tensor[], + tileProgramMetadata: ProgramMetadata, +): ProgramInfo => { + const inputShape = inputs[0].dims.slice(); + const outputShape = new Array(inputShape.length); - const tileOps: string[] = []; - for (let i = 0; i < inputShape.length; i++) { - outputShape[i] = inputShape[i] * inputs[1].numberData[i]; - tileOps.push(`inputIdx[${i}] = int(mod(float(outputIdx[${i}]), ${inputShape[i]}.));`); - } + const tileOps: string[] = []; + for (let i = 0; i < inputShape.length; i++) { + outputShape[i] = inputShape[i] * inputs[1].numberData[i]; + tileOps.push(`inputIdx[${i}] = int(mod(float(outputIdx[${i}]), ${inputShape[i]}.));`); + } - const rank = outputShape.length; - const shaderSource = ` + const rank = outputShape.length; + const shaderSource = ` float process(int outputIdx[${rank}]) { int inputIdx[${rank}]; ${tileOps.join('\n')} return _A(inputIdx); } `; - return { - ...tileProgramMetadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...tileProgramMetadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 2) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts b/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts index d3e7b3c0823be..6eceedca46f77 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, TextureType } from '../types'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -16,51 +16,59 @@ export interface TransposeAttributes extends AttributeWithCacheKey { const transposeProgramMetadata = { name: 'Transpose', inputNames: ['A'], - inputTypes: [TextureType.unpacked] + inputTypes: [TextureType.unpacked], }; -export const transpose: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: TransposeAttributes): Tensor[] => { - validateInputs(inputs); - const output = inferenceHandler.run( - { - ...transposeProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => createTransposeProgramInfo(inferenceHandler, inputs[0], attributes.perm) - }, - inputs); - return [output]; - }; +export const transpose: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: TransposeAttributes, +): Tensor[] => { + validateInputs(inputs); + const output = inferenceHandler.run( + { + ...transposeProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createTransposeProgramInfo(inferenceHandler, inputs[0], attributes.perm), + }, + inputs, + ); + return [output]; +}; -export const parseTransposeAttributes: OperatorInitialization = - (node: Graph.Node): TransposeAttributes => createAttributeWithCacheKey({perm: node.attributes.getInts('perm', [])}); +export const parseTransposeAttributes: OperatorInitialization = ( + node: Graph.Node, +): TransposeAttributes => createAttributeWithCacheKey({ perm: node.attributes.getInts('perm', []) }); -const createTransposeProgramInfo = - (_inferenceHandler: WebGLInferenceHandler, input: Tensor, perm: number[]): ProgramInfo => { - const inputShape = input.dims; - perm = getAdjustedPerm(inputShape, perm); - const unpackedOutputShape = getOutputShape(inputShape, perm); - const rank = inputShape.length; - // A dims=[${inputs[0].dims.toString()}] - // out Dims=[${unpackedOutputShape.toString()}] - // based on perm=[${perm.toString()}] - const shaderSource = ` +const createTransposeProgramInfo = ( + _inferenceHandler: WebGLInferenceHandler, + input: Tensor, + perm: number[], +): ProgramInfo => { + const inputShape = input.dims; + perm = getAdjustedPerm(inputShape, perm); + const unpackedOutputShape = getOutputShape(inputShape, perm); + const rank = inputShape.length; + // A dims=[${inputs[0].dims.toString()}] + // out Dims=[${unpackedOutputShape.toString()}] + // based on perm=[${perm.toString()}] + const shaderSource = ` ${getPermFunctionBody('perm', perm, rank)} float process(int indices[${rank}]) { int a[${rank}]; perm(a, indices); return _A(a); }`; - return { - ...transposeProgramMetadata, - output: {dims: unpackedOutputShape, type: input.type, textureType: TextureType.unpacked}, - shaderSource - }; - }; + return { + ...transposeProgramMetadata, + output: { dims: unpackedOutputShape, type: input.type, textureType: TextureType.unpacked }, + shaderSource, + }; +}; const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] => { if (perm && perm.length !== inputShape.length) { - perm = [...(inputShape.keys())].reverse(); + perm = [...inputShape.keys()].reverse(); } return perm; }; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/uint8-encode.ts b/js/web/lib/onnxjs/backends/webgl/ops/uint8-encode.ts index 76811de7b88b7..dcd0c80c57e01 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/uint8-encode.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/uint8-encode.ts @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {TextureData, TextureType} from '../types'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { TextureData, TextureType } from '../types'; export const encodeAsUint8 = (inferenceHandler: WebGLInferenceHandler, input: TextureData): TextureData => { const outputShape = input.shape; @@ -63,9 +63,9 @@ export const encodeAsUint8 = (inferenceHandler: WebGLInferenceHandler, input: Te name: 'Uint8Encode', inputTypes: [TextureType.unpacked], inputNames: ['X'], - output: {dims: outputShape, type: input.tensor.type, textureType: TextureType.downloadUint8AsFloat}, + output: { dims: outputShape, type: input.tensor.type, textureType: TextureType.downloadUint8AsFloat }, shaderSource, - hasMain: true + hasMain: true, }; return inferenceHandler.executeProgram(programInfo, [input.tensor]); }; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/unary-op.ts b/js/web/lib/onnxjs/backends/webgl/ops/unary-op.ts index d8bba35021e9f..77b7c027d3f63 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/unary-op.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/unary-op.ts @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {Tensor} from '../../../tensor'; -import {MAX_CLIP, MIN_CLIP} from '../../../util'; -import {FunctionType, GlslValueFunction} from '../glsl-definitions'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { Tensor } from '../../../tensor'; +import { MAX_CLIP, MIN_CLIP } from '../../../util'; +import { FunctionType, GlslValueFunction } from '../glsl-definitions'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types'; export function glslAbs(): GlslValueFunction { return glslBuiltinUnary('abs'); @@ -40,7 +40,7 @@ export function glslElu(alpha: number): GlslValueFunction { return vec4(${name}_(v.x), ${name}_(v.y), ${name}_(v.z), ${name}_(v.w)); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslExp(): GlslValueFunction { return glslBuiltinUnary('exp'); @@ -61,7 +61,7 @@ export function glslClip(min: number, max: number): GlslValueFunction { return clamp(v, min, max); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslIdentity(): GlslValueFunction { const name = 'indentity'; @@ -73,7 +73,7 @@ export function glslIdentity(): GlslValueFunction { return v; } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslLeakyRelu(alpha: number): GlslValueFunction { const name = 'leakyRelu'; @@ -87,7 +87,7 @@ export function glslLeakyRelu(alpha: number): GlslValueFunction { return vec4(${name}_(v.x), ${name}_(v.y), ${name}_(v.z), ${name}_(v.w)); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslLog(): GlslValueFunction { return glslBuiltinUnary('log'); @@ -102,7 +102,7 @@ export function glslNeg(): GlslValueFunction { return -v; } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslNot(): GlslValueFunction { const name = 'not'; @@ -120,7 +120,7 @@ export function glslNot(): GlslValueFunction { return bvec4(!v.x, !v.y, !v.z, !v.w); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslSin(): GlslValueFunction { return glslBuiltinUnary('sin'); @@ -135,7 +135,7 @@ export function glslRelu(): GlslValueFunction { return max( v, 0.0 ); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslSigmoid(): GlslValueFunction { const name = 'sigmoid'; @@ -147,7 +147,7 @@ export function glslSigmoid(): GlslValueFunction { return 1.0 / (1.0 + exp(-v)); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } export function glslSqrt(): GlslValueFunction { return glslBuiltinUnary('sqrt'); @@ -169,7 +169,7 @@ export function glslTanh(): GlslValueFunction { return (v - 1.) / (v + 1.); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } function glslBuiltinUnary(name: string): GlslValueFunction { const body = ` @@ -180,22 +180,25 @@ function glslBuiltinUnary(name: string): GlslValueFunction { return ${name}(v); } `; - return {body, name, type: FunctionType.ValueBased}; + return { body, name, type: FunctionType.ValueBased }; } ///// ///// ///// -const createElementwiseProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, input: Tensor, glslFunc: GlslValueFunction): - ProgramInfo => { - const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked; - const glsl = getGlsl(handler.session.backend.glContext.version); - return { - ...metadata, - output: {dims: input.dims, type: input.type, textureType}, - shaderSource: ` +const createElementwiseProgramInfo = ( + handler: WebGLInferenceHandler, + metadata: ProgramMetadata, + input: Tensor, + glslFunc: GlslValueFunction, +): ProgramInfo => { + const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked; + const glsl = getGlsl(handler.session.backend.glContext.version); + return { + ...metadata, + output: { dims: input.dims, type: input.type, textureType }, + shaderSource: ` ${glslFunc.body} void main() { vec4 v = ${glsl.texture2D}(A, TexCoords); @@ -203,43 +206,59 @@ const createElementwiseProgramInfo = ${glsl.output} = v; } `, - hasMain: true - }; - }; + hasMain: true, + }; +}; -const createElementwiseProgramInfoLoader = - (handler: WebGLInferenceHandler, input: Tensor, glslFunc: GlslValueFunction, cacheKey?: string): - ProgramInfoLoader => { - const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked; - const metadata = {name: glslFunc.name, inputTypes: [textureType], inputNames: ['A'], cacheHint: cacheKey}; - return {...metadata, get: () => createElementwiseProgramInfo(handler, metadata, input, glslFunc)}; - }; +const createElementwiseProgramInfoLoader = ( + handler: WebGLInferenceHandler, + input: Tensor, + glslFunc: GlslValueFunction, + cacheKey?: string, +): ProgramInfoLoader => { + const textureType = handler.session.pack ? TextureType.packed : TextureType.unpacked; + const metadata = { name: glslFunc.name, inputTypes: [textureType], inputNames: ['A'], cacheHint: cacheKey }; + return { ...metadata, get: () => createElementwiseProgramInfo(handler, metadata, input, glslFunc) }; +}; -export const abs = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAbs()), inputs)]; +export const abs = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAbs()), inputs), +]; -export const acos = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAcos()), inputs)]; +export const acos = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAcos()), inputs), +]; -export const asin = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAsin()), inputs)]; +export const asin = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAsin()), inputs), +]; -export const atan = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAtan()), inputs)]; +export const atan = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslAtan()), inputs), +]; export interface ClipAttributes extends AttributeWithCacheKey { readonly min: number; readonly max: number; } -export const clip = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ClipAttributes): Tensor[] => [handler.run( - createElementwiseProgramInfoLoader( - handler, inputs[0], glslClip(attributes.min, attributes.max), attributes.cacheKey), - inputs)]; - -export const parseClipAttributes = (node: Graph.Node): ClipAttributes => createAttributeWithCacheKey( - {min: node.attributes.getFloat('min', MIN_CLIP), max: node.attributes.getFloat('max', MAX_CLIP)}); +export const clip = (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ClipAttributes): Tensor[] => [ + handler.run( + createElementwiseProgramInfoLoader( + handler, + inputs[0], + glslClip(attributes.min, attributes.max), + attributes.cacheKey, + ), + inputs, + ), +]; + +export const parseClipAttributes = (node: Graph.Node): ClipAttributes => + createAttributeWithCacheKey({ + min: node.attributes.getFloat('min', MIN_CLIP), + max: node.attributes.getFloat('max', MAX_CLIP), + }); export const clipV11 = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { const attributes = generateClipAttributesFromInputs(handler, inputs); @@ -247,78 +266,102 @@ export const clipV11 = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tenso }; const generateClipAttributesFromInputs = (handler: WebGLInferenceHandler, inputs: Tensor[]): ClipAttributes => { - if (inputs.length >= 3 && - (!handler.session.isInitializer(inputs[1].dataId) || !handler.session.isInitializer(inputs[2].dataId))) { + if ( + inputs.length >= 3 && + (!handler.session.isInitializer(inputs[1].dataId) || !handler.session.isInitializer(inputs[2].dataId)) + ) { throw new Error('dynamic clip attributes are not allowed'); } - const min = (inputs.length >= 3) ? inputs[1].numberData[0] : MIN_CLIP; - const max = (inputs.length >= 3) ? inputs[2].numberData[0] : MAX_CLIP; - return createAttributeWithCacheKey({min, max}); + const min = inputs.length >= 3 ? inputs[1].numberData[0] : MIN_CLIP; + const max = inputs.length >= 3 ? inputs[2].numberData[0] : MAX_CLIP; + return createAttributeWithCacheKey({ min, max }); }; -export const ceil = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCeil()), inputs)]; +export const ceil = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCeil()), inputs), +]; -export const cos = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCos()), inputs)]; +export const cos = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCos()), inputs), +]; export interface EluAttributes extends AttributeWithCacheKey { readonly alpha: number; } -export const elu = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: EluAttributes): Tensor[] => [handler.run( - createElementwiseProgramInfoLoader(handler, inputs[0], glslElu(attributes.alpha), attributes.cacheKey), - inputs)]; +export const elu = (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: EluAttributes): Tensor[] => [ + handler.run( + createElementwiseProgramInfoLoader(handler, inputs[0], glslElu(attributes.alpha), attributes.cacheKey), + inputs, + ), +]; export const parseEluAttributes = (node: Graph.Node): EluAttributes => - createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 1.0)}); + createAttributeWithCacheKey({ alpha: node.attributes.getFloat('alpha', 1.0) }); -export const exp = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslExp()), inputs)]; +export const exp = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslExp()), inputs), +]; -export const floor = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslFloor()), inputs)]; +export const floor = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslFloor()), inputs), +]; -export const identity = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslIdentity()), inputs)]; +export const identity = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslIdentity()), inputs), +]; export interface LeakyReluAttributes extends AttributeWithCacheKey { readonly alpha: number; } -export const leakyRelu = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: LeakyReluAttributes): Tensor[] => [handler.run( - createElementwiseProgramInfoLoader(handler, inputs[0], glslLeakyRelu(attributes.alpha), attributes.cacheKey), - inputs)]; +export const leakyRelu = ( + handler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: LeakyReluAttributes, +): Tensor[] => [ + handler.run( + createElementwiseProgramInfoLoader(handler, inputs[0], glslLeakyRelu(attributes.alpha), attributes.cacheKey), + inputs, + ), +]; export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes => - createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 0.01)}); + createAttributeWithCacheKey({ alpha: node.attributes.getFloat('alpha', 0.01) }); -export const log = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslLog()), inputs)]; +export const log = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslLog()), inputs), +]; -export const neg = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNeg()), inputs)]; +export const neg = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNeg()), inputs), +]; -export const not = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNot()), inputs)]; +export const not = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNot()), inputs), +]; -export const relu = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslRelu()), inputs)]; +export const relu = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslRelu()), inputs), +]; -export const sigmoid = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSigmoid()), inputs)]; +export const sigmoid = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSigmoid()), inputs), +]; -export const sin = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSin()), inputs)]; +export const sin = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSin()), inputs), +]; -export const sqrt = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSqrt()), inputs)]; +export const sqrt = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSqrt()), inputs), +]; -export const tan = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTan()), inputs)]; +export const tan = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTan()), inputs), +]; -export const tanh = (handler: WebGLInferenceHandler, inputs: Tensor[]): - Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTanh()), inputs)]; +export const tanh = (handler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => [ + handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTanh()), inputs), +]; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/unpack.ts b/js/web/lib/onnxjs/backends/webgl/ops/unpack.ts index db8b496bc260b..ffb5ff648df54 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/unpack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/unpack.ts @@ -1,18 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, ProgramInfoLoader, TextureType} from '../types'; -import {getCoordsDataType} from '../utils'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types'; +import { getCoordsDataType } from '../utils'; -import {getChannels, unpackFromChannel} from './packing-utils'; +import { getChannels, unpackFromChannel } from './packing-utils'; const unpackProgramMetadata = { name: 'unpack', inputNames: ['A'], - inputTypes: [TextureType.packed] + inputTypes: [TextureType.packed], }; export const createUnpackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => { @@ -22,7 +22,7 @@ export const createUnpackProgramInfo = (handler: WebGLInferenceHandler, input: T const innerDims = channels.slice(-2); const coordsDataType = getCoordsDataType(rank); const unpackChannel = unpackFromChannel(); - const isScalar = (input.dims.length === 0); + const isScalar = input.dims.length === 0; const sourceCoords = isScalar ? '' : getSourceCoords(rank, channels); const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`; const glsl = getGlsl(handler.session.backend.glContext.version); @@ -41,13 +41,15 @@ export const createUnpackProgramInfo = (handler: WebGLInferenceHandler, input: T return { ...unpackProgramMetadata, hasMain: true, - output: {dims: input.dims, type: input.type, textureType: TextureType.unpacked}, - shaderSource + output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked }, + shaderSource, }; }; -export const createUnpackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => - ({...unpackProgramMetadata, get: () => createUnpackProgramInfo(handler, input)}); +export const createUnpackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => ({ + ...unpackProgramMetadata, + get: () => createUnpackProgramInfo(handler, input), +}); function getSourceCoords(rank: number, dims: string[]): string { if (rank === 1) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/unsqueeze.ts b/js/web/lib/onnxjs/backends/webgl/ops/unsqueeze.ts index fcbba01de9831..5b6b22ace768e 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/unsqueeze.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/unsqueeze.ts @@ -1,19 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {ShapeUtil} from '../../../util'; -import {WebGLInferenceHandler} from '../inference-handler'; - -export const unsqueeze: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axes: number[]): Tensor[] => { - validateInputs(inputs); - const outputShape = ShapeUtil.unsqueezeShape(inputs[0].dims, axes); - const output = inferenceHandler.reshapeUnpacked(inputs[0], outputShape); - return [output]; - }; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { ShapeUtil } from '../../../util'; +import { WebGLInferenceHandler } from '../inference-handler'; + +export const unsqueeze: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + axes: number[], +): Tensor[] => { + validateInputs(inputs); + const outputShape = ShapeUtil.unsqueezeShape(inputs[0].dims, axes); + const output = inferenceHandler.reshapeUnpacked(inputs[0], outputShape); + return [output]; +}; export const unsqueezeV13 = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputsV13(inputs); @@ -21,7 +24,7 @@ export const unsqueezeV13 = (inferenceHandler: WebGLInferenceHandler, inputs: Te }; export const parseUnsqueezeAttributes: OperatorInitialization = (node: Graph.Node): number[] => - node.attributes.getInts('axes'); + node.attributes.getInts('axes'); const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/upsample.ts b/js/web/lib/onnxjs/backends/webgl/ops/upsample.ts index d7bb1393d2b2a..3dde0a48695be 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/upsample.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/upsample.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; -import {Graph} from '../../../graph'; -import {OperatorImplementation, OperatorInitialization} from '../../../operators'; -import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; -import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, TextureType} from '../types'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key'; +import { Graph } from '../../../graph'; +import { OperatorImplementation, OperatorInitialization } from '../../../operators'; +import { Tensor } from '../../../tensor'; +import { getGlsl } from '../glsl-source'; +import { WebGLInferenceHandler } from '../inference-handler'; +import { ProgramInfo, TextureType } from '../types'; export interface UpsampleAttributes extends AttributeWithCacheKey { readonly opset: number; @@ -33,27 +33,33 @@ const upsampleProgramMetadata = { inputTypes: [TextureType.unpacked], }; -export const upsample: OperatorImplementation = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: UpsampleAttributes): Tensor[] => { - validateInputs(inputs, attributes); - const output = inferenceHandler.run( - { - ...upsampleProgramMetadata, - cacheHint: attributes.cacheKey, - get: () => createUpsampleProgramInfo(inferenceHandler, inputs, attributes) - }, - inputs); - return [output]; - }; - -export const parseUpsampleAttributesV7: OperatorInitialization = - (node: Graph.Node): UpsampleAttributes => parseUpsampleAttributes(node, 7); - -export const parseUpsampleAttributesV9: OperatorInitialization = - (node: Graph.Node): UpsampleAttributes => parseUpsampleAttributes(node, 9); +export const upsample: OperatorImplementation = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: UpsampleAttributes, +): Tensor[] => { + validateInputs(inputs, attributes); + const output = inferenceHandler.run( + { + ...upsampleProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createUpsampleProgramInfo(inferenceHandler, inputs, attributes), + }, + inputs, + ); + return [output]; +}; + +export const parseUpsampleAttributesV7: OperatorInitialization = ( + node: Graph.Node, +): UpsampleAttributes => parseUpsampleAttributes(node, 7); + +export const parseUpsampleAttributesV9: OperatorInitialization = ( + node: Graph.Node, +): UpsampleAttributes => parseUpsampleAttributes(node, 9); export const parseUpsampleAttributes = (node: Graph.Node, opset: number): UpsampleAttributes => { - const isResize = (opset >= 10); + const isResize = opset >= 10; // processing node attributes const mode = node.attributes.getString('mode', 'nearest'); @@ -70,17 +76,24 @@ export const parseUpsampleAttributes = (node: Graph.Node, opset: number): Upsamp const extrapolationValue = node.attributes.getFloat('extrapolation_value', 0.0); const coordinateTransformMode = - opset > 10 ? node.attributes.getString('coordinate_transformation_mode', 'half_pixel') : 'asymmetric'; - if ([ - 'asymmetric', 'pytorch_half_pixel', 'tf_half_pixel_for_nn', 'align_corners', 'tf_crop_and_resize', 'half_pixel' - ].indexOf(coordinateTransformMode) === -1) { + opset > 10 ? node.attributes.getString('coordinate_transformation_mode', 'half_pixel') : 'asymmetric'; + if ( + [ + 'asymmetric', + 'pytorch_half_pixel', + 'tf_half_pixel_for_nn', + 'align_corners', + 'tf_crop_and_resize', + 'half_pixel', + ].indexOf(coordinateTransformMode) === -1 + ) { throw new Error(`coordinate_transform_mode '${coordinateTransformMode}' is not supported`); } - const needRoiInput = (coordinateTransformMode === 'tf_crop_and_resize'); + const needRoiInput = coordinateTransformMode === 'tf_crop_and_resize'; const useExtrapolation = needRoiInput; const nearestMode = - (mode === 'nearest' && opset >= 11) ? node.attributes.getString('nearest_mode', 'round_prefer_floor') : ''; + mode === 'nearest' && opset >= 11 ? node.attributes.getString('nearest_mode', 'round_prefer_floor') : ''; if (['round_prefer_floor', 'round_prefer_ceil', 'floor', 'ceil', ''].indexOf(nearestMode) === -1) { throw new Error(`nearest_mode '${nearestMode}' is not supported`); } @@ -92,7 +105,7 @@ export const parseUpsampleAttributes = (node: Graph.Node, opset: number): Upsamp } const useNearest2xOptimization = - (opset < 11) ? true : (mode === 'nearest' && coordinateTransformMode === 'asymmetric' && nearestMode === 'floor'); + opset < 11 ? true : mode === 'nearest' && coordinateTransformMode === 'asymmetric' && nearestMode === 'floor'; let roiInputIdx = 0; let scalesInputIdx = 0; @@ -127,37 +140,44 @@ export const parseUpsampleAttributes = (node: Graph.Node, opset: number): Upsamp useNearest2xOptimization, roiInputIdx, scalesInputIdx, - sizesInputIdx + sizesInputIdx, }); }; -const createUpsampleProgramInfo = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: UpsampleAttributes): ProgramInfo => { - const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); - const [inputWidth, inputHeight] = - inferenceHandler.calculateTextureWidthAndHeight(inputs[0].dims, TextureType.unpacked); - - const outputShape = inputs[0].dims.map((dim, i) => Math.floor(dim * attributes.scales[i])); - const [outputWidth, outputHeight] = - inferenceHandler.calculateTextureWidthAndHeight(outputShape, TextureType.unpacked); - const dim = outputShape.length; - - const outputPitches = new Array(dim); - const inputPitches = new Array(dim); - let precalculatedPitches = ` +const createUpsampleProgramInfo = ( + inferenceHandler: WebGLInferenceHandler, + inputs: Tensor[], + attributes: UpsampleAttributes, +): ProgramInfo => { + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); + const [inputWidth, inputHeight] = inferenceHandler.calculateTextureWidthAndHeight( + inputs[0].dims, + TextureType.unpacked, + ); + + const outputShape = inputs[0].dims.map((dim, i) => Math.floor(dim * attributes.scales[i])); + const [outputWidth, outputHeight] = inferenceHandler.calculateTextureWidthAndHeight( + outputShape, + TextureType.unpacked, + ); + const dim = outputShape.length; + + const outputPitches = new Array(dim); + const inputPitches = new Array(dim); + let precalculatedPitches = ` int output_pitches[${dim}]; int input_pitches[${dim}]; `; - for (let d = dim - 1; d >= 0; d--) { - outputPitches[d] = (d === dim - 1) ? 1 : outputPitches[d + 1] * outputShape[d + 1]; - inputPitches[d] = (d === dim - 1) ? 1 : inputPitches[d + 1] * inputs[0].dims[d + 1]; + for (let d = dim - 1; d >= 0; d--) { + outputPitches[d] = d === dim - 1 ? 1 : outputPitches[d + 1] * outputShape[d + 1]; + inputPitches[d] = d === dim - 1 ? 1 : inputPitches[d + 1] * inputs[0].dims[d + 1]; - precalculatedPitches += ` + precalculatedPitches += ` output_pitches[${d}] = ${outputPitches[d]}; input_pitches[${d}] = ${inputPitches[d]}; `; - } - const getInputFloatFunction = ` + } + const getInputFloatFunction = ` float getInputFloat(int index) { vec2 coords = offsetToCoords(index, ${inputWidth}, ${inputHeight}); float value = getColorAsFloat(${glsl.texture2D}(X, coords)); @@ -165,9 +185,10 @@ const createUpsampleProgramInfo = } `; - const shaderSource = attributes.mode === 'nearest' ? - // nearest - ` + const shaderSource = + attributes.mode === 'nearest' + ? // nearest + ` ${getInputFloatFunction} float process(int indices[${dim}]) { int input_index = 0; @@ -190,10 +211,10 @@ const createUpsampleProgramInfo = } return getInputFloat(input_index); - }` : - dim === 4 ? - // bilinear 4D - ` + }` + : dim === 4 + ? // bilinear 4D + ` ${getInputFloatFunction} float process(int indices[4]) { int input_index = 0; @@ -247,9 +268,9 @@ const createUpsampleProgramInfo = float y0 = x00 + float(y_offset) * (x01 - x00) / float(scales[2]); float y1 = x10 + float(y_offset) * (x11 - x10) / float(scales[2]); return y0 + float(x_offset) * (y1 - y0) / float(scales[3]); - }` : - // bilinear 2D - ` + }` + : // bilinear 2D + ` ${getInputFloatFunction} float process(int indices[2]) { int input_index = 0; @@ -297,23 +318,28 @@ const createUpsampleProgramInfo = float y1 = x10 + float(y_offset) * (x11 - x10) / float(scales[0]); return y0 + float(x_offset) * (y1 - y0) / float(scales[1]); }`; - return { - ...upsampleProgramMetadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - shaderSource, - variables: [{ - name: 'scales', - type: 'int', - arrayLength: attributes.scales.length, - data: attributes.scales.map(x => Math.ceil(x)) - }] - }; - }; + return { + ...upsampleProgramMetadata, + output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked }, + shaderSource, + variables: [ + { + name: 'scales', + type: 'int', + arrayLength: attributes.scales.length, + data: attributes.scales.map((x) => Math.ceil(x)), + }, + ], + }; +}; export const validateInputs = (inputs: Tensor[], attribute: UpsampleAttributes): void => { - if (!inputs || (attribute.opset < 9 && inputs.length !== 1) || - (attribute.opset >= 9 && attribute.opset < 11 && inputs.length !== 2) || - (attribute.opset >= 11 && inputs.length < 2)) { + if ( + !inputs || + (attribute.opset < 9 && inputs.length !== 1) || + (attribute.opset >= 9 && attribute.opset < 11 && inputs.length !== 2) || + (attribute.opset >= 11 && inputs.length < 2) + ) { throw new Error('invalid inputs.'); } @@ -347,4 +373,4 @@ export const scalesValidation = (scales: number[], mode: string, isResize: boole in the ${isResize ? 'Resize' : 'Upsample'} opeartor.`); } } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/program-manager.ts b/js/web/lib/onnxjs/backends/webgl/program-manager.ts index d2d678fbb19b8..92edbefc3d487 100644 --- a/js/web/lib/onnxjs/backends/webgl/program-manager.ts +++ b/js/web/lib/onnxjs/backends/webgl/program-manager.ts @@ -1,15 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env} from 'onnxruntime-common'; +import { env } from 'onnxruntime-common'; -import {Logger, Profiler} from '../../instrument'; +import { Logger, Profiler } from '../../instrument'; -import {GlslPreprocessor} from './glsl-preprocessor'; -import {getVertexShaderSource} from './glsl-source'; -import {TextureLayoutStrategy} from './texture-layout-strategy'; -import {Artifact, ProgramInfo, ProgramVariable, TextureData, TextureLayout, VariableInfo} from './types'; -import {WebGLContext} from './webgl-context'; +import { GlslPreprocessor } from './glsl-preprocessor'; +import { getVertexShaderSource } from './glsl-source'; +import { TextureLayoutStrategy } from './texture-layout-strategy'; +import { Artifact, ProgramInfo, ProgramVariable, TextureData, TextureLayout, VariableInfo } from './types'; +import { WebGLContext } from './webgl-context'; /** * ProgramManager is the main class behind running computations @@ -21,47 +21,54 @@ import {WebGLContext} from './webgl-context'; * corresponding Location's in the binary program */ export class ProgramManager { - repo: Map; // this should be per-session object + repo: Map; // this should be per-session object vertexShader: WebGLShader; attributesBound: boolean; constructor( - public profiler: Readonly, public glContext: WebGLContext, - public textureLayoutStrategy: TextureLayoutStrategy) { + public profiler: Readonly, + public glContext: WebGLContext, + public textureLayoutStrategy: TextureLayoutStrategy, + ) { this.repo = new Map(); this.attributesBound = false; } - getArtifact(key: unknown): Artifact|undefined { + getArtifact(key: unknown): Artifact | undefined { return this.repo.get(key); } setArtifact(key: unknown, artifact: Artifact): void { this.repo.set(key, artifact); } run(buildArtifact: Artifact, inputs: TextureData[], output: TextureData): void { - this.profiler.event('op', `ProgramManager.run ${buildArtifact.programInfo.name ?? 'unknown kernel'}`, () => { - const gl = this.glContext.gl; - const program = buildArtifact.program; - gl.useProgram(program); - try { - this.bindOutput(output); - if (!this.attributesBound) { - this.bindAttributes(buildArtifact.attribLocations); + this.profiler.event( + 'op', + `ProgramManager.run ${buildArtifact.programInfo.name ?? 'unknown kernel'}`, + () => { + const gl = this.glContext.gl; + const program = buildArtifact.program; + gl.useProgram(program); + try { + this.bindOutput(output); + if (!this.attributesBound) { + this.bindAttributes(buildArtifact.attribLocations); + } + this.bindUniforms(buildArtifact.uniformLocations, buildArtifact.programInfo.variables ?? [], inputs); + } catch (err) { + Logger.error('ProgramManager', buildArtifact.programInfo.shaderSource); + throw err; } - this.bindUniforms(buildArtifact.uniformLocations, buildArtifact.programInfo.variables ?? [], inputs); - } catch (err) { - Logger.error('ProgramManager', buildArtifact.programInfo.shaderSource); - throw err; - } - this.profiler.event('backend', 'GlContext.draw()', () => { - this.glContext.draw(); - }); - }, this.glContext); + this.profiler.event('backend', 'GlContext.draw()', () => { + this.glContext.draw(); + }); + }, + this.glContext, + ); } dispose(): void { if (this.vertexShader) { this.glContext.deleteShader(this.vertexShader); } - this.repo.forEach(a => this.glContext.deleteProgram(a.program)); + this.repo.forEach((a) => this.glContext.deleteProgram(a.program)); } build(programInfo: ProgramInfo, inputTextureLayouts: TextureLayout[], outputTextureLayout: TextureLayout): Artifact { return this.profiler.event('backend', 'ProgramManager.build', () => { @@ -72,8 +79,11 @@ export class ProgramManager { programInfo, program, uniformLocations: this.getUniformLocations( - program, preprocessor.context.programInfo.inputNames, preprocessor.context.programInfo.variables), - attribLocations: this.getAttribLocations(program) + program, + preprocessor.context.programInfo.inputNames, + preprocessor.context.programInfo.variables, + ), + attribLocations: this.getAttribLocations(program), }; return artifact; }); @@ -85,9 +95,12 @@ export class ProgramManager { this.vertexShader = this.glContext.compileShader(vertexShaderScript, this.glContext.gl.VERTEX_SHADER); } if (env.debug) { - Logger.verbose('ProrgramManager', `FragShader: + Logger.verbose( + 'ProrgramManager', + `FragShader: ${fragShaderScript} -`); +`, + ); } const fragShader = this.glContext.compileShader(fragShaderScript, this.glContext.gl.FRAGMENT_SHADER); const program = this.glContext.createProgram(this.vertexShader, fragShader); @@ -98,8 +111,9 @@ ${fragShaderScript} const width = td.width; const height = td.height; Logger.verbose( - 'ProrgramManager', - `Binding output texture to Framebuffer: w/h=${width}/${height}, shape=${td.shape}, type=${td.tensor.type}`); + 'ProrgramManager', + `Binding output texture to Framebuffer: w/h=${width}/${height}, shape=${td.shape}, type=${td.tensor.type}`, + ); this.glContext.attachFramebuffer(td.texture, width, height); } bindAttributes(attribLocations: Artifact.AttribLocations): void { @@ -108,12 +122,15 @@ ${fragShaderScript} this.glContext.setVertexAttributes(positionHandle, textureCoordHandle); this.attributesBound = true; } - bindUniforms(uniformLocations: Artifact.UniformLocations, variables: ProgramVariable[], textures: TextureData[]): - void { + bindUniforms( + uniformLocations: Artifact.UniformLocations, + variables: ProgramVariable[], + textures: TextureData[], + ): void { const gl = this.glContext.gl; let texturePosition = 0; - for (const {name, type, location, arrayLength} of uniformLocations) { - const value = variables.find(v => v.name === name)?.data; + for (const { name, type, location, arrayLength } of uniformLocations) { + const value = variables.find((v) => v.name === name)?.data; if (type !== 'sampler2D' && !value) { throw new Error(`variable '${name}' does not have data defined in program info`); } @@ -147,20 +164,27 @@ ${fragShaderScript} getAttribLocations(program: WebGLProgram): Artifact.AttribLocations { return { position: this.getAttribLocation(program, 'position'), - textureCoord: this.getAttribLocation(program, 'textureCoord') + textureCoord: this.getAttribLocation(program, 'textureCoord'), }; } - getUniformLocations(program: WebGLProgram, samplers?: string[], variables?: VariableInfo[]): - Artifact.UniformLocations { + getUniformLocations( + program: WebGLProgram, + samplers?: string[], + variables?: VariableInfo[], + ): Artifact.UniformLocations { const uniformLocations: Artifact.UniformLocations = []; if (samplers) { for (const sampler of samplers) { - uniformLocations.push({name: sampler, type: 'sampler2D', location: this.getUniformLocation(program, sampler)}); + uniformLocations.push({ + name: sampler, + type: 'sampler2D', + location: this.getUniformLocation(program, sampler), + }); } } if (variables) { for (const variable of variables) { - uniformLocations.push({...variable, location: this.getUniformLocation(program, variable.name)}); + uniformLocations.push({ ...variable, location: this.getUniformLocation(program, variable.name) }); } } return uniformLocations; diff --git a/js/web/lib/onnxjs/backends/webgl/session-handler.ts b/js/web/lib/onnxjs/backends/webgl/session-handler.ts index d1b8763cec7ef..936518db99e40 100644 --- a/js/web/lib/onnxjs/backends/webgl/session-handler.ts +++ b/js/web/lib/onnxjs/backends/webgl/session-handler.ts @@ -1,21 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {SessionHandler} from '../../backend'; -import {Graph} from '../../graph'; -import {Logger} from '../../instrument'; -import {Operator} from '../../operators'; -import {OpSet, resolveOperator} from '../../opset'; -import {Session} from '../../session'; -import {Tensor} from '../../tensor'; -import {WebGLBackend} from '../backend-webgl'; +import { SessionHandler } from '../../backend'; +import { Graph } from '../../graph'; +import { Logger } from '../../instrument'; +import { Operator } from '../../operators'; +import { OpSet, resolveOperator } from '../../opset'; +import { Session } from '../../session'; +import { Tensor } from '../../tensor'; +import { WebGLBackend } from '../backend-webgl'; -import {WebGLInferenceHandler} from './inference-handler'; -import {WEBGL_OP_RESOLVE_RULES} from './op-resolve-rules'; -import {ProgramManager} from './program-manager'; -import {PreferLogicalStrategy, TextureLayoutStrategy} from './texture-layout-strategy'; -import {TextureManager} from './texture-manager'; -import {TextureData} from './types'; +import { WebGLInferenceHandler } from './inference-handler'; +import { WEBGL_OP_RESOLVE_RULES } from './op-resolve-rules'; +import { ProgramManager } from './program-manager'; +import { PreferLogicalStrategy, TextureLayoutStrategy } from './texture-layout-strategy'; +import { TextureManager } from './texture-manager'; +import { TextureData } from './types'; export class WebGLSessionHandler implements SessionHandler { programManager: ProgramManager; @@ -28,12 +28,15 @@ export class WebGLSessionHandler implements SessionHandler { initializers: Set; pack?: boolean; - constructor(public readonly backend: WebGLBackend, public readonly context: Session.Context) { + constructor( + public readonly backend: WebGLBackend, + public readonly context: Session.Context, + ) { this.layoutStrategy = new PreferLogicalStrategy(backend.glContext.maxTextureSize); this.programManager = new ProgramManager(this.context.profiler, backend.glContext, this.layoutStrategy); - this.textureManager = new TextureManager( - backend.glContext, this.layoutStrategy, this.context.profiler, - {reuseTextures: backend.textureCacheMode === 'full'}); + this.textureManager = new TextureManager(backend.glContext, this.layoutStrategy, this.context.profiler, { + reuseTextures: backend.textureCacheMode === 'full', + }); this.packedTextureDataCache = new Map(); this.unpackedTextureDataCache = new Map(); this.pack = backend.pack; @@ -45,7 +48,10 @@ export class WebGLSessionHandler implements SessionHandler { return new WebGLInferenceHandler(this); } onGraphInitialized(graph: Graph): void { - const initializers = graph.getValues().filter(v => v.from === -1 && v.tensor).map(v => v.tensor!.dataId); + const initializers = graph + .getValues() + .filter((v) => v.from === -1 && v.tensor) + .map((v) => v.tensor!.dataId); this.initializers = new Set(initializers); } isInitializer(tensorId: Tensor.Id): boolean { @@ -54,7 +60,7 @@ export class WebGLSessionHandler implements SessionHandler { addInitializer(tensorId: Tensor.Id): void { this.initializers.add(tensorId); } - getTextureData(tensorId: Tensor.Id, isPacked: boolean): TextureData|undefined { + getTextureData(tensorId: Tensor.Id, isPacked: boolean): TextureData | undefined { if (isPacked) { return this.packedTextureDataCache.get(tensorId); } else { @@ -72,13 +78,13 @@ export class WebGLSessionHandler implements SessionHandler { dispose(): void { this.programManager.dispose(); this.textureManager.clearActiveTextures(); - this.packedTextureDataCache.forEach(td => this.textureManager.releaseTexture(td, true)); + this.packedTextureDataCache.forEach((td) => this.textureManager.releaseTexture(td, true)); this.packedTextureDataCache = new Map(); - this.unpackedTextureDataCache.forEach(td => this.textureManager.releaseTexture(td, true)); + this.unpackedTextureDataCache.forEach((td) => this.textureManager.releaseTexture(td, true)); this.unpackedTextureDataCache = new Map(); } resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator { const op = resolveOperator(node, opsets, WEBGL_OP_RESOLVE_RULES); - return {impl: op.opImpl, context: op.opInit ? op.opInit(node, graph) : node}; + return { impl: op.opImpl, context: op.opInit ? op.opInit(node, graph) : node }; } } diff --git a/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts b/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts index 4b0cf3f037921..51b73a7023d28 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Logger} from '../../instrument'; +import { Logger } from '../../instrument'; export declare namespace Encoder { export interface DataTypeMap { @@ -70,7 +70,7 @@ export class RedFloat32DataEncoder implements DataEncoder { Logger.warning('Encoder', 'Source data too small. Allocating larger array'); source = src as Float32Array; result = this.allocate(textureSize * this.channelSize) as Float32Array; - source.forEach((v, i) => result[i] = v); + source.forEach((v, i) => (result[i] = v)); } else { source = src as Float32Array; result = source; @@ -110,7 +110,7 @@ export class RGBAFloatDataEncoder implements DataEncoder { if (this.channelSize === 1) { Logger.verbose('Encoder', 'Exploding into a larger array'); dest = this.allocate(textureSize) as Float32Array; - src.forEach((v, i) => dest[i * 4] = v); + src.forEach((v, i) => (dest[i * 4] = v)); } return dest; } @@ -134,7 +134,7 @@ export class Uint8DataEncoder implements DataEncoder { constructor(gl: WebGLRenderingContext, channels = 1) { if (channels === 1) { this.internalFormat = gl.ALPHA; - this.format = gl.ALPHA; // not tested + this.format = gl.ALPHA; // not tested this.textureType = gl.UNSIGNED_BYTE; this.channelSize = channels; } else if (channels === 4) { diff --git a/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts b/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts index f8e370747928c..b05a130e521d0 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Logger} from '../../instrument'; -import {assert} from '../../util'; +import { Logger } from '../../instrument'; +import { assert } from '../../util'; /** Layout preferences */ export interface WidthHeightPrefs { @@ -37,8 +37,9 @@ export class AlwaysKeepOriginalSizeStrategy implements TextureLayoutStrategy { // ignore preferences // continue with default layout Logger.verbose( - 'TextureLayout', - `Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`); + 'TextureLayout', + `Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`, + ); } else { return [wsize, hsize]; } @@ -89,8 +90,9 @@ export class PreferLogicalStrategy implements TextureLayoutStrategy { // ignore preferences // continue with default layout Logger.verbose( - 'TextureLayout', - `Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`); + 'TextureLayout', + `Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`, + ); } else { return [wsize, hsize]; } @@ -104,8 +106,9 @@ export class PreferLogicalStrategy implements TextureLayoutStrategy { // they are from adjacent pairs of rows/cols within the same batch. So if a // tensor has 3 rows, we pretend it has 4 rows in order to account for the // fact that the texels containing the third row are half empty. - logShape = logShape.map( - (_d, i) => i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i]); + logShape = logShape.map((_d, i) => + i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i], + ); // Packed texture height is at least 2 (the channel height of a single // texel). @@ -130,12 +133,16 @@ export class PreferLogicalStrategy implements TextureLayoutStrategy { } else if (logShape.length === 3 && logShape[0] <= maxTextureSize && logShape[1] * logShape[2] <= maxTextureSize) { return [logShape[0], logShape[1] * logShape[2]]; } else if ( - logShape.length === 4 && logShape[0] * logShape[1] * logShape[2] <= maxTextureSize && - logShape[3] <= maxTextureSize) { + logShape.length === 4 && + logShape[0] * logShape[1] * logShape[2] <= maxTextureSize && + logShape[3] <= maxTextureSize + ) { return [logShape[0] * logShape[1] * logShape[2], logShape[3]]; } else if ( - logShape.length === 4 && logShape[0] <= maxTextureSize && - logShape[1] * logShape[2] * logShape[3] <= maxTextureSize) { + logShape.length === 4 && + logShape[0] <= maxTextureSize && + logShape[1] * logShape[2] * logShape[3] <= maxTextureSize + ) { return [logShape[0], logShape[1] * logShape[2] * logShape[3]]; } else { if (isPacked) { @@ -144,18 +151,18 @@ export class PreferLogicalStrategy implements TextureLayoutStrategy { // inner dimensions stay even, we rewrite size to equal the number of // texels. Then in the return statement we rehydrate the squarified // dimensions to channel units. - return sizeToSquarishShape(size / 4).map(d => d * 2) as [number, number]; + return sizeToSquarishShape(size / 4).map((d) => d * 2) as [number, number]; } return sizeToSquarishShape(size); } } } -export function squeezeShape(shape: number[], axis?: number[]): {newShape: number[]; keptDims: number[]} { +export function squeezeShape(shape: number[], axis?: number[]): { newShape: number[]; keptDims: number[] } { const newShape: number[] = []; const keptDims: number[] = []; const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; - const axes = (axis == null || isEmptyArray) ? null : parseAxisParam(axis, shape).sort(); + const axes = axis == null || isEmptyArray ? null : parseAxisParam(axis, shape).sort(); let j = 0; for (let i = 0; i < shape.length; ++i) { if (axes != null) { @@ -175,10 +182,10 @@ export function squeezeShape(shape: number[], axis?: number[]): {newShape: numbe keptDims.push(i); } } - return {newShape, keptDims}; + return { newShape, keptDims }; } -export function parseAxisParam(axis: number|number[], shape: number[]): number[] { +export function parseAxisParam(axis: number | number[], shape: number[]): number[] { const rank = shape.length; // Normalize input @@ -186,18 +193,15 @@ export function parseAxisParam(axis: number|number[], shape: number[]): number[] // Check for valid range assert( - axis.every(ax => ax >= -rank && ax < rank), - () => `All values in axis param must be in range [-${rank}, ${rank}) but ` + - `got axis ${axis}`); + axis.every((ax) => ax >= -rank && ax < rank), + () => `All values in axis param must be in range [-${rank}, ${rank}) but ` + `got axis ${axis}`, + ); // Check for only integers - assert( - axis.every(isInt), - () => 'All values in axis param must be integers but ' + - `got axis ${axis}`); + assert(axis.every(isInt), () => 'All values in axis param must be integers but ' + `got axis ${axis}`); // Handle negative axis. - return axis.map(a => a < 0 ? rank + a : a); + return axis.map((a) => (a < 0 ? rank + a : a)); } export function isInt(a: number): boolean { return a % 1 === 0; diff --git a/js/web/lib/onnxjs/backends/webgl/texture-layout.ts b/js/web/lib/onnxjs/backends/webgl/texture-layout.ts index 17ed44ec64fac..7b4068aed5d2c 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-layout.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-layout.ts @@ -1,70 +1,82 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {ShapeUtil} from '../../util'; +import { ShapeUtil } from '../../util'; -import {TextureLayoutStrategy, WidthHeightPrefs} from './texture-layout-strategy'; -import {TextureLayout, TextureType} from './types'; +import { TextureLayoutStrategy, WidthHeightPrefs } from './texture-layout-strategy'; +import { TextureLayout, TextureType } from './types'; -export const createTextureLayoutFromTextureType = - (textureLayoutStrategy: TextureLayoutStrategy, shape: readonly number[], - textureType: TextureType): TextureLayout => { - const channel = (textureType === TextureType.unpacked || textureType === TextureType.unpackedReversed) ? 1 : 4; - const isPacked = textureType === TextureType.packed; - const reverseWH = (textureType === TextureType.unpackedReversed || textureType === TextureType.packed); - const breakAxis = textureType === TextureType.packedLastDimension ? shape.length - 1 : undefined; - const unpackedShape = textureType === TextureType.packedLastDimension ? - shape.map((d, i) => i === shape.length - 1 ? d * 4 : d) : - undefined; - return createTextureLayoutFromShape( - textureLayoutStrategy, shape, channel, unpackedShape, {isPacked, reverseWH, breakAxis}); - }; +export const createTextureLayoutFromTextureType = ( + textureLayoutStrategy: TextureLayoutStrategy, + shape: readonly number[], + textureType: TextureType, +): TextureLayout => { + const channel = textureType === TextureType.unpacked || textureType === TextureType.unpackedReversed ? 1 : 4; + const isPacked = textureType === TextureType.packed; + const reverseWH = textureType === TextureType.unpackedReversed || textureType === TextureType.packed; + const breakAxis = textureType === TextureType.packedLastDimension ? shape.length - 1 : undefined; + const unpackedShape = + textureType === TextureType.packedLastDimension + ? shape.map((d, i) => (i === shape.length - 1 ? d * 4 : d)) + : undefined; + return createTextureLayoutFromShape(textureLayoutStrategy, shape, channel, unpackedShape, { + isPacked, + reverseWH, + breakAxis, + }); +}; -export const calculateTextureWidthAndHeight = - (textureLayoutStrategy: TextureLayoutStrategy, shape: readonly number[], textureType: TextureType): - [number, number] => { - const layout = createTextureLayoutFromTextureType(textureLayoutStrategy, shape, textureType); - return [layout.width, layout.height]; - }; +export const calculateTextureWidthAndHeight = ( + textureLayoutStrategy: TextureLayoutStrategy, + shape: readonly number[], + textureType: TextureType, +): [number, number] => { + const layout = createTextureLayoutFromTextureType(textureLayoutStrategy, shape, textureType); + return [layout.width, layout.height]; +}; /** * Create a TextureLayout object from shape. */ -export const createTextureLayoutFromShape = - (textureLayoutStrategy: TextureLayoutStrategy, shape: readonly number[], channels: 1|4 = 1, - unpackedShape?: readonly number[], prefs?: WidthHeightPrefs): TextureLayout => { - const isPacked = !!(prefs && prefs.isPacked); - const [width, height] = textureLayoutStrategy.computeTextureWH(isPacked ? unpackedShape || shape : shape, prefs); - const rank = shape.length; - let inferredDims = shape.slice(0); - if (rank === 0) { - inferredDims = [1]; - } - if (channels === 1) { - // unpackedShape will take `shape` and not `inferredDims` so as to create a scalar Tensor if need be - unpackedShape = shape; - } else if (isPacked) { - if (channels !== 4) { - throw new Error('a packed texture must be 4-channel'); - } - unpackedShape = shape; - if (rank > 0) { - inferredDims[rank - 1] = Math.ceil(inferredDims[rank - 1] / 2); - } - if (rank > 1) { - inferredDims[rank - 2] = Math.ceil(inferredDims[rank - 2] / 2); - } - } else if (!unpackedShape) { - throw new Error('Unpacked shape is needed when using channels > 1'); - } - return { - width, - height, - channels, - isPacked, - shape: inferredDims, - strides: ShapeUtil.computeStrides(inferredDims), - unpackedShape, - reversedWH: (prefs && prefs.reverseWH) - }; - }; +export const createTextureLayoutFromShape = ( + textureLayoutStrategy: TextureLayoutStrategy, + shape: readonly number[], + channels: 1 | 4 = 1, + unpackedShape?: readonly number[], + prefs?: WidthHeightPrefs, +): TextureLayout => { + const isPacked = !!(prefs && prefs.isPacked); + const [width, height] = textureLayoutStrategy.computeTextureWH(isPacked ? unpackedShape || shape : shape, prefs); + const rank = shape.length; + let inferredDims = shape.slice(0); + if (rank === 0) { + inferredDims = [1]; + } + if (channels === 1) { + // unpackedShape will take `shape` and not `inferredDims` so as to create a scalar Tensor if need be + unpackedShape = shape; + } else if (isPacked) { + if (channels !== 4) { + throw new Error('a packed texture must be 4-channel'); + } + unpackedShape = shape; + if (rank > 0) { + inferredDims[rank - 1] = Math.ceil(inferredDims[rank - 1] / 2); + } + if (rank > 1) { + inferredDims[rank - 2] = Math.ceil(inferredDims[rank - 2] / 2); + } + } else if (!unpackedShape) { + throw new Error('Unpacked shape is needed when using channels > 1'); + } + return { + width, + height, + channels, + isPacked, + shape: inferredDims, + strides: ShapeUtil.computeStrides(inferredDims), + unpackedShape, + reversedWH: prefs && prefs.reverseWH, + }; +}; diff --git a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts index effb65288dc1c..3aad95b33e3e4 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Logger, Profiler} from '../../instrument'; -import {Tensor} from '../../tensor'; +import { Logger, Profiler } from '../../instrument'; +import { Tensor } from '../../tensor'; -import {Encoder, EncoderUsage} from './texture-data-encoder'; -import {TextureLayoutStrategy} from './texture-layout-strategy'; -import {TextureData, TextureLayout} from './types'; -import {WebGLContext} from './webgl-context'; +import { Encoder, EncoderUsage } from './texture-data-encoder'; +import { TextureLayoutStrategy } from './texture-layout-strategy'; +import { TextureData, TextureLayout } from './types'; +import { WebGLContext } from './webgl-context'; export interface TextureManagerConfig { reuseTextures?: boolean; @@ -30,8 +30,11 @@ export class TextureManager { private readonly pendingRead: Map void>> = new Map(); constructor( - public glContext: WebGLContext, public layoutStrategy: TextureLayoutStrategy, public profiler: Readonly, - private config: TextureManagerConfig) { + public glContext: WebGLContext, + public layoutStrategy: TextureLayoutStrategy, + public profiler: Readonly, + private config: TextureManagerConfig, + ) { if (config.reuseTextures) { this.inUseTextures = new Map(); this.idleTextures = new Map(); @@ -39,7 +42,11 @@ export class TextureManager { } } createTextureFromLayout( - dataType: Tensor.DataType, layout: TextureLayout, data?: Tensor.NumberType, usage?: EncoderUsage) { + dataType: Tensor.DataType, + layout: TextureLayout, + data?: Tensor.NumberType, + usage?: EncoderUsage, + ) { const textureDataType = this.toEncoderType(dataType); const encoder = this.glContext.getEncoder(textureDataType, layout.channels || 1, usage); @@ -49,8 +56,8 @@ export class TextureManager { const width = layout.width; const height = layout.height; - let key: string|undefined; - let inUseTextures: WebGLTexture[]|undefined; + let key: string | undefined; + let inUseTextures: WebGLTexture[] | undefined; if (this.config.reuseTextures) { key = `${width}x${height}_${encoder.format}_${encoder.internalFormat}_${encoder.textureType}`; inUseTextures = this.inUseTextures.get(key); @@ -86,7 +93,13 @@ export class TextureManager { return this.profiler.event('backend', 'TextureManager.readTexture', () => { const dataSize = td.shape.reduce((a, b) => a * b) * channels!; const data = this.glContext.readTexture( - td.texture, td.width, td.height, dataSize, this.toEncoderType(dataType), channels!); + td.texture, + td.width, + td.height, + dataSize, + this.toEncoderType(dataType), + channels!, + ); return this.toTensorData(dataType, data); }); } @@ -97,7 +110,7 @@ export class TextureManager { } if (this.pendingRead.has(dataId)) { const subscribers = this.pendingRead.get(dataId); - return new Promise(resolve => subscribers?.push(resolve)); + return new Promise((resolve) => subscribers?.push(resolve)); } return this.profiler.event('backend', 'TextureManager.readTextureAsync', async () => { this.pendingRead.set(dataId, []); @@ -105,11 +118,17 @@ export class TextureManager { // add a fence waiting for the data to be ready await this.glContext.createAndWaitForFence(); const data = this.glContext.readTexture( - td.texture, td.width, td.height, dataSize, this.toEncoderType(dataType), channels!); + td.texture, + td.width, + td.height, + dataSize, + this.toEncoderType(dataType), + channels!, + ); const tensorData = this.toTensorData(dataType, data); const subscribers = this.pendingRead.get(dataId); this.pendingRead.delete(dataId); - subscribers?.forEach(resolve => resolve(tensorData)); + subscribers?.forEach((resolve) => resolve(tensorData)); return tensorData; }); } @@ -121,7 +140,7 @@ export class TextureManager { }); } releaseTexture(textureData: TextureData, deleteTexture?: boolean): void { - let key: string|undefined; + let key: string | undefined; if (this.config.reuseTextures) { key = this.textureLookup.get(textureData.texture); if (key) { @@ -172,11 +191,11 @@ export class TextureManager { throw new Error(`TensorData type ${dataType} is not supported`); } } - toTextureData(_dataType: Tensor.DataType, data: Tensor.NumberType|undefined): Encoder.DataArrayType|undefined { + toTextureData(_dataType: Tensor.DataType, data: Tensor.NumberType | undefined): Encoder.DataArrayType | undefined { if (!data) { return undefined; } - return (data instanceof Float32Array) ? data : new Float32Array(data); + return data instanceof Float32Array ? data : new Float32Array(data); /* switch (dataType) { case 'int16': diff --git a/js/web/lib/onnxjs/backends/webgl/types.ts b/js/web/lib/onnxjs/backends/webgl/types.ts index 03124fd0b67bd..ed38090a0f820 100644 --- a/js/web/lib/onnxjs/backends/webgl/types.ts +++ b/js/web/lib/onnxjs/backends/webgl/types.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../tensor'; +import { Tensor } from '../../tensor'; /** * Layout info is used for mapping n-dimensional array to 2D textures @@ -14,7 +14,7 @@ export interface TextureLayout { /** * specify the number of value that encoded in a single pixel */ - channels: 1|2|3|4; + channels: 1 | 2 | 3 | 4; /** * whether in packed mode or not */ @@ -40,11 +40,11 @@ export interface TextureData extends TextureLayout { } export enum TextureType { - unpacked, // <-- normal unpacked texture - unpackedReversed, // <-- unpacked texture used in old ONNX.js implementation (deprecated) - packed, // <-- normal packed texture - downloadUint8AsFloat, // <-- ONLY used in texture downloading for iOS devices - packedLastDimension // <-- ONLY used in old ONNX.js Conv implementation for input W (deprecated) + unpacked, // <-- normal unpacked texture + unpackedReversed, // <-- unpacked texture used in old ONNX.js implementation (deprecated) + packed, // <-- normal packed texture + downloadUint8AsFloat, // <-- ONLY used in texture downloading for iOS devices + packedLastDimension, // <-- ONLY used in old ONNX.js Conv implementation for input W (deprecated) } export interface TensorInfo { @@ -55,10 +55,10 @@ export interface TensorInfo { } export interface ProgramVariable { - type: 'float'|'int'; + type: 'float' | 'int'; name: string; arrayLength?: number; - data: number|number[]; + data: number | number[]; } /** @@ -116,23 +116,23 @@ export interface ProgramInfo extends ProgramMetadata { } export interface VariableInfo { - type: 'float'|'int'; + type: 'float' | 'int'; name: string; arrayLength?: number; } export interface ProgramVariable { - type: 'float'|'int'; + type: 'float' | 'int'; name: string; arrayLength?: number; - data: number|number[]; + data: number | number[]; } /** * Information of uniforms that shader uses */ export interface UniformInfo { - type: 'sampler2D'|VariableInfo['type']; + type: 'sampler2D' | VariableInfo['type']; name: string; arrayLength?: number; } @@ -150,7 +150,7 @@ export interface Artifact { programInfo: ProgramInfo; program: WebGLProgram; uniformLocations: UniformLocation[]; - attribLocations: {position: number; textureCoord: number}; + attribLocations: { position: number; textureCoord: number }; } export declare namespace Artifact { type UniformLocations = Artifact['uniformLocations']; @@ -158,5 +158,5 @@ export declare namespace Artifact { } export interface UniformData { - [name: string]: number|number[]; + [name: string]: number | number[]; } diff --git a/js/web/lib/onnxjs/backends/webgl/utils.ts b/js/web/lib/onnxjs/backends/webgl/utils.ts index 1f2f2def50c7d..d2286cdd9e826 100644 --- a/js/web/lib/onnxjs/backends/webgl/utils.ts +++ b/js/web/lib/onnxjs/backends/webgl/utils.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {assert} from '../../util'; +import { assert } from '../../util'; /** * Given a non RGBA shape calculate the R version * It is assumed that the dimensions are multiples of given channels @@ -14,7 +14,10 @@ export function getPackedShape(unpackedShape: readonly number[]): readonly numbe } export async function repeatedTry( - checkFn: () => boolean, delayFn = (_counter: number) => 0, maxCounter?: number): Promise { + checkFn: () => boolean, + delayFn = (_counter: number) => 0, + maxCounter?: number, +): Promise { return new Promise((resolve, reject) => { let tryCount = 0; @@ -67,7 +70,7 @@ export function squeezeInputShape(inputShape: readonly number[], squeezedShape: /** Returns a list of squeezed parameters for shader functions */ export function getSqueezedParams(params: string[], keptDims: number[]): string { - return keptDims.map(d => params[d]).join(', '); + return keptDims.map((d) => params[d]).join(', '); } /** Returns the data type for different ranks. */ diff --git a/js/web/lib/onnxjs/backends/webgl/webgl-context-factory.ts b/js/web/lib/onnxjs/backends/webgl/webgl-context-factory.ts index 6bf12500ec8b5..bbf05a7b75a28 100644 --- a/js/web/lib/onnxjs/backends/webgl/webgl-context-factory.ts +++ b/js/web/lib/onnxjs/backends/webgl/webgl-context-factory.ts @@ -1,19 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Logger} from '../../instrument'; +import { Logger } from '../../instrument'; -import {WebGLContext} from './webgl-context'; +import { WebGLContext } from './webgl-context'; -const cache: {[contextId: string]: WebGLContext} = {}; +const cache: { [contextId: string]: WebGLContext } = {}; /** * This factory function creates proper WebGLRenderingContext based on * the current browsers capabilities * The order is from higher/most recent versions to most basic */ -export function createWebGLContext(contextId?: 'webgl'|'webgl2'): WebGLContext { - let context: WebGLContext|undefined; +export function createWebGLContext(contextId?: 'webgl' | 'webgl2'): WebGLContext { + let context: WebGLContext | undefined; if ((!contextId || contextId === 'webgl2') && 'webgl2' in cache) { context = cache.webgl2; } else if ((!contextId || contextId === 'webgl') && 'webgl' in cache) { @@ -55,7 +55,7 @@ export function createWebGLContext(contextId?: 'webgl'|'webgl2'): WebGLContext { return context; } -export function createNewWebGLContext(canvas: HTMLCanvasElement, contextId?: 'webgl'|'webgl2'): WebGLContext { +export function createNewWebGLContext(canvas: HTMLCanvasElement, contextId?: 'webgl' | 'webgl2'): WebGLContext { const contextAttributes: WebGLContextAttributes = { alpha: false, depth: false, @@ -63,9 +63,9 @@ export function createNewWebGLContext(canvas: HTMLCanvasElement, contextId?: 'we stencil: false, preserveDrawingBuffer: false, premultipliedAlpha: false, - failIfMajorPerformanceCaveat: false + failIfMajorPerformanceCaveat: false, }; - let gl: WebGLRenderingContext|null; + let gl: WebGLRenderingContext | null; const ca = contextAttributes; if (!contextId || contextId === 'webgl2') { gl = canvas.getContext('webgl2', ca); @@ -78,14 +78,15 @@ export function createNewWebGLContext(canvas: HTMLCanvasElement, contextId?: 'we } } if (!contextId || contextId === 'webgl') { - gl = canvas.getContext('webgl', ca) || canvas.getContext('experimental-webgl', ca) as WebGLRenderingContext; + gl = canvas.getContext('webgl', ca) || (canvas.getContext('experimental-webgl', ca) as WebGLRenderingContext); if (gl) { try { return new WebGLContext(gl, 1); } catch (err) { Logger.warning( - 'GlContextFactory', - `failed to create WebGLContext using contextId 'webgl' or 'experimental-webgl'. Error: ${err}`); + 'GlContextFactory', + `failed to create WebGLContext using contextId 'webgl' or 'experimental-webgl'. Error: ${err}`, + ); } } } @@ -94,7 +95,7 @@ export function createNewWebGLContext(canvas: HTMLCanvasElement, contextId?: 'we } // eslint-disable-next-line @typescript-eslint/naming-convention -declare let OffscreenCanvas: {new (width: number, height: number): HTMLCanvasElement}; +declare let OffscreenCanvas: { new (width: number, height: number): HTMLCanvasElement }; function createCanvas(): HTMLCanvasElement { if (typeof document === 'undefined') { diff --git a/js/web/lib/onnxjs/backends/webgl/webgl-context.ts b/js/web/lib/onnxjs/backends/webgl/webgl-context.ts index 744f206e38334..19684dec81b3d 100644 --- a/js/web/lib/onnxjs/backends/webgl/webgl-context.ts +++ b/js/web/lib/onnxjs/backends/webgl/webgl-context.ts @@ -1,19 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env} from 'onnxruntime-common'; +import { env } from 'onnxruntime-common'; import * as DataEncoders from './texture-data-encoder'; -import {DataEncoder, Encoder, EncoderUsage} from './texture-data-encoder'; -import {repeatedTry} from './utils'; +import { DataEncoder, Encoder, EncoderUsage } from './texture-data-encoder'; +import { repeatedTry } from './utils'; export interface FenceContext { - query: WebGLSync|null; + query: WebGLSync | null; isFencePassed(): boolean; } type PollItem = { - isDoneFn: () => boolean; resolveFn: () => void; + isDoneFn: () => boolean; + resolveFn: () => void; }; export function linearSearchLastTrue(arr: Array<() => boolean>): number { @@ -32,7 +33,7 @@ export function linearSearchLastTrue(arr: Array<() => boolean>): number { */ export class WebGLContext { gl: WebGLRenderingContext; - version: 1|2; + version: 1 | 2; private vertexbuffer: WebGLBuffer; private framebuffer: WebGLFramebuffer; @@ -58,19 +59,19 @@ export class WebGLContext { // WebGL extensions // eslint-disable-next-line camelcase - textureFloatExtension: OES_texture_float|null; + textureFloatExtension: OES_texture_float | null; // eslint-disable-next-line camelcase - textureHalfFloatExtension: OES_texture_half_float|null; + textureHalfFloatExtension: OES_texture_half_float | null; // WebGL2 extensions - colorBufferFloatExtension: unknown|null; + colorBufferFloatExtension: unknown | null; // eslint-disable-next-line @typescript-eslint/naming-convention - disjointTimerQueryWebgl2Extension: {TIME_ELAPSED_EXT: GLenum; GPU_DISJOINT_EXT: GLenum}|null; + disjointTimerQueryWebgl2Extension: { TIME_ELAPSED_EXT: GLenum; GPU_DISJOINT_EXT: GLenum } | null; private disposed: boolean; private frameBufferBound = false; - constructor(gl: WebGLRenderingContext, version: 1|2) { + constructor(gl: WebGLRenderingContext, version: 1 | 2) { this.gl = gl; this.version = version; @@ -92,25 +93,40 @@ export class WebGLContext { gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); const buffer = data ? encoder.encode(data, width * height) : null; gl.texImage2D( - gl.TEXTURE_2D, - 0, // Level of detail. - encoder.internalFormat, width, height, - 0, // Always 0 in OpenGL ES. - encoder.format, encoder.textureType, buffer); + gl.TEXTURE_2D, + 0, // Level of detail. + encoder.internalFormat, + width, + height, + 0, // Always 0 in OpenGL ES. + encoder.format, + encoder.textureType, + buffer, + ); this.checkError(); return texture as WebGLTexture; } updateTexture( - texture: WebGLTexture, width: number, height: number, encoder: DataEncoder, data: Encoder.DataArrayType): void { + texture: WebGLTexture, + width: number, + height: number, + encoder: DataEncoder, + data: Encoder.DataArrayType, + ): void { const gl = this.gl; gl.bindTexture(gl.TEXTURE_2D, texture); const buffer = encoder.encode(data, width * height); gl.texSubImage2D( - gl.TEXTURE_2D, - 0, // level - 0, // xoffset - 0, // yoffset - width, height, encoder.format, encoder.textureType, buffer); + gl.TEXTURE_2D, + 0, // level + 0, // xoffset + 0, // yoffset + width, + height, + encoder.format, + encoder.textureType, + buffer, + ); this.checkError(); } attachFramebuffer(texture: WebGLTexture, width: number, height: number): void { @@ -118,16 +134,19 @@ export class WebGLContext { // Make it the target for framebuffer operations - including rendering. gl.bindTexture(gl.TEXTURE_2D, texture); gl.bindFramebuffer(gl.FRAMEBUFFER, this.framebuffer); - gl.framebufferTexture2D( - gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, - 0); // 0, we aren't using MIPMAPs + gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); // 0, we aren't using MIPMAPs this.checkError(); gl.viewport(0, 0, width, height); gl.scissor(0, 0, width, height); } readTexture( - texture: WebGLTexture, width: number, height: number, dataSize: number, dataType: Encoder.DataType, - channels: number): Encoder.DataArrayType { + texture: WebGLTexture, + width: number, + height: number, + dataSize: number, + dataType: Encoder.DataType, + channels: number, + ): Encoder.DataArrayType { const gl = this.gl; if (!channels) { channels = 1; @@ -139,9 +158,7 @@ export class WebGLContext { const buffer = encoder.allocate(width * height); // bind texture to framebuffer gl.bindTexture(gl.TEXTURE_2D, texture); - gl.framebufferTexture2D( - gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, - 0); // 0, we aren't using MIPMAPs + gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); // 0, we aren't using MIPMAPs // TODO: Check if framebuffer is ready gl.readPixels(0, 0, width, height, gl.RGBA, encoder.textureType, buffer); this.checkError(); @@ -156,7 +173,7 @@ export class WebGLContext { getActiveTexture(): string { const gl = this.gl; const n = gl.getParameter(this.gl.ACTIVE_TEXTURE); - return `TEXTURE${(n - gl.TEXTURE0)}`; + return `TEXTURE${n - gl.TEXTURE0}`; } getTextureBinding(): WebGLTexture { return this.gl.getParameter(this.gl.TEXTURE_BINDING_2D); @@ -174,10 +191,7 @@ export class WebGLContext { } this.checkError(); } - createProgram( - vertexShader: WebGLShader, - fragShader: WebGLShader, - ): WebGLProgram { + createProgram(vertexShader: WebGLShader, fragShader: WebGLShader): WebGLProgram { const gl = this.gl; const program = gl.createProgram()!; @@ -225,24 +239,24 @@ ${shaderSource}`); const error = gl.getError(); let label = ''; switch (error) { - case (gl.NO_ERROR): + case gl.NO_ERROR: return; - case (gl.INVALID_ENUM): + case gl.INVALID_ENUM: label = 'INVALID_ENUM'; break; - case (gl.INVALID_VALUE): + case gl.INVALID_VALUE: label = 'INVALID_VALUE'; break; - case (gl.INVALID_OPERATION): + case gl.INVALID_OPERATION: label = 'INVALID_OPERATION'; break; - case (gl.INVALID_FRAMEBUFFER_OPERATION): + case gl.INVALID_FRAMEBUFFER_OPERATION: label = 'INVALID_FRAMEBUFFER_OPERATION'; break; - case (gl.OUT_OF_MEMORY): + case gl.OUT_OF_MEMORY: label = 'OUT_OF_MEMORY'; break; - case (gl.CONTEXT_LOST_WEBGL): + case gl.CONTEXT_LOST_WEBGL: label = 'CONTEXT_LOST_WEBGL'; break; default: @@ -268,7 +282,10 @@ ${shaderSource}`); return new DataEncoders.RGBAFloatDataEncoder(this.gl, channels); } else { return new DataEncoders.RGBAFloatDataEncoder( - this.gl, channels, this.textureHalfFloatExtension!.HALF_FLOAT_OES); + this.gl, + channels, + this.textureHalfFloatExtension!.HALF_FLOAT_OES, + ); } case 'int': throw new Error('not implemented'); @@ -302,10 +319,26 @@ ${shaderSource}`); private createDefaultGeometry(): Float32Array { // Sets of x,y,z(=0),s,t coordinates. return new Float32Array([ - -1.0, 1.0, 0.0, 0.0, 1.0, // upper left - -1.0, -1.0, 0.0, 0.0, 0.0, // lower left - 1.0, 1.0, 0.0, 1.0, 1.0, // upper right - 1.0, -1.0, 0.0, 1.0, 0.0 // lower right + -1.0, + 1.0, + 0.0, + 0.0, + 1.0, // upper left + -1.0, + -1.0, + 0.0, + 0.0, + 0.0, // lower left + 1.0, + 1.0, + 0.0, + 1.0, + 1.0, // upper right + 1.0, + -1.0, + 0.0, + 1.0, + 0.0, // lower right ]); } private createVertexbuffer(): WebGLBuffer { @@ -373,7 +406,7 @@ ${shaderSource}`); const texture = gl.createTexture(); gl.bindTexture(gl.TEXTURE_2D, texture); // eslint-disable-next-line @typescript-eslint/naming-convention - const internalFormat = this.version === 2 ? (gl as unknown as {RGBA32F: number}).RGBA32F : gl.RGBA; + const internalFormat = this.version === 2 ? (gl as unknown as { RGBA32F: number }).RGBA32F : gl.RGBA; gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, 1, 1, 0, gl.RGBA, gl.FLOAT, null); // STEP.2 bind a frame buffer const frameBuffer = gl.createFramebuffer(); @@ -427,11 +460,11 @@ ${shaderSource}`); const gl = this.gl; - let texture: WebGLTexture|null|undefined; - let frameBuffer: WebGLFramebuffer|null|undefined; - let vertexShader: WebGLShader|null|undefined; - let fragmentShader: WebGLShader|null|undefined; - let program: WebGLProgram|null|undefined; + let texture: WebGLTexture | null | undefined; + let frameBuffer: WebGLFramebuffer | null | undefined; + let vertexShader: WebGLShader | null | undefined; + let fragmentShader: WebGLShader | null | undefined; + let program: WebGLProgram | null | undefined; try { texture = gl.createTexture(); @@ -439,7 +472,7 @@ ${shaderSource}`); gl.bindTexture(gl.TEXTURE_2D, texture); // eslint-disable-next-line @typescript-eslint/naming-convention - const internalFormat = this.version === 2 ? (gl as unknown as {RGBA32F: number}).RGBA32F : gl.RGBA; + const internalFormat = this.version === 2 ? (gl as unknown as { RGBA32F: number }).RGBA32F : gl.RGBA; gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, 1, 1, 0, gl.RGBA, gl.FLOAT, null); gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); @@ -472,7 +505,6 @@ ${shaderSource}`); gl.drawArrays(gl.POINTS, 0, 1); return gl.getError() === gl.NO_ERROR; - } finally { gl.disable(gl.BLEND); @@ -523,7 +555,8 @@ ${shaderSource}`); } isTimerResultAvailable(query: WebGLQuery): boolean { - let available = false, disjoint = false; + let available = false, + disjoint = false; if (this.version === 2 && this.disjointTimerQueryWebgl2Extension) { const gl2 = this.gl as WebGL2RenderingContext; const ext = this.disjointTimerQueryWebgl2Extension; @@ -575,12 +608,15 @@ ${shaderSource}`); return status === gl2.ALREADY_SIGNALED || status === gl2.CONDITION_SATISFIED; }; } - return {query, isFencePassed}; + return { query, isFencePassed }; } async pollFence(fenceContext: FenceContext) { - return new Promise(resolve => { - void this.addItemToPoll(() => fenceContext.isFencePassed(), () => resolve()); + return new Promise((resolve) => { + void this.addItemToPoll( + () => fenceContext.isFencePassed(), + () => resolve(), + ); }); } @@ -588,16 +624,16 @@ ${shaderSource}`); pollItems(): void { // Find the last query that has finished. - const index = linearSearchLastTrue(this.itemsToPoll.map(x => x.isDoneFn)); + const index = linearSearchLastTrue(this.itemsToPoll.map((x) => x.isDoneFn)); for (let i = 0; i <= index; ++i) { - const {resolveFn} = this.itemsToPoll[i]; + const { resolveFn } = this.itemsToPoll[i]; resolveFn(); } this.itemsToPoll = this.itemsToPoll.slice(index + 1); } private async addItemToPoll(isDoneFn: () => boolean, resolveFn: () => void) { - this.itemsToPoll.push({isDoneFn, resolveFn}); + this.itemsToPoll.push({ isDoneFn, resolveFn }); if (this.itemsToPoll.length > 1) { // We already have a running loop that polls. return; diff --git a/js/web/lib/onnxjs/execution-plan.ts b/js/web/lib/onnxjs/execution-plan.ts index e155ff123f79d..40d6417b22d3a 100644 --- a/js/web/lib/onnxjs/execution-plan.ts +++ b/js/web/lib/onnxjs/execution-plan.ts @@ -1,18 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {SessionHandler} from './backend'; -import {Graph} from './graph'; -import {Logger, Profiler} from './instrument'; -import {Operator} from './operators'; -import {Tensor} from './tensor'; +import { SessionHandler } from './backend'; +import { Graph } from './graph'; +import { Logger, Profiler } from './instrument'; +import { Operator } from './operators'; +import { Tensor } from './tensor'; class KernelOp { - constructor(public op: Operator, public node: Graph.Node) {} + constructor( + public op: Operator, + public node: Graph.Node, + ) {} } export class ExecutionPlan { - constructor(private graph: Graph, ops: Operator[], private profiler: Readonly) { + constructor( + private graph: Graph, + ops: Operator[], + private profiler: Readonly, + ) { this.initialize(ops); } @@ -32,8 +39,8 @@ export class ExecutionPlan { let resolved = true; for (const input of op.node.inputs) { if ( - !this._values[input] // not an initialized input - && this.graph.getInputIndices().indexOf(input) === -1 // not model input + !this._values[input] && // not an initialized input + this.graph.getInputIndices().indexOf(input) === -1 // not model input ) { resolved = false; break; @@ -47,7 +54,7 @@ export class ExecutionPlan { } reset() { - this._values = this.graph.getValues().map(i => i.tensor); + this._values = this.graph.getValues().map((i) => i.tensor); } async execute(sessionHandler: SessionHandler, modelInputs: Tensor[]): Promise { @@ -61,8 +68,11 @@ export class ExecutionPlan { // populate inputs value const graphInputs = this.graph.getInputIndices(); if (modelInputs.length !== graphInputs.length) { - throw new Error(`number of input tensors don't match the number of inputs to the model: actual: ${ - modelInputs.length} expected: ${graphInputs.length}`); + throw new Error( + `number of input tensors don't match the number of inputs to the model: actual: ${ + modelInputs.length + } expected: ${graphInputs.length}`, + ); } modelInputs.forEach((input, i) => { @@ -83,7 +93,7 @@ export class ExecutionPlan { const thisOp = this._ops[thisOpIndex]; // check input - const inputList = thisOp.node.inputs.map(i => this._values[i]); + const inputList = thisOp.node.inputs.map((i) => this._values[i]); if (inputList.indexOf(undefined) !== -1) { throw new Error(`unresolved input detected: op: ${thisOp.node}`); } @@ -91,12 +101,15 @@ export class ExecutionPlan { // run const inputTensors = inputList as Tensor[]; Logger.verbose( - 'ExecPlan', - `Running op:${thisOp.node.name} (${ - inputTensors.map((t, i) => `'${thisOp.node.inputs[i]}': ${t.type}[${t.dims.join(',')}]`).join(', ')})`); + 'ExecPlan', + `Running op:${thisOp.node.name} (${inputTensors + .map((t, i) => `'${thisOp.node.inputs[i]}': ${t.type}[${t.dims.join(',')}]`) + .join(', ')})`, + ); - const outputList = await this.profiler.event( - 'node', thisOp.node.name, async () => thisOp.op.impl(inferenceHandler, inputTensors, thisOp.op.context)); + const outputList = await this.profiler.event('node', thisOp.node.name, async () => + thisOp.op.impl(inferenceHandler, inputTensors, thisOp.op.context), + ); // check output if (outputList.length !== thisOp.node.outputs.length) { @@ -154,7 +167,7 @@ export class ExecutionPlan { }); } - _values: Array; + _values: Array; _ops: KernelOp[]; _starter: number[]; } diff --git a/js/web/lib/onnxjs/graph.ts b/js/web/lib/onnxjs/graph.ts index d444be2bf7ce0..88a80ccbf196b 100644 --- a/js/web/lib/onnxjs/graph.ts +++ b/js/web/lib/onnxjs/graph.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Attribute} from './attribute'; -import {onnxruntime} from './ort-schema/flatbuffers/ort-generated'; -import {onnx} from './ort-schema/protobuf/onnx'; -import {Tensor} from './tensor'; -import {LongUtil, MAX_CLIP, MIN_CLIP, ProtoUtil} from './util'; +import { Attribute } from './attribute'; +import { onnxruntime } from './ort-schema/flatbuffers/ort-generated'; +import { onnx } from './ort-schema/protobuf/onnx'; +import { Tensor } from './tensor'; +import { LongUtil, MAX_CLIP, MIN_CLIP, ProtoUtil } from './util'; import ortFbs = onnxruntime.experimental.fbs; @@ -78,8 +78,8 @@ export const Graph = { /** * construct a graph from a graph protobuf type */ - from: (graphProto: onnx.IGraphProto|ortFbs.Graph, initializer?: Graph.Initializer) => - new GraphImpl(graphProto, initializer), + from: (graphProto: onnx.IGraphProto | ortFbs.Graph, initializer?: Graph.Initializer) => + new GraphImpl(graphProto, initializer), }; class Value implements Graph.Value { @@ -94,7 +94,7 @@ class Value implements Graph.Value { } } - _from?: number; // -1 represent from initializer + _from?: number; // -1 represent from initializer get from() { return this._from!; } @@ -107,7 +107,7 @@ class Value implements Graph.Value { } class Node implements Graph.Node { - constructor(_nodeProto: onnx.INodeProto|ortFbs.Node, name?: string) { + constructor(_nodeProto: onnx.INodeProto | ortFbs.Node, name?: string) { if (_nodeProto instanceof onnx.NodeProto) { this.name = _nodeProto.name; this.opType = _nodeProto.opType; @@ -142,7 +142,7 @@ class GraphImpl implements Graph, Graph.Transformer { private _nodes: Node[]; - constructor(graph: onnx.IGraphProto|ortFbs.Graph, graphInitializer?: Graph.Initializer) { + constructor(graph: onnx.IGraphProto | ortFbs.Graph, graphInitializer?: Graph.Initializer) { if (!graph) { throw new TypeError('graph is empty'); } @@ -181,7 +181,7 @@ class GraphImpl implements Graph, Graph.Transformer { return this._nodes; } - private buildGraph(graph: onnx.IGraphProto|ortFbs.Graph) { + private buildGraph(graph: onnx.IGraphProto | ortFbs.Graph) { // build the graph - will throw exceptions if something fatal is detected if (graph instanceof onnx.GraphProto) { this.buildGraphFromOnnxFormat(graph); @@ -228,8 +228,8 @@ class GraphImpl implements Graph, Graph.Transformer { if (index === undefined) { const value = new Value(); value.type = { - shape: {dims: ProtoUtil.tensorDimsFromProto(i.dims!)}, - tensorType: ProtoUtil.tensorDataTypeFromProto(i.dataType!) + shape: { dims: ProtoUtil.tensorDimsFromProto(i.dims!) }, + tensorType: ProtoUtil.tensorDataTypeFromProto(i.dataType!), }; index = this._allData.push(value) - 1; dataIndices.set(i.name!, index); @@ -267,7 +267,7 @@ class GraphImpl implements Graph, Graph.Transformer { for (const nodeProto of graph.node) { if (!nodeProto.name) { // assign a name to the node if it doesn't have one - for (let pick = 0;; pick++) { + for (let pick = 0; ; pick++) { const name = `unnamed_${nodeProto.opType}_${pick}`; if (!nodesIndices.has(name)) { nodeProto.name = name; @@ -333,8 +333,11 @@ class GraphImpl implements Graph, Graph.Transformer { const dataIndex = dataIndices.get(input); if (typeof dataIndex === 'undefined') { // handle exception when opset > 9 and roi / scales not given - if (input === '' && (nodeProto.input.length === 3 || nodeProto.input.length === 4) && - nodeProto.opType === 'Resize') { + if ( + input === '' && + (nodeProto.input.length === 3 || nodeProto.input.length === 4) && + nodeProto.opType === 'Resize' + ) { continue; } throw new Error(`unrecognized input '${input}' for node: ${nodeProto.name}`); @@ -384,7 +387,7 @@ class GraphImpl implements Graph, Graph.Transformer { for (let k = 0; k < shape.dimLength()!; k++) { dims.push(LongUtil.longToNumber(shape.dim(k)!.value()!.dimValue()!)); } - value.type = {shape: {dims}, tensorType: type}; + value.type = { shape: { dims }, tensorType: type }; const currentIndex = this._allData.push(value) - 1; dataIndices.set(inputName, currentIndex); inputValueNames.push(inputName); @@ -399,7 +402,7 @@ class GraphImpl implements Graph, Graph.Transformer { const value = new Value(); const dims = ProtoUtil.tensorDimsFromORTFormat(initializer); const type = ProtoUtil.tensorDataTypeFromProto(initializer.dataType()); - value.type = {shape: {dims}, tensorType: type}; + value.type = { shape: { dims }, tensorType: type }; index = this._allData.push(value) - 1; dataIndices.set(initializer.name()!, index); } @@ -436,7 +439,7 @@ class GraphImpl implements Graph, Graph.Transformer { let name = nodeProto!.name(); if (!name) { // assign a name to the node if it doesn't have one - for (let pick = 0;; pick++) { + for (let pick = 0; ; pick++) { name = `unnamed_${nodeProto!.opType()}_${pick}`; if (!nodesIndices.has(name)) { // an unique name is found. break. @@ -518,9 +521,9 @@ class GraphImpl implements Graph, Graph.Transformer { private checkIsAcyclic() { // go through the graph and check for cycles or other fatal inconsistencies const starters: Set = new Set(); - this._allInputIndices.forEach(i => { + this._allInputIndices.forEach((i) => { const data = this._allData[i]; - data._to.forEach(j => { + data._to.forEach((j) => { starters.add(j); }); }); @@ -545,7 +548,7 @@ class GraphImpl implements Graph, Graph.Transformer { throw new Error('node outputs should not be initialized'); } if (data._from !== nodeIndex) { - throw new Error('from property of the Value object doesn\'t match index of Node being processed'); + throw new Error("from property of the Value object doesn't match index of Node being processed"); } data._to.forEach((downstreamNodeIndex) => { // back edge found - cyclic @@ -600,10 +603,9 @@ class GraphImpl implements Graph, Graph.Transformer { this._nodes[nodePossition] = this._nodes[i]; } nodePossition++; - } else { // delete all output values - this._nodes[i].outputs.forEach(ind => { + this._nodes[i].outputs.forEach((ind) => { this._allData[ind]._from = -2; }); } @@ -656,7 +658,7 @@ class GraphImpl implements Graph, Graph.Transformer { } // find the node that the current value is linking to and update its input reference - this._allData[i].to.forEach(node => { + this._allData[i].to.forEach((node) => { ind = this._nodes[node].inputs.indexOf(i + offset); if (ind !== -1) { this._nodes[node].inputs[ind] = i; @@ -699,7 +701,7 @@ class GraphImpl implements Graph, Graph.Transformer { const delIndex = this._allData[node.inputs[i]].to.indexOf(nodeIndex); // should not happen if (delIndex === -1) { - throw new Error('The Value object doesn\'t have the current Node in it\'s \'to\' property '); + throw new Error("The Value object doesn't have the current Node in it's 'to' property "); } this._allData[node.inputs[i]].to.splice(delIndex, 1); } @@ -719,7 +721,7 @@ class GraphImpl implements Graph, Graph.Transformer { const replaceIndex = this._nodes[nodeIndex].inputs.indexOf(outputValueIndex); // should not happen if (replaceIndex === -1) { - throw new Error('The Node object doesn\'t have the output Value in it\'s \'inputs\' property '); + throw new Error("The Node object doesn't have the output Value in it's 'inputs' property "); } this._nodes[nodeIndex].inputs[replaceIndex] = inputValueIndex; this._allData[inputValueIndex].to.push(nodeIndex); @@ -741,7 +743,7 @@ class GraphImpl implements Graph, Graph.Transformer { } // the second output should not be referenced by any other node if (node.outputs.length === 2 && this._allData[node.outputs[1]]._to.length !== 0) { - throw new Error('Dropout nodes\'s second output should not be referenced by other nodes'); + throw new Error("Dropout nodes's second output should not be referenced by other nodes"); } this.deleteNode(nodeIndex); } @@ -781,24 +783,28 @@ class GraphImpl implements Graph, Graph.Transformer { if (child.opType === 'Clip') { if (child.inputs.length === 1) { try { - node.attributes.set( - 'activation_params', 'floats', - [child.attributes.getFloat('min'), child.attributes.getFloat('max')]); + node.attributes.set('activation_params', 'floats', [ + child.attributes.getFloat('min'), + child.attributes.getFloat('max'), + ]); } catch (e) { node.attributes.set('activation_params', 'floats', [MIN_CLIP, MAX_CLIP]); } } else if ( - child.inputs.length >= 3 && this._allData[child.inputs[1]].tensor !== undefined && - this._allData[child.inputs[2]].tensor !== undefined) { + child.inputs.length >= 3 && + this._allData[child.inputs[1]].tensor !== undefined && + this._allData[child.inputs[2]].tensor !== undefined + ) { node.attributes.set('activation_params', 'floats', [ - this._allData[child.inputs[1]].tensor!.floatData[0], this._allData[child.inputs[2]].tensor!.floatData[0] + this._allData[child.inputs[1]].tensor!.floatData[0], + this._allData[child.inputs[2]].tensor!.floatData[0], ]); } else { // Skip fusion with clip node since clip min and clip max are not coming from initializer continue; } } - node.attributes.set('activation', 'string', (child.opType)); + node.attributes.set('activation', 'string', child.opType); this.deleteNode(next[0]); } } diff --git a/js/web/lib/onnxjs/instrument.ts b/js/web/lib/onnxjs/instrument.ts index 4f865503d50ec..df6a1777054fd 100644 --- a/js/web/lib/onnxjs/instrument.ts +++ b/js/web/lib/onnxjs/instrument.ts @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from 'onnxruntime-common'; +import { Env } from 'onnxruntime-common'; -import {WebGLContext} from './backends/webgl/webgl-context'; +import { WebGLContext } from './backends/webgl/webgl-context'; export declare namespace Logger { export interface SeverityTypeMap { @@ -16,7 +16,7 @@ export declare namespace Logger { export type Severity = keyof SeverityTypeMap; - export type Provider = 'none'|'console'; + export type Provider = 'none' | 'console'; /** * Logging config that used to control the behavior of logger @@ -121,28 +121,33 @@ const SEVERITY_VALUE = { info: 2000, warning: 4000, error: 5000, - fatal: 6000 + fatal: 6000, }; -const LOGGER_PROVIDER_MAP: {readonly [provider: string]: Readonly} = { +const LOGGER_PROVIDER_MAP: { readonly [provider: string]: Readonly } = { ['none']: new NoOpLoggerProvider(), - ['console']: new ConsoleLoggerProvider() + ['console']: new ConsoleLoggerProvider(), }; const LOGGER_DEFAULT_CONFIG = { provider: 'console', minimalSeverity: 'warning', logDateTime: true, - logSourceLocation: false + logSourceLocation: false, +}; +let LOGGER_CONFIG_MAP: { [category: string]: Readonly> } = { + ['']: LOGGER_DEFAULT_CONFIG as Required, }; -let LOGGER_CONFIG_MAP: - {[category: string]: Readonly>} = {['']: LOGGER_DEFAULT_CONFIG as Required}; function log(category: string): Logger.CategorizedLogger; function log(severity: Logger.Severity, content: string): void; function log(severity: Logger.Severity, category: string, content: string): void; function log(severity: Logger.Severity, arg1: string, arg2?: string): void; function log( - arg0: string|Logger.Severity, arg1?: string, arg2?: string|number, arg3?: number): Logger.CategorizedLogger|void { + arg0: string | Logger.Severity, + arg1?: string, + arg2?: string | number, + arg3?: number, +): Logger.CategorizedLogger | void { if (arg1 === undefined) { // log(category: string): Logger.CategorizedLogger; return createCategorizedLogger(arg0); @@ -169,7 +174,7 @@ function createCategorizedLogger(category: string): Logger.CategorizedLogger { info: log.info.bind(null, category), warning: log.warning.bind(null, category), error: log.error.bind(null, category), - fatal: log.fatal.bind(null, category) + fatal: log.fatal.bind(null, category), }; } @@ -233,9 +238,9 @@ namespace log { LOGGER_CONFIG_MAP[category] = { provider: config.provider || previousConfig.provider, minimalSeverity: config.minimalSeverity || previousConfig.minimalSeverity, - logDateTime: (config.logDateTime === undefined) ? previousConfig.logDateTime : config.logDateTime, - logSourceLocation: (config.logSourceLocation === undefined) ? previousConfig.logSourceLocation : - config.logSourceLocation + logDateTime: config.logDateTime === undefined ? previousConfig.logDateTime : config.logDateTime, + logSourceLocation: + config.logSourceLocation === undefined ? previousConfig.logSourceLocation : config.logSourceLocation, }; } @@ -261,10 +266,10 @@ export declare namespace Profiler { flushIntervalInMilliseconds?: number; } - export type EventCategory = 'session'|'node'|'op'|'backend'; + export type EventCategory = 'session' | 'node' | 'op' | 'backend'; export interface Event { - end(): void|Promise; + end(): void | Promise; } } // TODO @@ -272,8 +277,13 @@ export declare namespace Profiler { class Event implements Profiler.Event { constructor( - public category: Profiler.EventCategory, public name: string, public startTime: number, - private endCallback: (e: Event) => void|Promise, public timer?: WebGLQuery, public ctx?: WebGLContext) {} + public category: Profiler.EventCategory, + public name: string, + public startTime: number, + private endCallback: (e: Event) => void | Promise, + public timer?: WebGLQuery, + public ctx?: WebGLContext, + ) {} async end() { return this.endCallback(this); @@ -291,7 +301,11 @@ class Event implements Profiler.Event { class EventRecord { constructor( - public category: Profiler.EventCategory, public name: string, public startTime: number, public endTime: number) {} + public category: Profiler.EventCategory, + public name: string, + public startTime: number, + public endTime: number, + ) {} } export class Profiler { @@ -329,8 +343,12 @@ export class Profiler { event(category: Profiler.EventCategory, name: string, func: () => T, ctx?: WebGLContext): T; event(category: Profiler.EventCategory, name: string, func: () => Promise, ctx?: WebGLContext): Promise; - event(category: Profiler.EventCategory, name: string, func: () => T | Promise, ctx?: WebGLContext): T - |Promise { + event( + category: Profiler.EventCategory, + name: string, + func: () => T | Promise, + ctx?: WebGLContext, + ): T | Promise { const event = this._started ? this.begin(category, name, ctx) : undefined; let isPromise = false; @@ -340,33 +358,38 @@ export class Profiler { if (res && typeof (res as Promise).then === 'function') { isPromise = true; return new Promise((resolve, reject) => { - (res as Promise) - .then( - async value => { // fulfilled - if (event) { - await event.end(); - } - resolve(value); - }, - async reason => { // rejected - if (event) { - await event.end(); - } - reject(reason); - }); + (res as Promise).then( + async (value) => { + // fulfilled + if (event) { + await event.end(); + } + resolve(value); + }, + async (reason) => { + // rejected + if (event) { + await event.end(); + } + reject(reason); + }, + ); }); } if (!isPromise && event) { const eventRes = event.end(); if (eventRes && typeof eventRes.then === 'function') { return new Promise((resolve, reject) => { - (eventRes).then( - () => { // fulfilled - resolve(res); - }, - (reason) => { // rejected - reject(reason); - }); + eventRes.then( + () => { + // fulfilled + resolve(res); + }, + (reason) => { + // rejected + reject(reason); + }, + ); }); } } @@ -381,10 +404,10 @@ export class Profiler { if (ctx === undefined) { const startTime = now(); this.flush(startTime); - return new Event(category, name, startTime, e => this.endSync(e)); + return new Event(category, name, startTime, (e) => this.endSync(e)); } else { const timer: WebGLQuery = ctx.beginTimer(); - return new Event(category, name, 0, async e => this.end(e), timer, ctx); + return new Event(category, name, 0, async (e) => this.end(e), timer, ctx); } } @@ -407,18 +430,23 @@ export class Profiler { private logOneEvent(event: EventRecord) { Logger.verbose( - `Profiler.${event.category}`, - `${(event.endTime - event.startTime).toFixed(2)}ms on event '${event.name}' at ${event.endTime.toFixed(2)}`); + `Profiler.${event.category}`, + `${(event.endTime - event.startTime).toFixed(2)}ms on event '${event.name}' at ${event.endTime.toFixed(2)}`, + ); } private flush(currentTime: number) { - if (this._timingEvents.length - this._flushPointer >= this._flushBatchSize || - currentTime - this._flushTime >= this._flushIntervalInMilliseconds) { + if ( + this._timingEvents.length - this._flushPointer >= this._flushBatchSize || + currentTime - this._flushTime >= this._flushIntervalInMilliseconds + ) { // should flush when either batch size accumlated or interval elepsed - for (const previousPointer = this._flushPointer; this._flushPointer < previousPointer + this._flushBatchSize && - this._flushPointer < this._timingEvents.length; - this._flushPointer++) { + for ( + const previousPointer = this._flushPointer; + this._flushPointer < previousPointer + this._flushBatchSize && this._flushPointer < this._timingEvents.length; + this._flushPointer++ + ) { this.logOneEvent(this._timingEvents[this._flushPointer]); } @@ -444,4 +472,4 @@ export class Profiler { /** * returns a number to represent the current timestamp in a resolution as high as possible. */ -export const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now; +export const now = typeof performance !== 'undefined' && performance.now ? () => performance.now() : Date.now; diff --git a/js/web/lib/onnxjs/model.ts b/js/web/lib/onnxjs/model.ts index 8e689626011be..a43d419b70aa6 100644 --- a/js/web/lib/onnxjs/model.ts +++ b/js/web/lib/onnxjs/model.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {flatbuffers} from 'flatbuffers'; +import { flatbuffers } from 'flatbuffers'; -import {Graph} from './graph'; -import {OpSet} from './opset'; -import {onnxruntime} from './ort-schema/flatbuffers/ort-generated'; -import {onnx} from './ort-schema/protobuf/onnx'; -import {LongUtil} from './util'; +import { Graph } from './graph'; +import { OpSet } from './opset'; +import { onnxruntime } from './ort-schema/flatbuffers/ort-generated'; +import { onnx } from './ort-schema/protobuf/onnx'; +import { LongUtil } from './util'; import ortFbs = onnxruntime.experimental.fbs; @@ -16,7 +16,7 @@ export class Model { constructor() {} load(buf: Uint8Array, graphInitializer?: Graph.Initializer, isOrtFormat?: boolean): void { - let onnxError: Error|undefined; + let onnxError: Error | undefined; if (!isOrtFormat) { // isOrtFormat === false || isOrtFormat === undefined try { @@ -48,8 +48,10 @@ export class Model { throw new Error('only support ONNX model with IR_VERSION>=3'); } - this._opsets = - modelProto.opsetImport.map(i => ({domain: i.domain as string, version: LongUtil.longToNumber(i.version!)})); + this._opsets = modelProto.opsetImport.map((i) => ({ + domain: i.domain as string, + version: LongUtil.longToNumber(i.version!), + })); this._graph = Graph.from(modelProto.graph!, graphInitializer); } @@ -64,7 +66,7 @@ export class Model { this._opsets = []; for (let i = 0; i < ortModel.opsetImportLength(); i++) { const opsetId = ortModel.opsetImport(i)!; - this._opsets.push({domain: opsetId?.domain() as string, version: LongUtil.longToNumber(opsetId.version()!)}); + this._opsets.push({ domain: opsetId?.domain() as string, version: LongUtil.longToNumber(opsetId.version()!) }); } this._graph = Graph.from(ortModel.graph()!, graphInitializer); diff --git a/js/web/lib/onnxjs/operators.ts b/js/web/lib/onnxjs/operators.ts index 4d664f6dcda5a..289cf03570f0f 100644 --- a/js/web/lib/onnxjs/operators.ts +++ b/js/web/lib/onnxjs/operators.ts @@ -1,19 +1,27 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceHandler} from './backend'; -import {Graph} from './graph'; -import {Tensor} from './tensor'; +import { InferenceHandler } from './backend'; +import { Graph } from './graph'; +import { Tensor } from './tensor'; export type OperatorImplementation = (inferenceHandler: InferenceHandler, inputs: Tensor[], context: T) => Tensor[]; export type OperatorInitialization = (node: Graph.Node, graph: Graph) => T; export interface Operator { readonly impl: OperatorImplementation; - readonly context: Graph.Node|unknown; + readonly context: Graph.Node | unknown; } -export const NUMBER_TYPES: readonly Tensor.DataType[] = - ['float32', 'float64', 'int32', 'int16', 'int8', 'uint16', 'uint32', 'uint8']; +export const NUMBER_TYPES: readonly Tensor.DataType[] = [ + 'float32', + 'float64', + 'int32', + 'int16', + 'int8', + 'uint16', + 'uint32', + 'uint8', +]; export const INT_TYPES: readonly Tensor.DataType[] = ['int32', 'int16', 'int8', 'uint16', 'uint32', 'uint8']; export const FLOAT_TYPES: readonly Tensor.DataType[] = ['float32', 'float64']; diff --git a/js/web/lib/onnxjs/opset.ts b/js/web/lib/onnxjs/opset.ts index e7eb3251babc5..27bfe0a627596 100644 --- a/js/web/lib/onnxjs/opset.ts +++ b/js/web/lib/onnxjs/opset.ts @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Graph} from './graph'; -import {OperatorImplementation, OperatorInitialization} from './operators'; +import { Graph } from './graph'; +import { OperatorImplementation, OperatorInitialization } from './operators'; export interface OpSet { domain: string; @@ -12,14 +12,14 @@ export declare namespace OpSet { /** * Domain of an opset, it can be an empty string(default value, represent for ai.onnx), or 'ai.onnx.ml' */ - type Domain = ''|'ai.onnx.ml'|'com.microsoft'; + type Domain = '' | 'ai.onnx.ml' | 'com.microsoft'; /** * A resolve rule consists of 4 or 5 items: opType, opSetDomain, versionSelector, operatorImplementation and * operatorInitialization (optional) */ - type ResolveRule = [ - string, Domain, string, OperatorImplementation - ]|[string, Domain, string, OperatorImplementation, OperatorInitialization]; + type ResolveRule = + | [string, Domain, string, OperatorImplementation] + | [string, Domain, string, OperatorImplementation, OperatorInitialization]; } export function resolveOperator(node: Graph.Node, opsets: readonly OpSet[], rules: readonly OpSet.ResolveRule[]) { @@ -30,20 +30,25 @@ export function resolveOperator(node: Graph.Node, opsets: readonly OpSet[], rule const opImpl = rule[3]; const opInit = rule[4]; - if (node.opType === opType) { // operator type matches + if (node.opType === opType) { + // operator type matches for (const opset of opsets) { // opset '' and 'ai.onnx' are considered the same. - if (opset.domain === domain || (opset.domain === 'ai.onnx' && domain === '')) { // opset domain found + if (opset.domain === domain || (opset.domain === 'ai.onnx' && domain === '')) { + // opset domain found if (matchSelector(opset.version, versionSelector)) { - return {opImpl, opInit}; + return { opImpl, opInit }; } } } } } - throw new TypeError(`cannot resolve operator '${node.opType}' with opsets: ${ - opsets.map(set => `${set.domain || 'ai.onnx'} v${set.version}`).join(', ')}`); + throw new TypeError( + `cannot resolve operator '${node.opType}' with opsets: ${opsets + .map((set) => `${set.domain || 'ai.onnx'} v${set.version}`) + .join(', ')}`, + ); } function matchSelector(version: number, selector: string): boolean { diff --git a/js/web/lib/onnxjs/ort-schema/flatbuffers/ort-generated.ts b/js/web/lib/onnxjs/ort-schema/flatbuffers/ort-generated.ts index 32758c2bfd8b7..c0c608d559f81 100644 --- a/js/web/lib/onnxjs/ort-schema/flatbuffers/ort-generated.ts +++ b/js/web/lib/onnxjs/ort-schema/flatbuffers/ort-generated.ts @@ -1,7 +1,7 @@ // automatically generated by the FlatBuffers compiler, do not modify /* eslint-disable */ -import {flatbuffers} from 'flatbuffers'; +import { flatbuffers } from 'flatbuffers'; /** * @enum {number} @@ -20,7 +20,7 @@ export namespace onnxruntime.experimental.fbs { TENSORS = 9, GRAPHS = 10, SPARSE_TENSOR = 11, - SPARSE_TENSORS = 12 + SPARSE_TENSORS = 12, } } @@ -28,7 +28,11 @@ export namespace onnxruntime.experimental.fbs { * @enum {number} */ export namespace onnxruntime.experimental.fbs { - export enum DimensionValueType {UNKNOWN = 0, VALUE = 1, PARAM = 2} + export enum DimensionValueType { + UNKNOWN = 0, + VALUE = 1, + PARAM = 2, + } } /** @@ -64,14 +68,22 @@ export namespace onnxruntime.experimental.fbs { * @enum {number} */ export namespace onnxruntime.experimental.fbs { - export enum NodeType {Primitive = 0, Fused = 1} + export enum NodeType { + Primitive = 0, + Fused = 1, + } } /** * @enum {number} */ export namespace onnxruntime.experimental.fbs { - export enum TypeInfoValue {NONE = 0, tensor_type = 1, sequence_type = 2, map_type = 3} + export enum TypeInfoValue { + NONE = 0, + tensor_type = 1, + sequence_type = 2, + map_type = 3, + } } /** @@ -79,7 +91,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class Shape { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -117,11 +129,14 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Dimension= obj * @returns onnxruntime.experimental.fbs.Dimension */ - dim(index: number, obj?: onnxruntime.experimental.fbs.Dimension): onnxruntime.experimental.fbs.Dimension|null { + dim(index: number, obj?: onnxruntime.experimental.fbs.Dimension): onnxruntime.experimental.fbs.Dimension | null { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? (obj || new onnxruntime.experimental.fbs.Dimension()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Dimension()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -189,7 +204,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class Dimension { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -226,20 +241,23 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.DimensionValue= obj * @returns onnxruntime.experimental.fbs.DimensionValue|null */ - value(obj?: onnxruntime.experimental.fbs.DimensionValue): onnxruntime.experimental.fbs.DimensionValue|null { + value(obj?: onnxruntime.experimental.fbs.DimensionValue): onnxruntime.experimental.fbs.DimensionValue | null { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? (obj || new onnxruntime.experimental.fbs.DimensionValue()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.DimensionValue()).__init( + this.bb!.__indirect(this.bb_pos + offset), + this.bb!, + ) + : null; } /** * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - denotation(): string|null; - denotation(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - denotation(optionalEncoding?: any): string|Uint8Array|null { + denotation(): string | null; + denotation(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + denotation(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 6); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -277,8 +295,10 @@ export namespace onnxruntime.experimental.fbs { } static createDimension( - builder: flatbuffers.Builder, valueOffset: flatbuffers.Offset, - denotationOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + valueOffset: flatbuffers.Offset, + denotationOffset: flatbuffers.Offset, + ): flatbuffers.Offset { Dimension.startDimension(builder); Dimension.addValue(builder, valueOffset); Dimension.addDenotation(builder, denotationOffset); @@ -291,7 +311,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class DimensionValue { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -329,8 +349,9 @@ export namespace onnxruntime.experimental.fbs { */ dimType(): onnxruntime.experimental.fbs.DimensionValueType { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? /** */ (this.bb!.readInt8(this.bb_pos + offset)) : - onnxruntime.experimental.fbs.DimensionValueType.UNKNOWN; + return offset + ? /** */ this.bb!.readInt8(this.bb_pos + offset) + : onnxruntime.experimental.fbs.DimensionValueType.UNKNOWN; } /** @@ -345,9 +366,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - dimParam(): string|null; - dimParam(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - dimParam(optionalEncoding?: any): string|Uint8Array|null { + dimParam(): string | null; + dimParam(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + dimParam(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 8); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -393,8 +414,11 @@ export namespace onnxruntime.experimental.fbs { } static createDimensionValue( - builder: flatbuffers.Builder, dimType: onnxruntime.experimental.fbs.DimensionValueType, - dimValue: flatbuffers.Long, dimParamOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + dimType: onnxruntime.experimental.fbs.DimensionValueType, + dimValue: flatbuffers.Long, + dimParamOffset: flatbuffers.Offset, + ): flatbuffers.Offset { DimensionValue.startDimensionValue(builder); DimensionValue.addDimType(builder, dimType); DimensionValue.addDimValue(builder, dimValue); @@ -408,7 +432,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class TensorTypeAndShape { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -436,8 +460,10 @@ export namespace onnxruntime.experimental.fbs { * @param TensorTypeAndShape= obj * @returns TensorTypeAndShape */ - static getSizePrefixedRootAsTensorTypeAndShape(bb: flatbuffers.ByteBuffer, obj?: TensorTypeAndShape): - TensorTypeAndShape { + static getSizePrefixedRootAsTensorTypeAndShape( + bb: flatbuffers.ByteBuffer, + obj?: TensorTypeAndShape, + ): TensorTypeAndShape { bb.setPosition(bb.position() + flatbuffers.SIZE_PREFIX_LENGTH); return (obj || new TensorTypeAndShape()).__init(bb.readInt32(bb.position()) + bb.position(), bb); } @@ -447,19 +473,20 @@ export namespace onnxruntime.experimental.fbs { */ elemType(): onnxruntime.experimental.fbs.TensorDataType { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? /** */ (this.bb!.readInt32(this.bb_pos + offset)) : - onnxruntime.experimental.fbs.TensorDataType.UNDEFINED; + return offset + ? /** */ this.bb!.readInt32(this.bb_pos + offset) + : onnxruntime.experimental.fbs.TensorDataType.UNDEFINED; } /** * @param onnxruntime.experimental.fbs.Shape= obj * @returns onnxruntime.experimental.fbs.Shape|null */ - shape(obj?: onnxruntime.experimental.fbs.Shape): onnxruntime.experimental.fbs.Shape|null { + shape(obj?: onnxruntime.experimental.fbs.Shape): onnxruntime.experimental.fbs.Shape | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? (obj || new onnxruntime.experimental.fbs.Shape()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Shape()).__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) + : null; } /** @@ -495,8 +522,10 @@ export namespace onnxruntime.experimental.fbs { } static createTensorTypeAndShape( - builder: flatbuffers.Builder, elemType: onnxruntime.experimental.fbs.TensorDataType, - shapeOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + elemType: onnxruntime.experimental.fbs.TensorDataType, + shapeOffset: flatbuffers.Offset, + ): flatbuffers.Offset { TensorTypeAndShape.startTensorTypeAndShape(builder); TensorTypeAndShape.addElemType(builder, elemType); TensorTypeAndShape.addShape(builder, shapeOffset); @@ -509,7 +538,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class MapType { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -547,19 +576,23 @@ export namespace onnxruntime.experimental.fbs { */ keyType(): onnxruntime.experimental.fbs.TensorDataType { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? /** */ (this.bb!.readInt32(this.bb_pos + offset)) : - onnxruntime.experimental.fbs.TensorDataType.UNDEFINED; + return offset + ? /** */ this.bb!.readInt32(this.bb_pos + offset) + : onnxruntime.experimental.fbs.TensorDataType.UNDEFINED; } /** * @param onnxruntime.experimental.fbs.TypeInfo= obj * @returns onnxruntime.experimental.fbs.TypeInfo|null */ - valueType(obj?: onnxruntime.experimental.fbs.TypeInfo): onnxruntime.experimental.fbs.TypeInfo|null { + valueType(obj?: onnxruntime.experimental.fbs.TypeInfo): onnxruntime.experimental.fbs.TypeInfo | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? (obj || new onnxruntime.experimental.fbs.TypeInfo()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.TypeInfo()).__init( + this.bb!.__indirect(this.bb_pos + offset), + this.bb!, + ) + : null; } /** @@ -595,8 +628,10 @@ export namespace onnxruntime.experimental.fbs { } static createMapType( - builder: flatbuffers.Builder, keyType: onnxruntime.experimental.fbs.TensorDataType, - valueTypeOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + keyType: onnxruntime.experimental.fbs.TensorDataType, + valueTypeOffset: flatbuffers.Offset, + ): flatbuffers.Offset { MapType.startMapType(builder); MapType.addKeyType(builder, keyType); MapType.addValueType(builder, valueTypeOffset); @@ -609,7 +644,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class SequenceType { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -646,11 +681,14 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.TypeInfo= obj * @returns onnxruntime.experimental.fbs.TypeInfo|null */ - elemType(obj?: onnxruntime.experimental.fbs.TypeInfo): onnxruntime.experimental.fbs.TypeInfo|null { + elemType(obj?: onnxruntime.experimental.fbs.TypeInfo): onnxruntime.experimental.fbs.TypeInfo | null { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? (obj || new onnxruntime.experimental.fbs.TypeInfo()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.TypeInfo()).__init( + this.bb!.__indirect(this.bb_pos + offset), + this.bb!, + ) + : null; } /** @@ -689,7 +727,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class EdgeEnd { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -732,8 +770,11 @@ export namespace onnxruntime.experimental.fbs { * @returns flatbuffers.Offset */ static createEdgeEnd( - builder: flatbuffers.Builder, node_index: number, src_arg_index: number, - dst_arg_index: number): flatbuffers.Offset { + builder: flatbuffers.Builder, + node_index: number, + src_arg_index: number, + dst_arg_index: number, + ): flatbuffers.Offset { builder.prep(4, 12); builder.writeInt32(dst_arg_index); builder.writeInt32(src_arg_index); @@ -747,7 +788,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class NodeEdge { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -793,11 +834,14 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.EdgeEnd= obj * @returns onnxruntime.experimental.fbs.EdgeEnd */ - inputEdges(index: number, obj?: onnxruntime.experimental.fbs.EdgeEnd): onnxruntime.experimental.fbs.EdgeEnd|null { + inputEdges(index: number, obj?: onnxruntime.experimental.fbs.EdgeEnd): onnxruntime.experimental.fbs.EdgeEnd | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? (obj || new onnxruntime.experimental.fbs.EdgeEnd()) - .__init(this.bb!.__vector(this.bb_pos + offset) + index * 12, this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.EdgeEnd()).__init( + this.bb!.__vector(this.bb_pos + offset) + index * 12, + this.bb!, + ) + : null; } /** @@ -813,11 +857,17 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.EdgeEnd= obj * @returns onnxruntime.experimental.fbs.EdgeEnd */ - outputEdges(index: number, obj?: onnxruntime.experimental.fbs.EdgeEnd): onnxruntime.experimental.fbs.EdgeEnd|null { + outputEdges( + index: number, + obj?: onnxruntime.experimental.fbs.EdgeEnd, + ): onnxruntime.experimental.fbs.EdgeEnd | null { let offset = this.bb!.__offset(this.bb_pos, 8); - return offset ? (obj || new onnxruntime.experimental.fbs.EdgeEnd()) - .__init(this.bb!.__vector(this.bb_pos + offset) + index * 12, this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.EdgeEnd()).__init( + this.bb!.__vector(this.bb_pos + offset) + index * 12, + this.bb!, + ) + : null; } /** @@ -885,8 +935,11 @@ export namespace onnxruntime.experimental.fbs { } static createNodeEdge( - builder: flatbuffers.Builder, nodeIndex: number, inputEdgesOffset: flatbuffers.Offset, - outputEdgesOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + nodeIndex: number, + inputEdgesOffset: flatbuffers.Offset, + outputEdgesOffset: flatbuffers.Offset, + ): flatbuffers.Offset { NodeEdge.startNodeEdge(builder); NodeEdge.addNodeIndex(builder, nodeIndex); NodeEdge.addInputEdges(builder, inputEdgesOffset); @@ -900,7 +953,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class Node { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -937,9 +990,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - name(): string|null; - name(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - name(optionalEncoding?: any): string|Uint8Array|null { + name(): string | null; + name(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + name(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 4); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -948,9 +1001,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - docString(): string|null; - docString(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - docString(optionalEncoding?: any): string|Uint8Array|null { + docString(): string | null; + docString(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + docString(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 6); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -959,9 +1012,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - domain(): string|null; - domain(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - domain(optionalEncoding?: any): string|Uint8Array|null { + domain(): string | null; + domain(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + domain(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 8); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -986,9 +1039,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - opType(): string|null; - opType(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - opType(optionalEncoding?: any): string|Uint8Array|null { + opType(): string | null; + opType(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + opType(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 14); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -998,17 +1051,18 @@ export namespace onnxruntime.experimental.fbs { */ type(): onnxruntime.experimental.fbs.NodeType { let offset = this.bb!.__offset(this.bb_pos, 16); - return offset ? /** */ (this.bb!.readInt32(this.bb_pos + offset)) : - onnxruntime.experimental.fbs.NodeType.Primitive; + return offset + ? /** */ this.bb!.readInt32(this.bb_pos + offset) + : onnxruntime.experimental.fbs.NodeType.Primitive; } /** * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - executionProviderType(): string|null; - executionProviderType(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - executionProviderType(optionalEncoding?: any): string|Uint8Array|null { + executionProviderType(): string | null; + executionProviderType(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + executionProviderType(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 18); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -1019,8 +1073,8 @@ export namespace onnxruntime.experimental.fbs { * @returns string|Uint8Array */ inputs(index: number): string; - inputs(index: number, optionalEncoding: flatbuffers.Encoding): string|Uint8Array; - inputs(index: number, optionalEncoding?: any): string|Uint8Array|null { + inputs(index: number, optionalEncoding: flatbuffers.Encoding): string | Uint8Array; + inputs(index: number, optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 20); return offset ? this.bb!.__string(this.bb!.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; } @@ -1039,8 +1093,8 @@ export namespace onnxruntime.experimental.fbs { * @returns string|Uint8Array */ outputs(index: number): string; - outputs(index: number, optionalEncoding: flatbuffers.Encoding): string|Uint8Array; - outputs(index: number, optionalEncoding?: any): string|Uint8Array|null { + outputs(index: number, optionalEncoding: flatbuffers.Encoding): string | Uint8Array; + outputs(index: number, optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 22); return offset ? this.bb!.__string(this.bb!.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; } @@ -1058,12 +1112,17 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Attribute= obj * @returns onnxruntime.experimental.fbs.Attribute */ - attributes(index: number, obj?: onnxruntime.experimental.fbs.Attribute): onnxruntime.experimental.fbs.Attribute - |null { + attributes( + index: number, + obj?: onnxruntime.experimental.fbs.Attribute, + ): onnxruntime.experimental.fbs.Attribute | null { let offset = this.bb!.__offset(this.bb_pos, 24); - return offset ? (obj || new onnxruntime.experimental.fbs.Attribute()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Attribute()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -1078,7 +1137,7 @@ export namespace onnxruntime.experimental.fbs { * @param number index * @returns number */ - inputArgCounts(index: number): number|null { + inputArgCounts(index: number): number | null { let offset = this.bb!.__offset(this.bb_pos, 26); return offset ? this.bb!.readInt32(this.bb!.__vector(this.bb_pos + offset) + index * 4) : 0; } @@ -1094,13 +1153,15 @@ export namespace onnxruntime.experimental.fbs { /** * @returns Int32Array */ - inputArgCountsArray(): Int32Array|null { + inputArgCountsArray(): Int32Array | null { let offset = this.bb!.__offset(this.bb_pos, 26); - return offset ? - new Int32Array( - this.bb!.bytes().buffer, this.bb!.bytes().byteOffset + this.bb!.__vector(this.bb_pos + offset), - this.bb!.__vector_len(this.bb_pos + offset)) : - null; + return offset + ? new Int32Array( + this.bb!.bytes().buffer, + this.bb!.bytes().byteOffset + this.bb!.__vector(this.bb_pos + offset), + this.bb!.__vector_len(this.bb_pos + offset), + ) + : null; } /** @@ -1109,8 +1170,8 @@ export namespace onnxruntime.experimental.fbs { * @returns string|Uint8Array */ implicitInputs(index: number): string; - implicitInputs(index: number, optionalEncoding: flatbuffers.Encoding): string|Uint8Array; - implicitInputs(index: number, optionalEncoding?: any): string|Uint8Array|null { + implicitInputs(index: number, optionalEncoding: flatbuffers.Encoding): string | Uint8Array; + implicitInputs(index: number, optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 28); return offset ? this.bb!.__string(this.bb!.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; } @@ -1294,7 +1355,7 @@ export namespace onnxruntime.experimental.fbs { * @param Array. data * @returns flatbuffers.Offset */ - static createInputArgCountsVector(builder: flatbuffers.Builder, data: number[]|Uint8Array): flatbuffers.Offset { + static createInputArgCountsVector(builder: flatbuffers.Builder, data: number[] | Uint8Array): flatbuffers.Offset { builder.startVector(4, data.length, 4); for (let i = data.length - 1; i >= 0; i--) { builder.addInt32(data[i]); @@ -1349,11 +1410,21 @@ export namespace onnxruntime.experimental.fbs { } static createNode( - builder: flatbuffers.Builder, nameOffset: flatbuffers.Offset, docStringOffset: flatbuffers.Offset, - domainOffset: flatbuffers.Offset, sinceVersion: number, index: number, opTypeOffset: flatbuffers.Offset, - type: onnxruntime.experimental.fbs.NodeType, executionProviderTypeOffset: flatbuffers.Offset, - inputsOffset: flatbuffers.Offset, outputsOffset: flatbuffers.Offset, attributesOffset: flatbuffers.Offset, - inputArgCountsOffset: flatbuffers.Offset, implicitInputsOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + nameOffset: flatbuffers.Offset, + docStringOffset: flatbuffers.Offset, + domainOffset: flatbuffers.Offset, + sinceVersion: number, + index: number, + opTypeOffset: flatbuffers.Offset, + type: onnxruntime.experimental.fbs.NodeType, + executionProviderTypeOffset: flatbuffers.Offset, + inputsOffset: flatbuffers.Offset, + outputsOffset: flatbuffers.Offset, + attributesOffset: flatbuffers.Offset, + inputArgCountsOffset: flatbuffers.Offset, + implicitInputsOffset: flatbuffers.Offset, + ): flatbuffers.Offset { Node.startNode(builder); Node.addName(builder, nameOffset); Node.addDocString(builder, docStringOffset); @@ -1377,7 +1448,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class ValueInfo { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -1414,9 +1485,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - name(): string|null; - name(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - name(optionalEncoding?: any): string|Uint8Array|null { + name(): string | null; + name(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + name(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 4); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -1425,9 +1496,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - docString(): string|null; - docString(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - docString(optionalEncoding?: any): string|Uint8Array|null { + docString(): string | null; + docString(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + docString(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 6); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -1436,11 +1507,14 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.TypeInfo= obj * @returns onnxruntime.experimental.fbs.TypeInfo|null */ - type(obj?: onnxruntime.experimental.fbs.TypeInfo): onnxruntime.experimental.fbs.TypeInfo|null { + type(obj?: onnxruntime.experimental.fbs.TypeInfo): onnxruntime.experimental.fbs.TypeInfo | null { let offset = this.bb!.__offset(this.bb_pos, 8); - return offset ? (obj || new onnxruntime.experimental.fbs.TypeInfo()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.TypeInfo()).__init( + this.bb!.__indirect(this.bb_pos + offset), + this.bb!, + ) + : null; } /** @@ -1484,8 +1558,11 @@ export namespace onnxruntime.experimental.fbs { } static createValueInfo( - builder: flatbuffers.Builder, nameOffset: flatbuffers.Offset, docStringOffset: flatbuffers.Offset, - typeOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + nameOffset: flatbuffers.Offset, + docStringOffset: flatbuffers.Offset, + typeOffset: flatbuffers.Offset, + ): flatbuffers.Offset { ValueInfo.startValueInfo(builder); ValueInfo.addName(builder, nameOffset); ValueInfo.addDocString(builder, docStringOffset); @@ -1499,7 +1576,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class TypeInfo { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -1536,9 +1613,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - denotation(): string|null; - denotation(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - denotation(optionalEncoding?: any): string|Uint8Array|null { + denotation(): string | null; + denotation(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + denotation(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 4); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -1548,15 +1625,16 @@ export namespace onnxruntime.experimental.fbs { */ valueType(): onnxruntime.experimental.fbs.TypeInfoValue { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? /** */ (this.bb!.readUint8(this.bb_pos + offset)) : - onnxruntime.experimental.fbs.TypeInfoValue.NONE; + return offset + ? /** */ this.bb!.readUint8(this.bb_pos + offset) + : onnxruntime.experimental.fbs.TypeInfoValue.NONE; } /** * @param flatbuffers.Table obj * @returns ?flatbuffers.Table */ - value(obj: T): T|null { + value(obj: T): T | null { let offset = this.bb!.__offset(this.bb_pos, 8); return offset ? this.bb!.__union(obj, this.bb_pos + offset) : null; } @@ -1602,8 +1680,11 @@ export namespace onnxruntime.experimental.fbs { } static createTypeInfo( - builder: flatbuffers.Builder, denotationOffset: flatbuffers.Offset, - valueType: onnxruntime.experimental.fbs.TypeInfoValue, valueOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + denotationOffset: flatbuffers.Offset, + valueType: onnxruntime.experimental.fbs.TypeInfoValue, + valueOffset: flatbuffers.Offset, + ): flatbuffers.Offset { TypeInfo.startTypeInfo(builder); TypeInfo.addDenotation(builder, denotationOffset); TypeInfo.addValueType(builder, valueType); @@ -1617,7 +1698,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class OperatorSetId { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -1654,9 +1735,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - domain(): string|null; - domain(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - domain(optionalEncoding?: any): string|Uint8Array|null { + domain(): string | null; + domain(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + domain(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 4); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -1702,7 +1783,10 @@ export namespace onnxruntime.experimental.fbs { } static createOperatorSetId( - builder: flatbuffers.Builder, domainOffset: flatbuffers.Offset, version: flatbuffers.Long): flatbuffers.Offset { + builder: flatbuffers.Builder, + domainOffset: flatbuffers.Offset, + version: flatbuffers.Long, + ): flatbuffers.Offset { OperatorSetId.startOperatorSetId(builder); OperatorSetId.addDomain(builder, domainOffset); OperatorSetId.addVersion(builder, version); @@ -1715,7 +1799,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class Tensor { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -1752,9 +1836,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - name(): string|null; - name(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - name(optionalEncoding?: any): string|Uint8Array|null { + name(): string | null; + name(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + name(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 4); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -1763,9 +1847,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - docString(): string|null; - docString(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - docString(optionalEncoding?: any): string|Uint8Array|null { + docString(): string | null; + docString(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + docString(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 6); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -1774,10 +1858,11 @@ export namespace onnxruntime.experimental.fbs { * @param number index * @returns flatbuffers.Long */ - dims(index: number): flatbuffers.Long|null { + dims(index: number): flatbuffers.Long | null { let offset = this.bb!.__offset(this.bb_pos, 8); - return offset ? this.bb!.readInt64(this.bb!.__vector(this.bb_pos + offset) + index * 8) : - this.bb!.createLong(0, 0); + return offset + ? this.bb!.readInt64(this.bb!.__vector(this.bb_pos + offset) + index * 8) + : this.bb!.createLong(0, 0); } /** @@ -1793,15 +1878,16 @@ export namespace onnxruntime.experimental.fbs { */ dataType(): onnxruntime.experimental.fbs.TensorDataType { let offset = this.bb!.__offset(this.bb_pos, 10); - return offset ? /** */ (this.bb!.readInt32(this.bb_pos + offset)) : - onnxruntime.experimental.fbs.TensorDataType.UNDEFINED; + return offset + ? /** */ this.bb!.readInt32(this.bb_pos + offset) + : onnxruntime.experimental.fbs.TensorDataType.UNDEFINED; } /** * @param number index * @returns number */ - rawData(index: number): number|null { + rawData(index: number): number | null { let offset = this.bb!.__offset(this.bb_pos, 12); return offset ? this.bb!.readUint8(this.bb!.__vector(this.bb_pos + offset) + index) : 0; } @@ -1817,13 +1903,15 @@ export namespace onnxruntime.experimental.fbs { /** * @returns Uint8Array */ - rawDataArray(): Uint8Array|null { + rawDataArray(): Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 12); - return offset ? - new Uint8Array( - this.bb!.bytes().buffer, this.bb!.bytes().byteOffset + this.bb!.__vector(this.bb_pos + offset), - this.bb!.__vector_len(this.bb_pos + offset)) : - null; + return offset + ? new Uint8Array( + this.bb!.bytes().buffer, + this.bb!.bytes().byteOffset + this.bb!.__vector(this.bb_pos + offset), + this.bb!.__vector_len(this.bb_pos + offset), + ) + : null; } /** @@ -1832,8 +1920,8 @@ export namespace onnxruntime.experimental.fbs { * @returns string|Uint8Array */ stringData(index: number): string; - stringData(index: number, optionalEncoding: flatbuffers.Encoding): string|Uint8Array; - stringData(index: number, optionalEncoding?: any): string|Uint8Array|null { + stringData(index: number, optionalEncoding: flatbuffers.Encoding): string | Uint8Array; + stringData(index: number, optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 14); return offset ? this.bb!.__string(this.bb!.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; } @@ -1919,7 +2007,7 @@ export namespace onnxruntime.experimental.fbs { * @param Array. data * @returns flatbuffers.Offset */ - static createRawDataVector(builder: flatbuffers.Builder, data: number[]|Uint8Array): flatbuffers.Offset { + static createRawDataVector(builder: flatbuffers.Builder, data: number[] | Uint8Array): flatbuffers.Offset { builder.startVector(1, data.length, 1); for (let i = data.length - 1; i >= 0; i--) { builder.addInt8(data[i]); @@ -1974,9 +2062,14 @@ export namespace onnxruntime.experimental.fbs { } static createTensor( - builder: flatbuffers.Builder, nameOffset: flatbuffers.Offset, docStringOffset: flatbuffers.Offset, - dimsOffset: flatbuffers.Offset, dataType: onnxruntime.experimental.fbs.TensorDataType, - rawDataOffset: flatbuffers.Offset, stringDataOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + nameOffset: flatbuffers.Offset, + docStringOffset: flatbuffers.Offset, + dimsOffset: flatbuffers.Offset, + dataType: onnxruntime.experimental.fbs.TensorDataType, + rawDataOffset: flatbuffers.Offset, + stringDataOffset: flatbuffers.Offset, + ): flatbuffers.Offset { Tensor.startTensor(builder); Tensor.addName(builder, nameOffset); Tensor.addDocString(builder, docStringOffset); @@ -1993,7 +2086,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class SparseTensor { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -2030,32 +2123,33 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Tensor= obj * @returns onnxruntime.experimental.fbs.Tensor|null */ - values(obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor|null { + values(obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor | null { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? (obj || new onnxruntime.experimental.fbs.Tensor()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Tensor()).__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) + : null; } /** * @param onnxruntime.experimental.fbs.Tensor= obj * @returns onnxruntime.experimental.fbs.Tensor|null */ - indices(obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor|null { + indices(obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? (obj || new onnxruntime.experimental.fbs.Tensor()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Tensor()).__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) + : null; } /** * @param number index * @returns flatbuffers.Long */ - dims(index: number): flatbuffers.Long|null { + dims(index: number): flatbuffers.Long | null { let offset = this.bb!.__offset(this.bb_pos, 8); - return offset ? this.bb!.readInt64(this.bb!.__vector(this.bb_pos + offset) + index * 8) : - this.bb!.createLong(0, 0); + return offset + ? this.bb!.readInt64(this.bb!.__vector(this.bb_pos + offset) + index * 8) + : this.bb!.createLong(0, 0); } /** @@ -2128,8 +2222,11 @@ export namespace onnxruntime.experimental.fbs { } static createSparseTensor( - builder: flatbuffers.Builder, valuesOffset: flatbuffers.Offset, indicesOffset: flatbuffers.Offset, - dimsOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + valuesOffset: flatbuffers.Offset, + indicesOffset: flatbuffers.Offset, + dimsOffset: flatbuffers.Offset, + ): flatbuffers.Offset { SparseTensor.startSparseTensor(builder); SparseTensor.addValues(builder, valuesOffset); SparseTensor.addIndices(builder, indicesOffset); @@ -2143,7 +2240,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class Attribute { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -2180,9 +2277,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - name(): string|null; - name(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - name(optionalEncoding?: any): string|Uint8Array|null { + name(): string | null; + name(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + name(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 4); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -2191,9 +2288,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - docString(): string|null; - docString(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - docString(optionalEncoding?: any): string|Uint8Array|null { + docString(): string | null; + docString(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + docString(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 6); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -2203,8 +2300,9 @@ export namespace onnxruntime.experimental.fbs { */ type(): onnxruntime.experimental.fbs.AttributeType { let offset = this.bb!.__offset(this.bb_pos, 8); - return offset ? /** */ (this.bb!.readInt32(this.bb_pos + offset)) : - onnxruntime.experimental.fbs.AttributeType.UNDEFINED; + return offset + ? /** */ this.bb!.readInt32(this.bb_pos + offset) + : onnxruntime.experimental.fbs.AttributeType.UNDEFINED; } /** @@ -2227,9 +2325,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - s(): string|null; - s(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - s(optionalEncoding?: any): string|Uint8Array|null { + s(): string | null; + s(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + s(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 14); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -2238,29 +2336,29 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Tensor= obj * @returns onnxruntime.experimental.fbs.Tensor|null */ - t(obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor|null { + t(obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor | null { let offset = this.bb!.__offset(this.bb_pos, 16); - return offset ? (obj || new onnxruntime.experimental.fbs.Tensor()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Tensor()).__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) + : null; } /** * @param onnxruntime.experimental.fbs.Graph= obj * @returns onnxruntime.experimental.fbs.Graph|null */ - g(obj?: onnxruntime.experimental.fbs.Graph): onnxruntime.experimental.fbs.Graph|null { + g(obj?: onnxruntime.experimental.fbs.Graph): onnxruntime.experimental.fbs.Graph | null { let offset = this.bb!.__offset(this.bb_pos, 18); - return offset ? (obj || new onnxruntime.experimental.fbs.Graph()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Graph()).__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) + : null; } /** * @param number index * @returns number */ - floats(index: number): number|null { + floats(index: number): number | null { let offset = this.bb!.__offset(this.bb_pos, 20); return offset ? this.bb!.readFloat32(this.bb!.__vector(this.bb_pos + offset) + index * 4) : 0; } @@ -2276,23 +2374,26 @@ export namespace onnxruntime.experimental.fbs { /** * @returns Float32Array */ - floatsArray(): Float32Array|null { + floatsArray(): Float32Array | null { let offset = this.bb!.__offset(this.bb_pos, 20); - return offset ? - new Float32Array( - this.bb!.bytes().buffer, this.bb!.bytes().byteOffset + this.bb!.__vector(this.bb_pos + offset), - this.bb!.__vector_len(this.bb_pos + offset)) : - null; + return offset + ? new Float32Array( + this.bb!.bytes().buffer, + this.bb!.bytes().byteOffset + this.bb!.__vector(this.bb_pos + offset), + this.bb!.__vector_len(this.bb_pos + offset), + ) + : null; } /** * @param number index * @returns flatbuffers.Long */ - ints(index: number): flatbuffers.Long|null { + ints(index: number): flatbuffers.Long | null { let offset = this.bb!.__offset(this.bb_pos, 22); - return offset ? this.bb!.readInt64(this.bb!.__vector(this.bb_pos + offset) + index * 8) : - this.bb!.createLong(0, 0); + return offset + ? this.bb!.readInt64(this.bb!.__vector(this.bb_pos + offset) + index * 8) + : this.bb!.createLong(0, 0); } /** @@ -2309,8 +2410,8 @@ export namespace onnxruntime.experimental.fbs { * @returns string|Uint8Array */ strings(index: number): string; - strings(index: number, optionalEncoding: flatbuffers.Encoding): string|Uint8Array; - strings(index: number, optionalEncoding?: any): string|Uint8Array|null { + strings(index: number, optionalEncoding: flatbuffers.Encoding): string | Uint8Array; + strings(index: number, optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 24); return offset ? this.bb!.__string(this.bb!.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; } @@ -2328,11 +2429,14 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Tensor= obj * @returns onnxruntime.experimental.fbs.Tensor */ - tensors(index: number, obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor|null { + tensors(index: number, obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor | null { let offset = this.bb!.__offset(this.bb_pos, 26); - return offset ? (obj || new onnxruntime.experimental.fbs.Tensor()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Tensor()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -2348,11 +2452,14 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Graph= obj * @returns onnxruntime.experimental.fbs.Graph */ - graphs(index: number, obj?: onnxruntime.experimental.fbs.Graph): onnxruntime.experimental.fbs.Graph|null { + graphs(index: number, obj?: onnxruntime.experimental.fbs.Graph): onnxruntime.experimental.fbs.Graph | null { let offset = this.bb!.__offset(this.bb_pos, 28); - return offset ? (obj || new onnxruntime.experimental.fbs.Graph()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Graph()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -2447,7 +2554,7 @@ export namespace onnxruntime.experimental.fbs { * @param Array. data * @returns flatbuffers.Offset */ - static createFloatsVector(builder: flatbuffers.Builder, data: number[]|Uint8Array): flatbuffers.Offset { + static createFloatsVector(builder: flatbuffers.Builder, data: number[] | Uint8Array): flatbuffers.Offset { builder.startVector(4, data.length, 4); for (let i = data.length - 1; i >= 0; i--) { builder.addFloat32(data[i]); @@ -2589,11 +2696,21 @@ export namespace onnxruntime.experimental.fbs { } static createAttribute( - builder: flatbuffers.Builder, nameOffset: flatbuffers.Offset, docStringOffset: flatbuffers.Offset, - type: onnxruntime.experimental.fbs.AttributeType, f: number, i: flatbuffers.Long, sOffset: flatbuffers.Offset, - tOffset: flatbuffers.Offset, gOffset: flatbuffers.Offset, floatsOffset: flatbuffers.Offset, - intsOffset: flatbuffers.Offset, stringsOffset: flatbuffers.Offset, tensorsOffset: flatbuffers.Offset, - graphsOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + nameOffset: flatbuffers.Offset, + docStringOffset: flatbuffers.Offset, + type: onnxruntime.experimental.fbs.AttributeType, + f: number, + i: flatbuffers.Long, + sOffset: flatbuffers.Offset, + tOffset: flatbuffers.Offset, + gOffset: flatbuffers.Offset, + floatsOffset: flatbuffers.Offset, + intsOffset: flatbuffers.Offset, + stringsOffset: flatbuffers.Offset, + tensorsOffset: flatbuffers.Offset, + graphsOffset: flatbuffers.Offset, + ): flatbuffers.Offset { Attribute.startAttribute(builder); Attribute.addName(builder, nameOffset); Attribute.addDocString(builder, docStringOffset); @@ -2617,7 +2734,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class Graph { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -2655,11 +2772,14 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Tensor= obj * @returns onnxruntime.experimental.fbs.Tensor */ - initializers(index: number, obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor|null { + initializers(index: number, obj?: onnxruntime.experimental.fbs.Tensor): onnxruntime.experimental.fbs.Tensor | null { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? (obj || new onnxruntime.experimental.fbs.Tensor()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Tensor()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -2675,11 +2795,17 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.ValueInfo= obj * @returns onnxruntime.experimental.fbs.ValueInfo */ - nodeArgs(index: number, obj?: onnxruntime.experimental.fbs.ValueInfo): onnxruntime.experimental.fbs.ValueInfo|null { + nodeArgs( + index: number, + obj?: onnxruntime.experimental.fbs.ValueInfo, + ): onnxruntime.experimental.fbs.ValueInfo | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? (obj || new onnxruntime.experimental.fbs.ValueInfo()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.ValueInfo()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -2695,11 +2821,14 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Node= obj * @returns onnxruntime.experimental.fbs.Node */ - nodes(index: number, obj?: onnxruntime.experimental.fbs.Node): onnxruntime.experimental.fbs.Node|null { + nodes(index: number, obj?: onnxruntime.experimental.fbs.Node): onnxruntime.experimental.fbs.Node | null { let offset = this.bb!.__offset(this.bb_pos, 8); - return offset ? (obj || new onnxruntime.experimental.fbs.Node()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Node()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -2723,11 +2852,17 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.NodeEdge= obj * @returns onnxruntime.experimental.fbs.NodeEdge */ - nodeEdges(index: number, obj?: onnxruntime.experimental.fbs.NodeEdge): onnxruntime.experimental.fbs.NodeEdge|null { + nodeEdges( + index: number, + obj?: onnxruntime.experimental.fbs.NodeEdge, + ): onnxruntime.experimental.fbs.NodeEdge | null { let offset = this.bb!.__offset(this.bb_pos, 12); - return offset ? (obj || new onnxruntime.experimental.fbs.NodeEdge()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.NodeEdge()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -2744,8 +2879,8 @@ export namespace onnxruntime.experimental.fbs { * @returns string|Uint8Array */ inputs(index: number): string; - inputs(index: number, optionalEncoding: flatbuffers.Encoding): string|Uint8Array; - inputs(index: number, optionalEncoding?: any): string|Uint8Array|null { + inputs(index: number, optionalEncoding: flatbuffers.Encoding): string | Uint8Array; + inputs(index: number, optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 14); return offset ? this.bb!.__string(this.bb!.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; } @@ -2764,8 +2899,8 @@ export namespace onnxruntime.experimental.fbs { * @returns string|Uint8Array */ outputs(index: number): string; - outputs(index: number, optionalEncoding: flatbuffers.Encoding): string|Uint8Array; - outputs(index: number, optionalEncoding?: any): string|Uint8Array|null { + outputs(index: number, optionalEncoding: flatbuffers.Encoding): string | Uint8Array; + outputs(index: number, optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 16); return offset ? this.bb!.__string(this.bb!.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null; } @@ -2783,12 +2918,17 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.SparseTensor= obj * @returns onnxruntime.experimental.fbs.SparseTensor */ - sparseInitializers(index: number, obj?: onnxruntime.experimental.fbs.SparseTensor): - onnxruntime.experimental.fbs.SparseTensor|null { + sparseInitializers( + index: number, + obj?: onnxruntime.experimental.fbs.SparseTensor, + ): onnxruntime.experimental.fbs.SparseTensor | null { let offset = this.bb!.__offset(this.bb_pos, 18); - return offset ? (obj || new onnxruntime.experimental.fbs.SparseTensor()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.SparseTensor()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -3001,8 +3141,10 @@ export namespace onnxruntime.experimental.fbs { * @param Array. data * @returns flatbuffers.Offset */ - static createSparseInitializersVector(builder: flatbuffers.Builder, data: flatbuffers.Offset[]): - flatbuffers.Offset { + static createSparseInitializersVector( + builder: flatbuffers.Builder, + data: flatbuffers.Offset[], + ): flatbuffers.Offset { builder.startVector(4, data.length, 4); for (let i = data.length - 1; i >= 0; i--) { builder.addOffset(data[i]); @@ -3028,10 +3170,16 @@ export namespace onnxruntime.experimental.fbs { } static createGraph( - builder: flatbuffers.Builder, initializersOffset: flatbuffers.Offset, nodeArgsOffset: flatbuffers.Offset, - nodesOffset: flatbuffers.Offset, maxNodeIndex: number, nodeEdgesOffset: flatbuffers.Offset, - inputsOffset: flatbuffers.Offset, outputsOffset: flatbuffers.Offset, - sparseInitializersOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + initializersOffset: flatbuffers.Offset, + nodeArgsOffset: flatbuffers.Offset, + nodesOffset: flatbuffers.Offset, + maxNodeIndex: number, + nodeEdgesOffset: flatbuffers.Offset, + inputsOffset: flatbuffers.Offset, + outputsOffset: flatbuffers.Offset, + sparseInitializersOffset: flatbuffers.Offset, + ): flatbuffers.Offset { Graph.startGraph(builder); Graph.addInitializers(builder, initializersOffset); Graph.addNodeArgs(builder, nodeArgsOffset); @@ -3050,7 +3198,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class Model { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -3096,12 +3244,17 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.OperatorSetId= obj * @returns onnxruntime.experimental.fbs.OperatorSetId */ - opsetImport(index: number, obj?: onnxruntime.experimental.fbs.OperatorSetId): - onnxruntime.experimental.fbs.OperatorSetId|null { + opsetImport( + index: number, + obj?: onnxruntime.experimental.fbs.OperatorSetId, + ): onnxruntime.experimental.fbs.OperatorSetId | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? (obj || new onnxruntime.experimental.fbs.OperatorSetId()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.OperatorSetId()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -3116,9 +3269,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - producerName(): string|null; - producerName(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - producerName(optionalEncoding?: any): string|Uint8Array|null { + producerName(): string | null; + producerName(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + producerName(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 8); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -3127,9 +3280,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - producerVersion(): string|null; - producerVersion(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - producerVersion(optionalEncoding?: any): string|Uint8Array|null { + producerVersion(): string | null; + producerVersion(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + producerVersion(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 10); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -3138,9 +3291,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - domain(): string|null; - domain(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - domain(optionalEncoding?: any): string|Uint8Array|null { + domain(): string | null; + domain(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + domain(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 12); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -3157,9 +3310,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - docString(): string|null; - docString(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - docString(optionalEncoding?: any): string|Uint8Array|null { + docString(): string | null; + docString(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + docString(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 16); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -3168,20 +3321,20 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Graph= obj * @returns onnxruntime.experimental.fbs.Graph|null */ - graph(obj?: onnxruntime.experimental.fbs.Graph): onnxruntime.experimental.fbs.Graph|null { + graph(obj?: onnxruntime.experimental.fbs.Graph): onnxruntime.experimental.fbs.Graph | null { let offset = this.bb!.__offset(this.bb_pos, 18); - return offset ? (obj || new onnxruntime.experimental.fbs.Graph()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Graph()).__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) + : null; } /** * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - graphDocString(): string|null; - graphDocString(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - graphDocString(optionalEncoding?: any): string|Uint8Array|null { + graphDocString(): string | null; + graphDocString(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + graphDocString(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 20); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -3296,10 +3449,17 @@ export namespace onnxruntime.experimental.fbs { } static createModel( - builder: flatbuffers.Builder, irVersion: flatbuffers.Long, opsetImportOffset: flatbuffers.Offset, - producerNameOffset: flatbuffers.Offset, producerVersionOffset: flatbuffers.Offset, - domainOffset: flatbuffers.Offset, modelVersion: flatbuffers.Long, docStringOffset: flatbuffers.Offset, - graphOffset: flatbuffers.Offset, graphDocStringOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + irVersion: flatbuffers.Long, + opsetImportOffset: flatbuffers.Offset, + producerNameOffset: flatbuffers.Offset, + producerVersionOffset: flatbuffers.Offset, + domainOffset: flatbuffers.Offset, + modelVersion: flatbuffers.Long, + docStringOffset: flatbuffers.Offset, + graphOffset: flatbuffers.Offset, + graphDocStringOffset: flatbuffers.Offset, + ): flatbuffers.Offset { Model.startModel(builder); Model.addIrVersion(builder, irVersion); Model.addOpsetImport(builder, opsetImportOffset); @@ -3319,7 +3479,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class KernelCreateInfos { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -3347,8 +3507,10 @@ export namespace onnxruntime.experimental.fbs { * @param KernelCreateInfos= obj * @returns KernelCreateInfos */ - static getSizePrefixedRootAsKernelCreateInfos(bb: flatbuffers.ByteBuffer, obj?: KernelCreateInfos): - KernelCreateInfos { + static getSizePrefixedRootAsKernelCreateInfos( + bb: flatbuffers.ByteBuffer, + obj?: KernelCreateInfos, + ): KernelCreateInfos { bb.setPosition(bb.position() + flatbuffers.SIZE_PREFIX_LENGTH); return (obj || new KernelCreateInfos()).__init(bb.readInt32(bb.position()) + bb.position(), bb); } @@ -3357,7 +3519,7 @@ export namespace onnxruntime.experimental.fbs { * @param number index * @returns number */ - nodeIndices(index: number): number|null { + nodeIndices(index: number): number | null { let offset = this.bb!.__offset(this.bb_pos, 4); return offset ? this.bb!.readUint32(this.bb!.__vector(this.bb_pos + offset) + index * 4) : 0; } @@ -3373,23 +3535,26 @@ export namespace onnxruntime.experimental.fbs { /** * @returns Uint32Array */ - nodeIndicesArray(): Uint32Array|null { + nodeIndicesArray(): Uint32Array | null { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? - new Uint32Array( - this.bb!.bytes().buffer, this.bb!.bytes().byteOffset + this.bb!.__vector(this.bb_pos + offset), - this.bb!.__vector_len(this.bb_pos + offset)) : - null; + return offset + ? new Uint32Array( + this.bb!.bytes().buffer, + this.bb!.bytes().byteOffset + this.bb!.__vector(this.bb_pos + offset), + this.bb!.__vector_len(this.bb_pos + offset), + ) + : null; } /** * @param number index * @returns flatbuffers.Long */ - kernelDefHashes(index: number): flatbuffers.Long|null { + kernelDefHashes(index: number): flatbuffers.Long | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? this.bb!.readUint64(this.bb!.__vector(this.bb_pos + offset) + index * 8) : - this.bb!.createLong(0, 0); + return offset + ? this.bb!.readUint64(this.bb!.__vector(this.bb_pos + offset) + index * 8) + : this.bb!.createLong(0, 0); } /** @@ -3420,7 +3585,7 @@ export namespace onnxruntime.experimental.fbs { * @param Array. data * @returns flatbuffers.Offset */ - static createNodeIndicesVector(builder: flatbuffers.Builder, data: number[]|Uint8Array): flatbuffers.Offset { + static createNodeIndicesVector(builder: flatbuffers.Builder, data: number[] | Uint8Array): flatbuffers.Offset { builder.startVector(4, data.length, 4); for (let i = data.length - 1; i >= 0; i--) { builder.addInt32(data[i]); @@ -3475,8 +3640,10 @@ export namespace onnxruntime.experimental.fbs { } static createKernelCreateInfos( - builder: flatbuffers.Builder, nodeIndicesOffset: flatbuffers.Offset, - kernelDefHashesOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + nodeIndicesOffset: flatbuffers.Offset, + kernelDefHashesOffset: flatbuffers.Offset, + ): flatbuffers.Offset { KernelCreateInfos.startKernelCreateInfos(builder); KernelCreateInfos.addNodeIndices(builder, nodeIndicesOffset); KernelCreateInfos.addKernelDefHashes(builder, kernelDefHashesOffset); @@ -3489,7 +3656,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class SubGraphSessionState { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -3517,8 +3684,10 @@ export namespace onnxruntime.experimental.fbs { * @param SubGraphSessionState= obj * @returns SubGraphSessionState */ - static getSizePrefixedRootAsSubGraphSessionState(bb: flatbuffers.ByteBuffer, obj?: SubGraphSessionState): - SubGraphSessionState { + static getSizePrefixedRootAsSubGraphSessionState( + bb: flatbuffers.ByteBuffer, + obj?: SubGraphSessionState, + ): SubGraphSessionState { bb.setPosition(bb.position() + flatbuffers.SIZE_PREFIX_LENGTH); return (obj || new SubGraphSessionState()).__init(bb.readInt32(bb.position()) + bb.position(), bb); } @@ -3527,9 +3696,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - graphId(): string|null; - graphId(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - graphId(optionalEncoding?: any): string|Uint8Array|null { + graphId(): string | null; + graphId(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + graphId(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 4); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -3538,11 +3707,14 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.SessionState= obj * @returns onnxruntime.experimental.fbs.SessionState|null */ - sessionState(obj?: onnxruntime.experimental.fbs.SessionState): onnxruntime.experimental.fbs.SessionState|null { + sessionState(obj?: onnxruntime.experimental.fbs.SessionState): onnxruntime.experimental.fbs.SessionState | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? (obj || new onnxruntime.experimental.fbs.SessionState()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.SessionState()).__init( + this.bb!.__indirect(this.bb_pos + offset), + this.bb!, + ) + : null; } /** @@ -3574,13 +3746,15 @@ export namespace onnxruntime.experimental.fbs { */ static endSubGraphSessionState(builder: flatbuffers.Builder): flatbuffers.Offset { let offset = builder.endObject(); - builder.requiredField(offset, 4); // graph_id + builder.requiredField(offset, 4); // graph_id return offset; } static createSubGraphSessionState( - builder: flatbuffers.Builder, graphIdOffset: flatbuffers.Offset, - sessionStateOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + graphIdOffset: flatbuffers.Offset, + sessionStateOffset: flatbuffers.Offset, + ): flatbuffers.Offset { SubGraphSessionState.startSubGraphSessionState(builder); SubGraphSessionState.addGraphId(builder, graphIdOffset); SubGraphSessionState.addSessionState(builder, sessionStateOffset); @@ -3593,7 +3767,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class SessionState { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -3630,11 +3804,16 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.KernelCreateInfos= obj * @returns onnxruntime.experimental.fbs.KernelCreateInfos|null */ - kernels(obj?: onnxruntime.experimental.fbs.KernelCreateInfos): onnxruntime.experimental.fbs.KernelCreateInfos|null { + kernels( + obj?: onnxruntime.experimental.fbs.KernelCreateInfos, + ): onnxruntime.experimental.fbs.KernelCreateInfos | null { let offset = this.bb!.__offset(this.bb_pos, 4); - return offset ? (obj || new onnxruntime.experimental.fbs.KernelCreateInfos()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.KernelCreateInfos()).__init( + this.bb!.__indirect(this.bb_pos + offset), + this.bb!, + ) + : null; } /** @@ -3642,12 +3821,17 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.SubGraphSessionState= obj * @returns onnxruntime.experimental.fbs.SubGraphSessionState */ - subGraphSessionStates(index: number, obj?: onnxruntime.experimental.fbs.SubGraphSessionState): - onnxruntime.experimental.fbs.SubGraphSessionState|null { + subGraphSessionStates( + index: number, + obj?: onnxruntime.experimental.fbs.SubGraphSessionState, + ): onnxruntime.experimental.fbs.SubGraphSessionState | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? (obj || new onnxruntime.experimental.fbs.SubGraphSessionState()) - .__init(this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.SubGraphSessionState()).__init( + this.bb!.__indirect(this.bb!.__vector(this.bb_pos + offset) + index * 4), + this.bb!, + ) + : null; } /** @@ -3686,8 +3870,10 @@ export namespace onnxruntime.experimental.fbs { * @param Array. data * @returns flatbuffers.Offset */ - static createSubGraphSessionStatesVector(builder: flatbuffers.Builder, data: flatbuffers.Offset[]): - flatbuffers.Offset { + static createSubGraphSessionStatesVector( + builder: flatbuffers.Builder, + data: flatbuffers.Offset[], + ): flatbuffers.Offset { builder.startVector(4, data.length, 4); for (let i = data.length - 1; i >= 0; i--) { builder.addOffset(data[i]); @@ -3713,8 +3899,10 @@ export namespace onnxruntime.experimental.fbs { } static createSessionState( - builder: flatbuffers.Builder, kernelsOffset: flatbuffers.Offset, - subGraphSessionStatesOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + kernelsOffset: flatbuffers.Offset, + subGraphSessionStatesOffset: flatbuffers.Offset, + ): flatbuffers.Offset { SessionState.startSessionState(builder); SessionState.addKernels(builder, kernelsOffset); SessionState.addSubGraphSessionStates(builder, subGraphSessionStatesOffset); @@ -3727,7 +3915,7 @@ export namespace onnxruntime.experimental.fbs { */ export namespace onnxruntime.experimental.fbs { export class InferenceSession { - bb: flatbuffers.ByteBuffer|null = null; + bb: flatbuffers.ByteBuffer | null = null; bb_pos = 0; /** @@ -3772,9 +3960,9 @@ export namespace onnxruntime.experimental.fbs { * @param flatbuffers.Encoding= optionalEncoding * @returns string|Uint8Array|null */ - ortVersion(): string|null; - ortVersion(optionalEncoding: flatbuffers.Encoding): string|Uint8Array|null; - ortVersion(optionalEncoding?: any): string|Uint8Array|null { + ortVersion(): string | null; + ortVersion(optionalEncoding: flatbuffers.Encoding): string | Uint8Array | null; + ortVersion(optionalEncoding?: any): string | Uint8Array | null { let offset = this.bb!.__offset(this.bb_pos, 4); return offset ? this.bb!.__string(this.bb_pos + offset, optionalEncoding) : null; } @@ -3783,22 +3971,25 @@ export namespace onnxruntime.experimental.fbs { * @param onnxruntime.experimental.fbs.Model= obj * @returns onnxruntime.experimental.fbs.Model|null */ - model(obj?: onnxruntime.experimental.fbs.Model): onnxruntime.experimental.fbs.Model|null { + model(obj?: onnxruntime.experimental.fbs.Model): onnxruntime.experimental.fbs.Model | null { let offset = this.bb!.__offset(this.bb_pos, 6); - return offset ? (obj || new onnxruntime.experimental.fbs.Model()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.Model()).__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) + : null; } /** * @param onnxruntime.experimental.fbs.SessionState= obj * @returns onnxruntime.experimental.fbs.SessionState|null */ - sessionState(obj?: onnxruntime.experimental.fbs.SessionState): onnxruntime.experimental.fbs.SessionState|null { + sessionState(obj?: onnxruntime.experimental.fbs.SessionState): onnxruntime.experimental.fbs.SessionState | null { let offset = this.bb!.__offset(this.bb_pos, 8); - return offset ? (obj || new onnxruntime.experimental.fbs.SessionState()) - .__init(this.bb!.__indirect(this.bb_pos + offset), this.bb!) : - null; + return offset + ? (obj || new onnxruntime.experimental.fbs.SessionState()).__init( + this.bb!.__indirect(this.bb_pos + offset), + this.bb!, + ) + : null; } /** @@ -3858,8 +4049,11 @@ export namespace onnxruntime.experimental.fbs { } static createInferenceSession( - builder: flatbuffers.Builder, ortVersionOffset: flatbuffers.Offset, modelOffset: flatbuffers.Offset, - sessionStateOffset: flatbuffers.Offset): flatbuffers.Offset { + builder: flatbuffers.Builder, + ortVersionOffset: flatbuffers.Offset, + modelOffset: flatbuffers.Offset, + sessionStateOffset: flatbuffers.Offset, + ): flatbuffers.Offset { InferenceSession.startInferenceSession(builder); InferenceSession.addOrtVersion(builder, ortVersionOffset); InferenceSession.addModel(builder, modelOffset); diff --git a/js/web/lib/onnxjs/ort-schema/protobuf/README.md b/js/web/lib/onnxjs/ort-schema/protobuf/README.md index f5f52c602f1ad..35f61310db9aa 100644 --- a/js/web/lib/onnxjs/ort-schema/protobuf/README.md +++ b/js/web/lib/onnxjs/ort-schema/protobuf/README.md @@ -12,10 +12,10 @@ The ONNX protobuf uses protobufjs@7.2.4, which depends on long@5.2.3, the versio - type export does not work with commonjs. described in https://github.com/dcodeIO/long.js/pull/124. added a "postinstall" script to fix. - in the generated typescript declaration file 'onnx.d.ts', the following line: ```ts - import Long = require("long"); + import Long = require('long'); ``` need to be replaced to fix type import error: ```ts - import Long from "long"; + import Long from 'long'; ``` this replacement is done and code format is also applied to file 'onnx.d.ts'. diff --git a/js/web/lib/onnxjs/ort-schema/protobuf/onnx.js b/js/web/lib/onnxjs/ort-schema/protobuf/onnx.js index 681855132d4e8..24ccb627acff7 100644 --- a/js/web/lib/onnxjs/ort-schema/protobuf/onnx.js +++ b/js/web/lib/onnxjs/ort-schema/protobuf/onnx.js @@ -1,7658 +1,7391 @@ /*eslint-disable block-scoped-var, id-length, no-control-regex, no-magic-numbers, no-prototype-builtins, no-redeclare, no-shadow, no-var, sort-vars*/ -"use strict"; +'use strict'; -var $protobuf = require("protobufjs/minimal"); +var $protobuf = require('protobufjs/minimal'); // Common aliases -var $Reader = $protobuf.Reader, $Writer = $protobuf.Writer, $util = $protobuf.util; +var $Reader = $protobuf.Reader, + $Writer = $protobuf.Writer, + $util = $protobuf.util; // Exported root namespace -var $root = $protobuf.roots["default"] || ($protobuf.roots["default"] = {}); +var $root = $protobuf.roots['default'] || ($protobuf.roots['default'] = {}); + +$root.onnx = (function () { + /** + * Namespace onnx. + * @exports onnx + * @namespace + */ + var onnx = {}; + + /** + * Version enum. + * @name onnx.Version + * @enum {number} + * @property {number} _START_VERSION=0 _START_VERSION value + * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value + * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value + * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value + * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value + * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value + * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value + * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value + * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value + * @property {number} IR_VERSION=9 IR_VERSION value + */ + onnx.Version = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = '_START_VERSION')] = 0; + values[(valuesById[1] = 'IR_VERSION_2017_10_10')] = 1; + values[(valuesById[2] = 'IR_VERSION_2017_10_30')] = 2; + values[(valuesById[3] = 'IR_VERSION_2017_11_3')] = 3; + values[(valuesById[4] = 'IR_VERSION_2019_1_22')] = 4; + values[(valuesById[5] = 'IR_VERSION_2019_3_18')] = 5; + values[(valuesById[6] = 'IR_VERSION_2019_9_19')] = 6; + values[(valuesById[7] = 'IR_VERSION_2020_5_8')] = 7; + values[(valuesById[8] = 'IR_VERSION_2021_7_30')] = 8; + values[(valuesById[9] = 'IR_VERSION')] = 9; + return values; + })(); + + onnx.AttributeProto = (function () { + /** + * Properties of an AttributeProto. + * @memberof onnx + * @interface IAttributeProto + * @property {string|null} [name] AttributeProto name + * @property {string|null} [refAttrName] AttributeProto refAttrName + * @property {string|null} [docString] AttributeProto docString + * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type + * @property {number|null} [f] AttributeProto f + * @property {number|Long|null} [i] AttributeProto i + * @property {Uint8Array|null} [s] AttributeProto s + * @property {onnx.ITensorProto|null} [t] AttributeProto t + * @property {onnx.IGraphProto|null} [g] AttributeProto g + * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor + * @property {onnx.ITypeProto|null} [tp] AttributeProto tp + * @property {Array.|null} [floats] AttributeProto floats + * @property {Array.|null} [ints] AttributeProto ints + * @property {Array.|null} [strings] AttributeProto strings + * @property {Array.|null} [tensors] AttributeProto tensors + * @property {Array.|null} [graphs] AttributeProto graphs + * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors + * @property {Array.|null} [typeProtos] AttributeProto typeProtos + */ + + /** + * Constructs a new AttributeProto. + * @memberof onnx + * @classdesc Represents an AttributeProto. + * @implements IAttributeProto + * @constructor + * @param {onnx.IAttributeProto=} [properties] Properties to set + */ + function AttributeProto(properties) { + this.floats = []; + this.ints = []; + this.strings = []; + this.tensors = []; + this.graphs = []; + this.sparseTensors = []; + this.typeProtos = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * AttributeProto name. + * @member {string} name + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.name = ''; + + /** + * AttributeProto refAttrName. + * @member {string} refAttrName + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.refAttrName = ''; + + /** + * AttributeProto docString. + * @member {string} docString + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.docString = ''; + + /** + * AttributeProto type. + * @member {onnx.AttributeProto.AttributeType} type + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.type = 0; + + /** + * AttributeProto f. + * @member {number} f + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.f = 0; + + /** + * AttributeProto i. + * @member {number|Long} i + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * AttributeProto s. + * @member {Uint8Array} s + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.s = $util.newBuffer([]); + + /** + * AttributeProto t. + * @member {onnx.ITensorProto|null|undefined} t + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.t = null; + + /** + * AttributeProto g. + * @member {onnx.IGraphProto|null|undefined} g + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.g = null; + + /** + * AttributeProto sparseTensor. + * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensor = null; + + /** + * AttributeProto tp. + * @member {onnx.ITypeProto|null|undefined} tp + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tp = null; + + /** + * AttributeProto floats. + * @member {Array.} floats + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.floats = $util.emptyArray; + + /** + * AttributeProto ints. + * @member {Array.} ints + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.ints = $util.emptyArray; + + /** + * AttributeProto strings. + * @member {Array.} strings + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.strings = $util.emptyArray; + + /** + * AttributeProto tensors. + * @member {Array.} tensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tensors = $util.emptyArray; + + /** + * AttributeProto graphs. + * @member {Array.} graphs + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.graphs = $util.emptyArray; + + /** + * AttributeProto sparseTensors. + * @member {Array.} sparseTensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensors = $util.emptyArray; + + /** + * AttributeProto typeProtos. + * @member {Array.} typeProtos + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.typeProtos = $util.emptyArray; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @function create + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto=} [properties] Properties to set + * @returns {onnx.AttributeProto} AttributeProto instance + */ + AttributeProto.create = function create(properties) { + return new AttributeProto(properties); + }; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encode + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.name); + if (message.f != null && Object.hasOwnProperty.call(message, 'f')) + writer.uint32(/* id 2, wireType 5 =*/ 21).float(message.f); + if (message.i != null && Object.hasOwnProperty.call(message, 'i')) + writer.uint32(/* id 3, wireType 0 =*/ 24).int64(message.i); + if (message.s != null && Object.hasOwnProperty.call(message, 's')) + writer.uint32(/* id 4, wireType 2 =*/ 34).bytes(message.s); + if (message.t != null && Object.hasOwnProperty.call(message, 't')) + $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/ 42).fork()).ldelim(); + if (message.g != null && Object.hasOwnProperty.call(message, 'g')) + $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/ 50).fork()).ldelim(); + if (message.floats != null && message.floats.length) { + writer.uint32(/* id 7, wireType 2 =*/ 58).fork(); + for (var i = 0; i < message.floats.length; ++i) writer.float(message.floats[i]); + writer.ldelim(); + } + if (message.ints != null && message.ints.length) { + writer.uint32(/* id 8, wireType 2 =*/ 66).fork(); + for (var i = 0; i < message.ints.length; ++i) writer.int64(message.ints[i]); + writer.ldelim(); + } + if (message.strings != null && message.strings.length) + for (var i = 0; i < message.strings.length; ++i) + writer.uint32(/* id 9, wireType 2 =*/ 74).bytes(message.strings[i]); + if (message.tensors != null && message.tensors.length) + for (var i = 0; i < message.tensors.length; ++i) + $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/ 82).fork()).ldelim(); + if (message.graphs != null && message.graphs.length) + for (var i = 0; i < message.graphs.length; ++i) + $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/ 90).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 13, wireType 2 =*/ 106).string(message.docString); + if (message.tp != null && Object.hasOwnProperty.call(message, 'tp')) + $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/ 114).fork()).ldelim(); + if (message.typeProtos != null && message.typeProtos.length) + for (var i = 0; i < message.typeProtos.length; ++i) + $root.onnx.TypeProto.encode( + message.typeProtos[i], + writer.uint32(/* id 15, wireType 2 =*/ 122).fork(), + ).ldelim(); + if (message.type != null && Object.hasOwnProperty.call(message, 'type')) + writer.uint32(/* id 20, wireType 0 =*/ 160).int32(message.type); + if (message.refAttrName != null && Object.hasOwnProperty.call(message, 'refAttrName')) + writer.uint32(/* id 21, wireType 2 =*/ 170).string(message.refAttrName); + if (message.sparseTensor != null && Object.hasOwnProperty.call(message, 'sparseTensor')) + $root.onnx.SparseTensorProto.encode( + message.sparseTensor, + writer.uint32(/* id 22, wireType 2 =*/ 178).fork(), + ).ldelim(); + if (message.sparseTensors != null && message.sparseTensors.length) + for (var i = 0; i < message.sparseTensors.length; ++i) + $root.onnx.SparseTensorProto.encode( + message.sparseTensors[i], + writer.uint32(/* id 23, wireType 2 =*/ 186).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.AttributeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 21: { + message.refAttrName = reader.string(); + break; + } + case 13: { + message.docString = reader.string(); + break; + } + case 20: { + message.type = reader.int32(); + break; + } + case 2: { + message.f = reader.float(); + break; + } + case 3: { + message.i = reader.int64(); + break; + } + case 4: { + message.s = reader.bytes(); + break; + } + case 5: { + message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 6: { + message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 22: { + message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); + break; + } + case 14: { + message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 7: { + if (!(message.floats && message.floats.length)) message.floats = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.floats.push(reader.float()); + } else message.floats.push(reader.float()); + break; + } + case 8: { + if (!(message.ints && message.ints.length)) message.ints = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.ints.push(reader.int64()); + } else message.ints.push(reader.int64()); + break; + } + case 9: { + if (!(message.strings && message.strings.length)) message.strings = []; + message.strings.push(reader.bytes()); + break; + } + case 10: { + if (!(message.tensors && message.tensors.length)) message.tensors = []; + message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 11: { + if (!(message.graphs && message.graphs.length)) message.graphs = []; + message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); + break; + } + case 23: { + if (!(message.sparseTensors && message.sparseTensors.length)) message.sparseTensors = []; + message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.typeProtos && message.typeProtos.length)) message.typeProtos = []; + message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an AttributeProto message. + * @function verify + * @memberof onnx.AttributeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + AttributeProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.refAttrName != null && message.hasOwnProperty('refAttrName')) + if (!$util.isString(message.refAttrName)) return 'refAttrName: string expected'; + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.type != null && message.hasOwnProperty('type')) + switch (message.type) { + default: + return 'type: enum value expected'; + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 11: + case 13: + case 6: + case 7: + case 8: + case 9: + case 10: + case 12: + case 14: + break; + } + if (message.f != null && message.hasOwnProperty('f')) + if (typeof message.f !== 'number') return 'f: number expected'; + if (message.i != null && message.hasOwnProperty('i')) + if ( + !$util.isInteger(message.i) && + !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high)) + ) + return 'i: integer|Long expected'; + if (message.s != null && message.hasOwnProperty('s')) + if (!((message.s && typeof message.s.length === 'number') || $util.isString(message.s))) + return 's: buffer expected'; + if (message.t != null && message.hasOwnProperty('t')) { + var error = $root.onnx.TensorProto.verify(message.t); + if (error) return 't.' + error; + } + if (message.g != null && message.hasOwnProperty('g')) { + var error = $root.onnx.GraphProto.verify(message.g); + if (error) return 'g.' + error; + } + if (message.sparseTensor != null && message.hasOwnProperty('sparseTensor')) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); + if (error) return 'sparseTensor.' + error; + } + if (message.tp != null && message.hasOwnProperty('tp')) { + var error = $root.onnx.TypeProto.verify(message.tp); + if (error) return 'tp.' + error; + } + if (message.floats != null && message.hasOwnProperty('floats')) { + if (!Array.isArray(message.floats)) return 'floats: array expected'; + for (var i = 0; i < message.floats.length; ++i) + if (typeof message.floats[i] !== 'number') return 'floats: number[] expected'; + } + if (message.ints != null && message.hasOwnProperty('ints')) { + if (!Array.isArray(message.ints)) return 'ints: array expected'; + for (var i = 0; i < message.ints.length; ++i) + if ( + !$util.isInteger(message.ints[i]) && + !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high)) + ) + return 'ints: integer|Long[] expected'; + } + if (message.strings != null && message.hasOwnProperty('strings')) { + if (!Array.isArray(message.strings)) return 'strings: array expected'; + for (var i = 0; i < message.strings.length; ++i) + if ( + !( + (message.strings[i] && typeof message.strings[i].length === 'number') || + $util.isString(message.strings[i]) + ) + ) + return 'strings: buffer[] expected'; + } + if (message.tensors != null && message.hasOwnProperty('tensors')) { + if (!Array.isArray(message.tensors)) return 'tensors: array expected'; + for (var i = 0; i < message.tensors.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.tensors[i]); + if (error) return 'tensors.' + error; + } + } + if (message.graphs != null && message.hasOwnProperty('graphs')) { + if (!Array.isArray(message.graphs)) return 'graphs: array expected'; + for (var i = 0; i < message.graphs.length; ++i) { + var error = $root.onnx.GraphProto.verify(message.graphs[i]); + if (error) return 'graphs.' + error; + } + } + if (message.sparseTensors != null && message.hasOwnProperty('sparseTensors')) { + if (!Array.isArray(message.sparseTensors)) return 'sparseTensors: array expected'; + for (var i = 0; i < message.sparseTensors.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); + if (error) return 'sparseTensors.' + error; + } + } + if (message.typeProtos != null && message.hasOwnProperty('typeProtos')) { + if (!Array.isArray(message.typeProtos)) return 'typeProtos: array expected'; + for (var i = 0; i < message.typeProtos.length; ++i) { + var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); + if (error) return 'typeProtos.' + error; + } + } + return null; + }; + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.AttributeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.AttributeProto} AttributeProto + */ + AttributeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.AttributeProto) return object; + var message = new $root.onnx.AttributeProto(); + if (object.name != null) message.name = String(object.name); + if (object.refAttrName != null) message.refAttrName = String(object.refAttrName); + if (object.docString != null) message.docString = String(object.docString); + switch (object.type) { + default: + if (typeof object.type === 'number') { + message.type = object.type; + break; + } + break; + case 'UNDEFINED': + case 0: + message.type = 0; + break; + case 'FLOAT': + case 1: + message.type = 1; + break; + case 'INT': + case 2: + message.type = 2; + break; + case 'STRING': + case 3: + message.type = 3; + break; + case 'TENSOR': + case 4: + message.type = 4; + break; + case 'GRAPH': + case 5: + message.type = 5; + break; + case 'SPARSE_TENSOR': + case 11: + message.type = 11; + break; + case 'TYPE_PROTO': + case 13: + message.type = 13; + break; + case 'FLOATS': + case 6: + message.type = 6; + break; + case 'INTS': + case 7: + message.type = 7; + break; + case 'STRINGS': + case 8: + message.type = 8; + break; + case 'TENSORS': + case 9: + message.type = 9; + break; + case 'GRAPHS': + case 10: + message.type = 10; + break; + case 'SPARSE_TENSORS': + case 12: + message.type = 12; + break; + case 'TYPE_PROTOS': + case 14: + message.type = 14; + break; + } + if (object.f != null) message.f = Number(object.f); + if (object.i != null) + if ($util.Long) (message.i = $util.Long.fromValue(object.i)).unsigned = false; + else if (typeof object.i === 'string') message.i = parseInt(object.i, 10); + else if (typeof object.i === 'number') message.i = object.i; + else if (typeof object.i === 'object') + message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); + if (object.s != null) + if (typeof object.s === 'string') + $util.base64.decode(object.s, (message.s = $util.newBuffer($util.base64.length(object.s))), 0); + else if (object.s.length >= 0) message.s = object.s; + if (object.t != null) { + if (typeof object.t !== 'object') throw TypeError('.onnx.AttributeProto.t: object expected'); + message.t = $root.onnx.TensorProto.fromObject(object.t); + } + if (object.g != null) { + if (typeof object.g !== 'object') throw TypeError('.onnx.AttributeProto.g: object expected'); + message.g = $root.onnx.GraphProto.fromObject(object.g); + } + if (object.sparseTensor != null) { + if (typeof object.sparseTensor !== 'object') + throw TypeError('.onnx.AttributeProto.sparseTensor: object expected'); + message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); + } + if (object.tp != null) { + if (typeof object.tp !== 'object') throw TypeError('.onnx.AttributeProto.tp: object expected'); + message.tp = $root.onnx.TypeProto.fromObject(object.tp); + } + if (object.floats) { + if (!Array.isArray(object.floats)) throw TypeError('.onnx.AttributeProto.floats: array expected'); + message.floats = []; + for (var i = 0; i < object.floats.length; ++i) message.floats[i] = Number(object.floats[i]); + } + if (object.ints) { + if (!Array.isArray(object.ints)) throw TypeError('.onnx.AttributeProto.ints: array expected'); + message.ints = []; + for (var i = 0; i < object.ints.length; ++i) + if ($util.Long) (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; + else if (typeof object.ints[i] === 'string') message.ints[i] = parseInt(object.ints[i], 10); + else if (typeof object.ints[i] === 'number') message.ints[i] = object.ints[i]; + else if (typeof object.ints[i] === 'object') + message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); + } + if (object.strings) { + if (!Array.isArray(object.strings)) throw TypeError('.onnx.AttributeProto.strings: array expected'); + message.strings = []; + for (var i = 0; i < object.strings.length; ++i) + if (typeof object.strings[i] === 'string') + $util.base64.decode( + object.strings[i], + (message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i]))), + 0, + ); + else if (object.strings[i].length >= 0) message.strings[i] = object.strings[i]; + } + if (object.tensors) { + if (!Array.isArray(object.tensors)) throw TypeError('.onnx.AttributeProto.tensors: array expected'); + message.tensors = []; + for (var i = 0; i < object.tensors.length; ++i) { + if (typeof object.tensors[i] !== 'object') throw TypeError('.onnx.AttributeProto.tensors: object expected'); + message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); + } + } + if (object.graphs) { + if (!Array.isArray(object.graphs)) throw TypeError('.onnx.AttributeProto.graphs: array expected'); + message.graphs = []; + for (var i = 0; i < object.graphs.length; ++i) { + if (typeof object.graphs[i] !== 'object') throw TypeError('.onnx.AttributeProto.graphs: object expected'); + message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); + } + } + if (object.sparseTensors) { + if (!Array.isArray(object.sparseTensors)) throw TypeError('.onnx.AttributeProto.sparseTensors: array expected'); + message.sparseTensors = []; + for (var i = 0; i < object.sparseTensors.length; ++i) { + if (typeof object.sparseTensors[i] !== 'object') + throw TypeError('.onnx.AttributeProto.sparseTensors: object expected'); + message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); + } + } + if (object.typeProtos) { + if (!Array.isArray(object.typeProtos)) throw TypeError('.onnx.AttributeProto.typeProtos: array expected'); + message.typeProtos = []; + for (var i = 0; i < object.typeProtos.length; ++i) { + if (typeof object.typeProtos[i] !== 'object') + throw TypeError('.onnx.AttributeProto.typeProtos: object expected'); + message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); + } + } + return message; + }; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.AttributeProto + * @static + * @param {onnx.AttributeProto} message AttributeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + AttributeProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.floats = []; + object.ints = []; + object.strings = []; + object.tensors = []; + object.graphs = []; + object.typeProtos = []; + object.sparseTensors = []; + } + if (options.defaults) { + object.name = ''; + object.f = 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.i = options.longs === String ? '0' : 0; + if (options.bytes === String) object.s = ''; + else { + object.s = []; + if (options.bytes !== Array) object.s = $util.newBuffer(object.s); + } + object.t = null; + object.g = null; + object.docString = ''; + object.tp = null; + object.type = options.enums === String ? 'UNDEFINED' : 0; + object.refAttrName = ''; + object.sparseTensor = null; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.f != null && message.hasOwnProperty('f')) + object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; + if (message.i != null && message.hasOwnProperty('i')) + if (typeof message.i === 'number') object.i = options.longs === String ? String(message.i) : message.i; + else + object.i = + options.longs === String + ? $util.Long.prototype.toString.call(message.i) + : options.longs === Number + ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() + : message.i; + if (message.s != null && message.hasOwnProperty('s')) + object.s = + options.bytes === String + ? $util.base64.encode(message.s, 0, message.s.length) + : options.bytes === Array + ? Array.prototype.slice.call(message.s) + : message.s; + if (message.t != null && message.hasOwnProperty('t')) + object.t = $root.onnx.TensorProto.toObject(message.t, options); + if (message.g != null && message.hasOwnProperty('g')) + object.g = $root.onnx.GraphProto.toObject(message.g, options); + if (message.floats && message.floats.length) { + object.floats = []; + for (var j = 0; j < message.floats.length; ++j) + object.floats[j] = + options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; + } + if (message.ints && message.ints.length) { + object.ints = []; + for (var j = 0; j < message.ints.length; ++j) + if (typeof message.ints[j] === 'number') + object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; + else + object.ints[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.ints[j]) + : options.longs === Number + ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() + : message.ints[j]; + } + if (message.strings && message.strings.length) { + object.strings = []; + for (var j = 0; j < message.strings.length; ++j) + object.strings[j] = + options.bytes === String + ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) + : options.bytes === Array + ? Array.prototype.slice.call(message.strings[j]) + : message.strings[j]; + } + if (message.tensors && message.tensors.length) { + object.tensors = []; + for (var j = 0; j < message.tensors.length; ++j) + object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); + } + if (message.graphs && message.graphs.length) { + object.graphs = []; + for (var j = 0; j < message.graphs.length; ++j) + object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.tp != null && message.hasOwnProperty('tp')) + object.tp = $root.onnx.TypeProto.toObject(message.tp, options); + if (message.typeProtos && message.typeProtos.length) { + object.typeProtos = []; + for (var j = 0; j < message.typeProtos.length; ++j) + object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); + } + if (message.type != null && message.hasOwnProperty('type')) + object.type = + options.enums === String + ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined + ? message.type + : $root.onnx.AttributeProto.AttributeType[message.type] + : message.type; + if (message.refAttrName != null && message.hasOwnProperty('refAttrName')) + object.refAttrName = message.refAttrName; + if (message.sparseTensor != null && message.hasOwnProperty('sparseTensor')) + object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); + if (message.sparseTensors && message.sparseTensors.length) { + object.sparseTensors = []; + for (var j = 0; j < message.sparseTensors.length; ++j) + object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); + } + return object; + }; + + /** + * Converts this AttributeProto to JSON. + * @function toJSON + * @memberof onnx.AttributeProto + * @instance + * @returns {Object.} JSON object + */ + AttributeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for AttributeProto + * @function getTypeUrl + * @memberof onnx.AttributeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.AttributeProto'; + }; + + /** + * AttributeType enum. + * @name onnx.AttributeProto.AttributeType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} INT=2 INT value + * @property {number} STRING=3 STRING value + * @property {number} TENSOR=4 TENSOR value + * @property {number} GRAPH=5 GRAPH value + * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value + * @property {number} TYPE_PROTO=13 TYPE_PROTO value + * @property {number} FLOATS=6 FLOATS value + * @property {number} INTS=7 INTS value + * @property {number} STRINGS=8 STRINGS value + * @property {number} TENSORS=9 TENSORS value + * @property {number} GRAPHS=10 GRAPHS value + * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value + * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value + */ + AttributeProto.AttributeType = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = 'UNDEFINED')] = 0; + values[(valuesById[1] = 'FLOAT')] = 1; + values[(valuesById[2] = 'INT')] = 2; + values[(valuesById[3] = 'STRING')] = 3; + values[(valuesById[4] = 'TENSOR')] = 4; + values[(valuesById[5] = 'GRAPH')] = 5; + values[(valuesById[11] = 'SPARSE_TENSOR')] = 11; + values[(valuesById[13] = 'TYPE_PROTO')] = 13; + values[(valuesById[6] = 'FLOATS')] = 6; + values[(valuesById[7] = 'INTS')] = 7; + values[(valuesById[8] = 'STRINGS')] = 8; + values[(valuesById[9] = 'TENSORS')] = 9; + values[(valuesById[10] = 'GRAPHS')] = 10; + values[(valuesById[12] = 'SPARSE_TENSORS')] = 12; + values[(valuesById[14] = 'TYPE_PROTOS')] = 14; + return values; + })(); + + return AttributeProto; + })(); + + onnx.ValueInfoProto = (function () { + /** + * Properties of a ValueInfoProto. + * @memberof onnx + * @interface IValueInfoProto + * @property {string|null} [name] ValueInfoProto name + * @property {onnx.ITypeProto|null} [type] ValueInfoProto type + * @property {string|null} [docString] ValueInfoProto docString + */ + + /** + * Constructs a new ValueInfoProto. + * @memberof onnx + * @classdesc Represents a ValueInfoProto. + * @implements IValueInfoProto + * @constructor + * @param {onnx.IValueInfoProto=} [properties] Properties to set + */ + function ValueInfoProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * ValueInfoProto name. + * @member {string} name + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.name = ''; + + /** + * ValueInfoProto type. + * @member {onnx.ITypeProto|null|undefined} type + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.type = null; + + /** + * ValueInfoProto docString. + * @member {string} docString + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.docString = ''; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @function create + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto=} [properties] Properties to set + * @returns {onnx.ValueInfoProto} ValueInfoProto instance + */ + ValueInfoProto.create = function create(properties) { + return new ValueInfoProto(properties); + }; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.name); + if (message.type != null && Object.hasOwnProperty.call(message, 'type')) + $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.docString); + return writer; + }; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.ValueInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 2: { + message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 3: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ValueInfoProto message. + * @function verify + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ValueInfoProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.type != null && message.hasOwnProperty('type')) { + var error = $root.onnx.TypeProto.verify(message.type); + if (error) return 'type.' + error; + } + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + return null; + }; + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ValueInfoProto} ValueInfoProto + */ + ValueInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ValueInfoProto) return object; + var message = new $root.onnx.ValueInfoProto(); + if (object.name != null) message.name = String(object.name); + if (object.type != null) { + if (typeof object.type !== 'object') throw TypeError('.onnx.ValueInfoProto.type: object expected'); + message.type = $root.onnx.TypeProto.fromObject(object.type); + } + if (object.docString != null) message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.ValueInfoProto} message ValueInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ValueInfoProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.name = ''; + object.type = null; + object.docString = ''; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.type != null && message.hasOwnProperty('type')) + object.type = $root.onnx.TypeProto.toObject(message.type, options); + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + return object; + }; + + /** + * Converts this ValueInfoProto to JSON. + * @function toJSON + * @memberof onnx.ValueInfoProto + * @instance + * @returns {Object.} JSON object + */ + ValueInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ValueInfoProto + * @function getTypeUrl + * @memberof onnx.ValueInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.ValueInfoProto'; + }; + + return ValueInfoProto; + })(); + + onnx.NodeProto = (function () { + /** + * Properties of a NodeProto. + * @memberof onnx + * @interface INodeProto + * @property {Array.|null} [input] NodeProto input + * @property {Array.|null} [output] NodeProto output + * @property {string|null} [name] NodeProto name + * @property {string|null} [opType] NodeProto opType + * @property {string|null} [domain] NodeProto domain + * @property {Array.|null} [attribute] NodeProto attribute + * @property {string|null} [docString] NodeProto docString + */ + + /** + * Constructs a new NodeProto. + * @memberof onnx + * @classdesc Represents a NodeProto. + * @implements INodeProto + * @constructor + * @param {onnx.INodeProto=} [properties] Properties to set + */ + function NodeProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * NodeProto input. + * @member {Array.} input + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.input = $util.emptyArray; + + /** + * NodeProto output. + * @member {Array.} output + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.output = $util.emptyArray; + + /** + * NodeProto name. + * @member {string} name + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.name = ''; + + /** + * NodeProto opType. + * @member {string} opType + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.opType = ''; + + /** + * NodeProto domain. + * @member {string} domain + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.domain = ''; + + /** + * NodeProto attribute. + * @member {Array.} attribute + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.attribute = $util.emptyArray; + + /** + * NodeProto docString. + * @member {string} docString + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.docString = ''; + + /** + * Creates a new NodeProto instance using the specified properties. + * @function create + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto=} [properties] Properties to set + * @returns {onnx.NodeProto} NodeProto instance + */ + NodeProto.create = function create(properties) { + return new NodeProto(properties); + }; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encode + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.output[i]); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.name); + if (message.opType != null && Object.hasOwnProperty.call(message, 'opType')) + writer.uint32(/* id 4, wireType 2 =*/ 34).string(message.opType); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + $root.onnx.AttributeProto.encode( + message.attribute[i], + writer.uint32(/* id 5, wireType 2 =*/ 42).fork(), + ).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.docString); + if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + writer.uint32(/* id 7, wireType 2 =*/ 58).string(message.domain); + return writer; + }; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.NodeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.input && message.input.length)) message.input = []; + message.input.push(reader.string()); + break; + } + case 2: { + if (!(message.output && message.output.length)) message.output = []; + message.output.push(reader.string()); + break; + } + case 3: { + message.name = reader.string(); + break; + } + case 4: { + message.opType = reader.string(); + break; + } + case 7: { + message.domain = reader.string(); + break; + } + case 5: { + if (!(message.attribute && message.attribute.length)) message.attribute = []; + message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a NodeProto message. + * @function verify + * @memberof onnx.NodeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + NodeProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.input != null && message.hasOwnProperty('input')) { + if (!Array.isArray(message.input)) return 'input: array expected'; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) return 'input: string[] expected'; + } + if (message.output != null && message.hasOwnProperty('output')) { + if (!Array.isArray(message.output)) return 'output: array expected'; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) return 'output: string[] expected'; + } + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.opType != null && message.hasOwnProperty('opType')) + if (!$util.isString(message.opType)) return 'opType: string expected'; + if (message.domain != null && message.hasOwnProperty('domain')) + if (!$util.isString(message.domain)) return 'domain: string expected'; + if (message.attribute != null && message.hasOwnProperty('attribute')) { + if (!Array.isArray(message.attribute)) return 'attribute: array expected'; + for (var i = 0; i < message.attribute.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attribute[i]); + if (error) return 'attribute.' + error; + } + } + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + return null; + }; + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.NodeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.NodeProto} NodeProto + */ + NodeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.NodeProto) return object; + var message = new $root.onnx.NodeProto(); + if (object.input) { + if (!Array.isArray(object.input)) throw TypeError('.onnx.NodeProto.input: array expected'); + message.input = []; + for (var i = 0; i < object.input.length; ++i) message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) throw TypeError('.onnx.NodeProto.output: array expected'); + message.output = []; + for (var i = 0; i < object.output.length; ++i) message.output[i] = String(object.output[i]); + } + if (object.name != null) message.name = String(object.name); + if (object.opType != null) message.opType = String(object.opType); + if (object.domain != null) message.domain = String(object.domain); + if (object.attribute) { + if (!Array.isArray(object.attribute)) throw TypeError('.onnx.NodeProto.attribute: array expected'); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) { + if (typeof object.attribute[i] !== 'object') throw TypeError('.onnx.NodeProto.attribute: object expected'); + message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); + } + } + if (object.docString != null) message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.NodeProto + * @static + * @param {onnx.NodeProto} message NodeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + NodeProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + } + if (options.defaults) { + object.name = ''; + object.opType = ''; + object.docString = ''; + object.domain = ''; + } + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) object.output[j] = message.output[j]; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.opType != null && message.hasOwnProperty('opType')) object.opType = message.opType; + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + return object; + }; + + /** + * Converts this NodeProto to JSON. + * @function toJSON + * @memberof onnx.NodeProto + * @instance + * @returns {Object.} JSON object + */ + NodeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for NodeProto + * @function getTypeUrl + * @memberof onnx.NodeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.NodeProto'; + }; + + return NodeProto; + })(); + + onnx.TrainingInfoProto = (function () { + /** + * Properties of a TrainingInfoProto. + * @memberof onnx + * @interface ITrainingInfoProto + * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization + * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm + * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding + * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding + */ + + /** + * Constructs a new TrainingInfoProto. + * @memberof onnx + * @classdesc Represents a TrainingInfoProto. + * @implements ITrainingInfoProto + * @constructor + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + */ + function TrainingInfoProto(properties) { + this.initializationBinding = []; + this.updateBinding = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TrainingInfoProto initialization. + * @member {onnx.IGraphProto|null|undefined} initialization + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initialization = null; + + /** + * TrainingInfoProto algorithm. + * @member {onnx.IGraphProto|null|undefined} algorithm + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.algorithm = null; + + /** + * TrainingInfoProto initializationBinding. + * @member {Array.} initializationBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; + + /** + * TrainingInfoProto updateBinding. + * @member {Array.} updateBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.updateBinding = $util.emptyArray; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @function create + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance + */ + TrainingInfoProto.create = function create(properties) { + return new TrainingInfoProto(properties); + }; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.initialization != null && Object.hasOwnProperty.call(message, 'initialization')) + $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + if (message.algorithm != null && Object.hasOwnProperty.call(message, 'algorithm')) + $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if (message.initializationBinding != null && message.initializationBinding.length) + for (var i = 0; i < message.initializationBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.initializationBinding[i], + writer.uint32(/* id 3, wireType 2 =*/ 26).fork(), + ).ldelim(); + if (message.updateBinding != null && message.updateBinding.length) + for (var i = 0; i < message.updateBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.updateBinding[i], + writer.uint32(/* id 4, wireType 2 =*/ 34).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TrainingInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.initializationBinding && message.initializationBinding.length)) + message.initializationBinding = []; + message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 4: { + if (!(message.updateBinding && message.updateBinding.length)) message.updateBinding = []; + message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TrainingInfoProto message. + * @function verify + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TrainingInfoProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.initialization != null && message.hasOwnProperty('initialization')) { + var error = $root.onnx.GraphProto.verify(message.initialization); + if (error) return 'initialization.' + error; + } + if (message.algorithm != null && message.hasOwnProperty('algorithm')) { + var error = $root.onnx.GraphProto.verify(message.algorithm); + if (error) return 'algorithm.' + error; + } + if (message.initializationBinding != null && message.hasOwnProperty('initializationBinding')) { + if (!Array.isArray(message.initializationBinding)) return 'initializationBinding: array expected'; + for (var i = 0; i < message.initializationBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); + if (error) return 'initializationBinding.' + error; + } + } + if (message.updateBinding != null && message.hasOwnProperty('updateBinding')) { + if (!Array.isArray(message.updateBinding)) return 'updateBinding: array expected'; + for (var i = 0; i < message.updateBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); + if (error) return 'updateBinding.' + error; + } + } + return null; + }; + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + */ + TrainingInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TrainingInfoProto) return object; + var message = new $root.onnx.TrainingInfoProto(); + if (object.initialization != null) { + if (typeof object.initialization !== 'object') + throw TypeError('.onnx.TrainingInfoProto.initialization: object expected'); + message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); + } + if (object.algorithm != null) { + if (typeof object.algorithm !== 'object') throw TypeError('.onnx.TrainingInfoProto.algorithm: object expected'); + message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); + } + if (object.initializationBinding) { + if (!Array.isArray(object.initializationBinding)) + throw TypeError('.onnx.TrainingInfoProto.initializationBinding: array expected'); + message.initializationBinding = []; + for (var i = 0; i < object.initializationBinding.length; ++i) { + if (typeof object.initializationBinding[i] !== 'object') + throw TypeError('.onnx.TrainingInfoProto.initializationBinding: object expected'); + message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject( + object.initializationBinding[i], + ); + } + } + if (object.updateBinding) { + if (!Array.isArray(object.updateBinding)) + throw TypeError('.onnx.TrainingInfoProto.updateBinding: array expected'); + message.updateBinding = []; + for (var i = 0; i < object.updateBinding.length; ++i) { + if (typeof object.updateBinding[i] !== 'object') + throw TypeError('.onnx.TrainingInfoProto.updateBinding: object expected'); + message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.TrainingInfoProto} message TrainingInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TrainingInfoProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.initializationBinding = []; + object.updateBinding = []; + } + if (options.defaults) { + object.initialization = null; + object.algorithm = null; + } + if (message.initialization != null && message.hasOwnProperty('initialization')) + object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); + if (message.algorithm != null && message.hasOwnProperty('algorithm')) + object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); + if (message.initializationBinding && message.initializationBinding.length) { + object.initializationBinding = []; + for (var j = 0; j < message.initializationBinding.length; ++j) + object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject( + message.initializationBinding[j], + options, + ); + } + if (message.updateBinding && message.updateBinding.length) { + object.updateBinding = []; + for (var j = 0; j < message.updateBinding.length; ++j) + object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); + } + return object; + }; + + /** + * Converts this TrainingInfoProto to JSON. + * @function toJSON + * @memberof onnx.TrainingInfoProto + * @instance + * @returns {Object.} JSON object + */ + TrainingInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TrainingInfoProto + * @function getTypeUrl + * @memberof onnx.TrainingInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TrainingInfoProto'; + }; + + return TrainingInfoProto; + })(); + + onnx.ModelProto = (function () { + /** + * Properties of a ModelProto. + * @memberof onnx + * @interface IModelProto + * @property {number|Long|null} [irVersion] ModelProto irVersion + * @property {Array.|null} [opsetImport] ModelProto opsetImport + * @property {string|null} [producerName] ModelProto producerName + * @property {string|null} [producerVersion] ModelProto producerVersion + * @property {string|null} [domain] ModelProto domain + * @property {number|Long|null} [modelVersion] ModelProto modelVersion + * @property {string|null} [docString] ModelProto docString + * @property {onnx.IGraphProto|null} [graph] ModelProto graph + * @property {Array.|null} [metadataProps] ModelProto metadataProps + * @property {Array.|null} [trainingInfo] ModelProto trainingInfo + * @property {Array.|null} [functions] ModelProto functions + */ + + /** + * Constructs a new ModelProto. + * @memberof onnx + * @classdesc Represents a ModelProto. + * @implements IModelProto + * @constructor + * @param {onnx.IModelProto=} [properties] Properties to set + */ + function ModelProto(properties) { + this.opsetImport = []; + this.metadataProps = []; + this.trainingInfo = []; + this.functions = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * ModelProto irVersion. + * @member {number|Long} irVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * ModelProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.opsetImport = $util.emptyArray; + + /** + * ModelProto producerName. + * @member {string} producerName + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerName = ''; + + /** + * ModelProto producerVersion. + * @member {string} producerVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerVersion = ''; + + /** + * ModelProto domain. + * @member {string} domain + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.domain = ''; + + /** + * ModelProto modelVersion. + * @member {number|Long} modelVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * ModelProto docString. + * @member {string} docString + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.docString = ''; + + /** + * ModelProto graph. + * @member {onnx.IGraphProto|null|undefined} graph + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.graph = null; + + /** + * ModelProto metadataProps. + * @member {Array.} metadataProps + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.metadataProps = $util.emptyArray; + + /** + * ModelProto trainingInfo. + * @member {Array.} trainingInfo + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.trainingInfo = $util.emptyArray; + + /** + * ModelProto functions. + * @member {Array.} functions + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.functions = $util.emptyArray; + + /** + * Creates a new ModelProto instance using the specified properties. + * @function create + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto=} [properties] Properties to set + * @returns {onnx.ModelProto} ModelProto instance + */ + ModelProto.create = function create(properties) { + return new ModelProto(properties); + }; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encode + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.irVersion != null && Object.hasOwnProperty.call(message, 'irVersion')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int64(message.irVersion); + if (message.producerName != null && Object.hasOwnProperty.call(message, 'producerName')) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.producerName); + if (message.producerVersion != null && Object.hasOwnProperty.call(message, 'producerVersion')) + writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.producerVersion); + if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + writer.uint32(/* id 4, wireType 2 =*/ 34).string(message.domain); + if (message.modelVersion != null && Object.hasOwnProperty.call(message, 'modelVersion')) + writer.uint32(/* id 5, wireType 0 =*/ 40).int64(message.modelVersion); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.docString); + if (message.graph != null && Object.hasOwnProperty.call(message, 'graph')) + $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/ 58).fork()).ldelim(); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode( + message.opsetImport[i], + writer.uint32(/* id 8, wireType 2 =*/ 66).fork(), + ).ldelim(); + if (message.metadataProps != null && message.metadataProps.length) + for (var i = 0; i < message.metadataProps.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.metadataProps[i], + writer.uint32(/* id 14, wireType 2 =*/ 114).fork(), + ).ldelim(); + if (message.trainingInfo != null && message.trainingInfo.length) + for (var i = 0; i < message.trainingInfo.length; ++i) + $root.onnx.TrainingInfoProto.encode( + message.trainingInfo[i], + writer.uint32(/* id 20, wireType 2 =*/ 162).fork(), + ).ldelim(); + if (message.functions != null && message.functions.length) + for (var i = 0; i < message.functions.length; ++i) + $root.onnx.FunctionProto.encode( + message.functions[i], + writer.uint32(/* id 25, wireType 2 =*/ 202).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.ModelProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.irVersion = reader.int64(); + break; + } + case 8: { + if (!(message.opsetImport && message.opsetImport.length)) message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.producerName = reader.string(); + break; + } + case 3: { + message.producerVersion = reader.string(); + break; + } + case 4: { + message.domain = reader.string(); + break; + } + case 5: { + message.modelVersion = reader.int64(); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + case 7: { + message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 14: { + if (!(message.metadataProps && message.metadataProps.length)) message.metadataProps = []; + message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 20: { + if (!(message.trainingInfo && message.trainingInfo.length)) message.trainingInfo = []; + message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); + break; + } + case 25: { + if (!(message.functions && message.functions.length)) message.functions = []; + message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ModelProto message. + * @function verify + * @memberof onnx.ModelProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ModelProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.irVersion != null && message.hasOwnProperty('irVersion')) + if ( + !$util.isInteger(message.irVersion) && + !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high)) + ) + return 'irVersion: integer|Long expected'; + if (message.opsetImport != null && message.hasOwnProperty('opsetImport')) { + if (!Array.isArray(message.opsetImport)) return 'opsetImport: array expected'; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) return 'opsetImport.' + error; + } + } + if (message.producerName != null && message.hasOwnProperty('producerName')) + if (!$util.isString(message.producerName)) return 'producerName: string expected'; + if (message.producerVersion != null && message.hasOwnProperty('producerVersion')) + if (!$util.isString(message.producerVersion)) return 'producerVersion: string expected'; + if (message.domain != null && message.hasOwnProperty('domain')) + if (!$util.isString(message.domain)) return 'domain: string expected'; + if (message.modelVersion != null && message.hasOwnProperty('modelVersion')) + if ( + !$util.isInteger(message.modelVersion) && + !( + message.modelVersion && + $util.isInteger(message.modelVersion.low) && + $util.isInteger(message.modelVersion.high) + ) + ) + return 'modelVersion: integer|Long expected'; + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.graph != null && message.hasOwnProperty('graph')) { + var error = $root.onnx.GraphProto.verify(message.graph); + if (error) return 'graph.' + error; + } + if (message.metadataProps != null && message.hasOwnProperty('metadataProps')) { + if (!Array.isArray(message.metadataProps)) return 'metadataProps: array expected'; + for (var i = 0; i < message.metadataProps.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); + if (error) return 'metadataProps.' + error; + } + } + if (message.trainingInfo != null && message.hasOwnProperty('trainingInfo')) { + if (!Array.isArray(message.trainingInfo)) return 'trainingInfo: array expected'; + for (var i = 0; i < message.trainingInfo.length; ++i) { + var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); + if (error) return 'trainingInfo.' + error; + } + } + if (message.functions != null && message.hasOwnProperty('functions')) { + if (!Array.isArray(message.functions)) return 'functions: array expected'; + for (var i = 0; i < message.functions.length; ++i) { + var error = $root.onnx.FunctionProto.verify(message.functions[i]); + if (error) return 'functions.' + error; + } + } + return null; + }; + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ModelProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ModelProto} ModelProto + */ + ModelProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ModelProto) return object; + var message = new $root.onnx.ModelProto(); + if (object.irVersion != null) + if ($util.Long) (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; + else if (typeof object.irVersion === 'string') message.irVersion = parseInt(object.irVersion, 10); + else if (typeof object.irVersion === 'number') message.irVersion = object.irVersion; + else if (typeof object.irVersion === 'object') + message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) throw TypeError('.onnx.ModelProto.opsetImport: array expected'); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== 'object') + throw TypeError('.onnx.ModelProto.opsetImport: object expected'); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.producerName != null) message.producerName = String(object.producerName); + if (object.producerVersion != null) message.producerVersion = String(object.producerVersion); + if (object.domain != null) message.domain = String(object.domain); + if (object.modelVersion != null) + if ($util.Long) (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; + else if (typeof object.modelVersion === 'string') message.modelVersion = parseInt(object.modelVersion, 10); + else if (typeof object.modelVersion === 'number') message.modelVersion = object.modelVersion; + else if (typeof object.modelVersion === 'object') + message.modelVersion = new $util.LongBits( + object.modelVersion.low >>> 0, + object.modelVersion.high >>> 0, + ).toNumber(); + if (object.docString != null) message.docString = String(object.docString); + if (object.graph != null) { + if (typeof object.graph !== 'object') throw TypeError('.onnx.ModelProto.graph: object expected'); + message.graph = $root.onnx.GraphProto.fromObject(object.graph); + } + if (object.metadataProps) { + if (!Array.isArray(object.metadataProps)) throw TypeError('.onnx.ModelProto.metadataProps: array expected'); + message.metadataProps = []; + for (var i = 0; i < object.metadataProps.length; ++i) { + if (typeof object.metadataProps[i] !== 'object') + throw TypeError('.onnx.ModelProto.metadataProps: object expected'); + message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); + } + } + if (object.trainingInfo) { + if (!Array.isArray(object.trainingInfo)) throw TypeError('.onnx.ModelProto.trainingInfo: array expected'); + message.trainingInfo = []; + for (var i = 0; i < object.trainingInfo.length; ++i) { + if (typeof object.trainingInfo[i] !== 'object') + throw TypeError('.onnx.ModelProto.trainingInfo: object expected'); + message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); + } + } + if (object.functions) { + if (!Array.isArray(object.functions)) throw TypeError('.onnx.ModelProto.functions: array expected'); + message.functions = []; + for (var i = 0; i < object.functions.length; ++i) { + if (typeof object.functions[i] !== 'object') throw TypeError('.onnx.ModelProto.functions: object expected'); + message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ModelProto + * @static + * @param {onnx.ModelProto} message ModelProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ModelProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.opsetImport = []; + object.metadataProps = []; + object.trainingInfo = []; + object.functions = []; + } + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.irVersion = + options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.irVersion = options.longs === String ? '0' : 0; + object.producerName = ''; + object.producerVersion = ''; + object.domain = ''; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.modelVersion = + options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.modelVersion = options.longs === String ? '0' : 0; + object.docString = ''; + object.graph = null; + } + if (message.irVersion != null && message.hasOwnProperty('irVersion')) + if (typeof message.irVersion === 'number') + object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; + else + object.irVersion = + options.longs === String + ? $util.Long.prototype.toString.call(message.irVersion) + : options.longs === Number + ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() + : message.irVersion; + if (message.producerName != null && message.hasOwnProperty('producerName')) + object.producerName = message.producerName; + if (message.producerVersion != null && message.hasOwnProperty('producerVersion')) + object.producerVersion = message.producerVersion; + if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + if (message.modelVersion != null && message.hasOwnProperty('modelVersion')) + if (typeof message.modelVersion === 'number') + object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; + else + object.modelVersion = + options.longs === String + ? $util.Long.prototype.toString.call(message.modelVersion) + : options.longs === Number + ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() + : message.modelVersion; + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.graph != null && message.hasOwnProperty('graph')) + object.graph = $root.onnx.GraphProto.toObject(message.graph, options); + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.metadataProps && message.metadataProps.length) { + object.metadataProps = []; + for (var j = 0; j < message.metadataProps.length; ++j) + object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); + } + if (message.trainingInfo && message.trainingInfo.length) { + object.trainingInfo = []; + for (var j = 0; j < message.trainingInfo.length; ++j) + object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); + } + if (message.functions && message.functions.length) { + object.functions = []; + for (var j = 0; j < message.functions.length; ++j) + object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + } + return object; + }; + + /** + * Converts this ModelProto to JSON. + * @function toJSON + * @memberof onnx.ModelProto + * @instance + * @returns {Object.} JSON object + */ + ModelProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ModelProto + * @function getTypeUrl + * @memberof onnx.ModelProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.ModelProto'; + }; + + return ModelProto; + })(); + + onnx.StringStringEntryProto = (function () { + /** + * Properties of a StringStringEntryProto. + * @memberof onnx + * @interface IStringStringEntryProto + * @property {string|null} [key] StringStringEntryProto key + * @property {string|null} [value] StringStringEntryProto value + */ + + /** + * Constructs a new StringStringEntryProto. + * @memberof onnx + * @classdesc Represents a StringStringEntryProto. + * @implements IStringStringEntryProto + * @constructor + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + */ + function StringStringEntryProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * StringStringEntryProto key. + * @member {string} key + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.key = ''; + + /** + * StringStringEntryProto value. + * @member {string} value + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.value = ''; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @function create + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance + */ + StringStringEntryProto.create = function create(properties) { + return new StringStringEntryProto(properties); + }; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encode + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.key != null && Object.hasOwnProperty.call(message, 'key')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.key); + if (message.value != null && Object.hasOwnProperty.call(message, 'value')) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.value); + return writer; + }; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.StringStringEntryProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.key = reader.string(); + break; + } + case 2: { + message.value = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a StringStringEntryProto message. + * @function verify + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + StringStringEntryProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.key != null && message.hasOwnProperty('key')) + if (!$util.isString(message.key)) return 'key: string expected'; + if (message.value != null && message.hasOwnProperty('value')) + if (!$util.isString(message.value)) return 'value: string expected'; + return null; + }; + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + */ + StringStringEntryProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.StringStringEntryProto) return object; + var message = new $root.onnx.StringStringEntryProto(); + if (object.key != null) message.key = String(object.key); + if (object.value != null) message.value = String(object.value); + return message; + }; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.StringStringEntryProto} message StringStringEntryProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + StringStringEntryProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.key = ''; + object.value = ''; + } + if (message.key != null && message.hasOwnProperty('key')) object.key = message.key; + if (message.value != null && message.hasOwnProperty('value')) object.value = message.value; + return object; + }; + + /** + * Converts this StringStringEntryProto to JSON. + * @function toJSON + * @memberof onnx.StringStringEntryProto + * @instance + * @returns {Object.} JSON object + */ + StringStringEntryProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for StringStringEntryProto + * @function getTypeUrl + * @memberof onnx.StringStringEntryProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.StringStringEntryProto'; + }; + + return StringStringEntryProto; + })(); + + onnx.TensorAnnotation = (function () { + /** + * Properties of a TensorAnnotation. + * @memberof onnx + * @interface ITensorAnnotation + * @property {string|null} [tensorName] TensorAnnotation tensorName + * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames + */ + + /** + * Constructs a new TensorAnnotation. + * @memberof onnx + * @classdesc Represents a TensorAnnotation. + * @implements ITensorAnnotation + * @constructor + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + */ + function TensorAnnotation(properties) { + this.quantParameterTensorNames = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorAnnotation tensorName. + * @member {string} tensorName + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.tensorName = ''; + + /** + * TensorAnnotation quantParameterTensorNames. + * @member {Array.} quantParameterTensorNames + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @function create + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + * @returns {onnx.TensorAnnotation} TensorAnnotation instance + */ + TensorAnnotation.create = function create(properties) { + return new TensorAnnotation(properties); + }; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encode + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.tensorName != null && Object.hasOwnProperty.call(message, 'tensorName')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.tensorName); + if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.quantParameterTensorNames[i], + writer.uint32(/* id 2, wireType 2 =*/ 18).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorAnnotation(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorName = reader.string(); + break; + } + case 2: { + if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) + message.quantParameterTensorNames = []; + message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorAnnotation message. + * @function verify + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorAnnotation.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.tensorName != null && message.hasOwnProperty('tensorName')) + if (!$util.isString(message.tensorName)) return 'tensorName: string expected'; + if (message.quantParameterTensorNames != null && message.hasOwnProperty('quantParameterTensorNames')) { + if (!Array.isArray(message.quantParameterTensorNames)) return 'quantParameterTensorNames: array expected'; + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); + if (error) return 'quantParameterTensorNames.' + error; + } + } + return null; + }; + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorAnnotation} TensorAnnotation + */ + TensorAnnotation.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorAnnotation) return object; + var message = new $root.onnx.TensorAnnotation(); + if (object.tensorName != null) message.tensorName = String(object.tensorName); + if (object.quantParameterTensorNames) { + if (!Array.isArray(object.quantParameterTensorNames)) + throw TypeError('.onnx.TensorAnnotation.quantParameterTensorNames: array expected'); + message.quantParameterTensorNames = []; + for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { + if (typeof object.quantParameterTensorNames[i] !== 'object') + throw TypeError('.onnx.TensorAnnotation.quantParameterTensorNames: object expected'); + message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject( + object.quantParameterTensorNames[i], + ); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.TensorAnnotation} message TensorAnnotation + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorAnnotation.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) object.quantParameterTensorNames = []; + if (options.defaults) object.tensorName = ''; + if (message.tensorName != null && message.hasOwnProperty('tensorName')) object.tensorName = message.tensorName; + if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { + object.quantParameterTensorNames = []; + for (var j = 0; j < message.quantParameterTensorNames.length; ++j) + object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject( + message.quantParameterTensorNames[j], + options, + ); + } + return object; + }; + + /** + * Converts this TensorAnnotation to JSON. + * @function toJSON + * @memberof onnx.TensorAnnotation + * @instance + * @returns {Object.} JSON object + */ + TensorAnnotation.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorAnnotation + * @function getTypeUrl + * @memberof onnx.TensorAnnotation + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorAnnotation'; + }; + + return TensorAnnotation; + })(); + + onnx.GraphProto = (function () { + /** + * Properties of a GraphProto. + * @memberof onnx + * @interface IGraphProto + * @property {Array.|null} [node] GraphProto node + * @property {string|null} [name] GraphProto name + * @property {Array.|null} [initializer] GraphProto initializer + * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer + * @property {string|null} [docString] GraphProto docString + * @property {Array.|null} [input] GraphProto input + * @property {Array.|null} [output] GraphProto output + * @property {Array.|null} [valueInfo] GraphProto valueInfo + * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation + */ + + /** + * Constructs a new GraphProto. + * @memberof onnx + * @classdesc Represents a GraphProto. + * @implements IGraphProto + * @constructor + * @param {onnx.IGraphProto=} [properties] Properties to set + */ + function GraphProto(properties) { + this.node = []; + this.initializer = []; + this.sparseInitializer = []; + this.input = []; + this.output = []; + this.valueInfo = []; + this.quantizationAnnotation = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * GraphProto node. + * @member {Array.} node + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.node = $util.emptyArray; + + /** + * GraphProto name. + * @member {string} name + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.name = ''; + + /** + * GraphProto initializer. + * @member {Array.} initializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.initializer = $util.emptyArray; + + /** + * GraphProto sparseInitializer. + * @member {Array.} sparseInitializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.sparseInitializer = $util.emptyArray; + + /** + * GraphProto docString. + * @member {string} docString + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.docString = ''; + + /** + * GraphProto input. + * @member {Array.} input + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.input = $util.emptyArray; + + /** + * GraphProto output. + * @member {Array.} output + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.output = $util.emptyArray; + + /** + * GraphProto valueInfo. + * @member {Array.} valueInfo + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.valueInfo = $util.emptyArray; -$root.onnx = (function() { + /** + * GraphProto quantizationAnnotation. + * @member {Array.} quantizationAnnotation + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.quantizationAnnotation = $util.emptyArray; + + /** + * Creates a new GraphProto instance using the specified properties. + * @function create + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto=} [properties] Properties to set + * @returns {onnx.GraphProto} GraphProto instance + */ + GraphProto.create = function create(properties) { + return new GraphProto(properties); + }; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encode + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.name); + if (message.initializer != null && message.initializer.length) + for (var i = 0; i < message.initializer.length; ++i) + $root.onnx.TensorProto.encode( + message.initializer[i], + writer.uint32(/* id 5, wireType 2 =*/ 42).fork(), + ).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 10, wireType 2 =*/ 82).string(message.docString); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + $root.onnx.ValueInfoProto.encode( + message.input[i], + writer.uint32(/* id 11, wireType 2 =*/ 90).fork(), + ).ldelim(); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + $root.onnx.ValueInfoProto.encode( + message.output[i], + writer.uint32(/* id 12, wireType 2 =*/ 98).fork(), + ).ldelim(); + if (message.valueInfo != null && message.valueInfo.length) + for (var i = 0; i < message.valueInfo.length; ++i) + $root.onnx.ValueInfoProto.encode( + message.valueInfo[i], + writer.uint32(/* id 13, wireType 2 =*/ 106).fork(), + ).ldelim(); + if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) + for (var i = 0; i < message.quantizationAnnotation.length; ++i) + $root.onnx.TensorAnnotation.encode( + message.quantizationAnnotation[i], + writer.uint32(/* id 14, wireType 2 =*/ 114).fork(), + ).ldelim(); + if (message.sparseInitializer != null && message.sparseInitializer.length) + for (var i = 0; i < message.sparseInitializer.length; ++i) + $root.onnx.SparseTensorProto.encode( + message.sparseInitializer[i], + writer.uint32(/* id 15, wireType 2 =*/ 122).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.GraphProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.node && message.node.length)) message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.name = reader.string(); + break; + } + case 5: { + if (!(message.initializer && message.initializer.length)) message.initializer = []; + message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.sparseInitializer && message.sparseInitializer.length)) message.sparseInitializer = []; + message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.docString = reader.string(); + break; + } + case 11: { + if (!(message.input && message.input.length)) message.input = []; + message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 12: { + if (!(message.output && message.output.length)) message.output = []; + message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 13: { + if (!(message.valueInfo && message.valueInfo.length)) message.valueInfo = []; + message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 14: { + if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) + message.quantizationAnnotation = []; + message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a GraphProto message. + * @function verify + * @memberof onnx.GraphProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + GraphProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.node != null && message.hasOwnProperty('node')) { + if (!Array.isArray(message.node)) return 'node: array expected'; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) return 'node.' + error; + } + } + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.initializer != null && message.hasOwnProperty('initializer')) { + if (!Array.isArray(message.initializer)) return 'initializer: array expected'; + for (var i = 0; i < message.initializer.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.initializer[i]); + if (error) return 'initializer.' + error; + } + } + if (message.sparseInitializer != null && message.hasOwnProperty('sparseInitializer')) { + if (!Array.isArray(message.sparseInitializer)) return 'sparseInitializer: array expected'; + for (var i = 0; i < message.sparseInitializer.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); + if (error) return 'sparseInitializer.' + error; + } + } + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.input != null && message.hasOwnProperty('input')) { + if (!Array.isArray(message.input)) return 'input: array expected'; + for (var i = 0; i < message.input.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.input[i]); + if (error) return 'input.' + error; + } + } + if (message.output != null && message.hasOwnProperty('output')) { + if (!Array.isArray(message.output)) return 'output: array expected'; + for (var i = 0; i < message.output.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.output[i]); + if (error) return 'output.' + error; + } + } + if (message.valueInfo != null && message.hasOwnProperty('valueInfo')) { + if (!Array.isArray(message.valueInfo)) return 'valueInfo: array expected'; + for (var i = 0; i < message.valueInfo.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); + if (error) return 'valueInfo.' + error; + } + } + if (message.quantizationAnnotation != null && message.hasOwnProperty('quantizationAnnotation')) { + if (!Array.isArray(message.quantizationAnnotation)) return 'quantizationAnnotation: array expected'; + for (var i = 0; i < message.quantizationAnnotation.length; ++i) { + var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); + if (error) return 'quantizationAnnotation.' + error; + } + } + return null; + }; + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.GraphProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.GraphProto} GraphProto + */ + GraphProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.GraphProto) return object; + var message = new $root.onnx.GraphProto(); + if (object.node) { + if (!Array.isArray(object.node)) throw TypeError('.onnx.GraphProto.node: array expected'); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== 'object') throw TypeError('.onnx.GraphProto.node: object expected'); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.name != null) message.name = String(object.name); + if (object.initializer) { + if (!Array.isArray(object.initializer)) throw TypeError('.onnx.GraphProto.initializer: array expected'); + message.initializer = []; + for (var i = 0; i < object.initializer.length; ++i) { + if (typeof object.initializer[i] !== 'object') + throw TypeError('.onnx.GraphProto.initializer: object expected'); + message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); + } + } + if (object.sparseInitializer) { + if (!Array.isArray(object.sparseInitializer)) + throw TypeError('.onnx.GraphProto.sparseInitializer: array expected'); + message.sparseInitializer = []; + for (var i = 0; i < object.sparseInitializer.length; ++i) { + if (typeof object.sparseInitializer[i] !== 'object') + throw TypeError('.onnx.GraphProto.sparseInitializer: object expected'); + message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); + } + } + if (object.docString != null) message.docString = String(object.docString); + if (object.input) { + if (!Array.isArray(object.input)) throw TypeError('.onnx.GraphProto.input: array expected'); + message.input = []; + for (var i = 0; i < object.input.length; ++i) { + if (typeof object.input[i] !== 'object') throw TypeError('.onnx.GraphProto.input: object expected'); + message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); + } + } + if (object.output) { + if (!Array.isArray(object.output)) throw TypeError('.onnx.GraphProto.output: array expected'); + message.output = []; + for (var i = 0; i < object.output.length; ++i) { + if (typeof object.output[i] !== 'object') throw TypeError('.onnx.GraphProto.output: object expected'); + message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); + } + } + if (object.valueInfo) { + if (!Array.isArray(object.valueInfo)) throw TypeError('.onnx.GraphProto.valueInfo: array expected'); + message.valueInfo = []; + for (var i = 0; i < object.valueInfo.length; ++i) { + if (typeof object.valueInfo[i] !== 'object') throw TypeError('.onnx.GraphProto.valueInfo: object expected'); + message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); + } + } + if (object.quantizationAnnotation) { + if (!Array.isArray(object.quantizationAnnotation)) + throw TypeError('.onnx.GraphProto.quantizationAnnotation: array expected'); + message.quantizationAnnotation = []; + for (var i = 0; i < object.quantizationAnnotation.length; ++i) { + if (typeof object.quantizationAnnotation[i] !== 'object') + throw TypeError('.onnx.GraphProto.quantizationAnnotation: object expected'); + message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.GraphProto + * @static + * @param {onnx.GraphProto} message GraphProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + GraphProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.node = []; + object.initializer = []; + object.input = []; + object.output = []; + object.valueInfo = []; + object.quantizationAnnotation = []; + object.sparseInitializer = []; + } + if (options.defaults) { + object.name = ''; + object.docString = ''; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.initializer && message.initializer.length) { + object.initializer = []; + for (var j = 0; j < message.initializer.length; ++j) + object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); + } + if (message.valueInfo && message.valueInfo.length) { + object.valueInfo = []; + for (var j = 0; j < message.valueInfo.length; ++j) + object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); + } + if (message.quantizationAnnotation && message.quantizationAnnotation.length) { + object.quantizationAnnotation = []; + for (var j = 0; j < message.quantizationAnnotation.length; ++j) + object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject( + message.quantizationAnnotation[j], + options, + ); + } + if (message.sparseInitializer && message.sparseInitializer.length) { + object.sparseInitializer = []; + for (var j = 0; j < message.sparseInitializer.length; ++j) + object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); + } + return object; + }; + + /** + * Converts this GraphProto to JSON. + * @function toJSON + * @memberof onnx.GraphProto + * @instance + * @returns {Object.} JSON object + */ + GraphProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for GraphProto + * @function getTypeUrl + * @memberof onnx.GraphProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.GraphProto'; + }; + + return GraphProto; + })(); + + onnx.TensorProto = (function () { + /** + * Properties of a TensorProto. + * @memberof onnx + * @interface ITensorProto + * @property {Array.|null} [dims] TensorProto dims + * @property {number|null} [dataType] TensorProto dataType + * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment + * @property {Array.|null} [floatData] TensorProto floatData + * @property {Array.|null} [int32Data] TensorProto int32Data + * @property {Array.|null} [stringData] TensorProto stringData + * @property {Array.|null} [int64Data] TensorProto int64Data + * @property {string|null} [name] TensorProto name + * @property {string|null} [docString] TensorProto docString + * @property {Uint8Array|null} [rawData] TensorProto rawData + * @property {Array.|null} [externalData] TensorProto externalData + * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation + * @property {Array.|null} [doubleData] TensorProto doubleData + * @property {Array.|null} [uint64Data] TensorProto uint64Data + */ + + /** + * Constructs a new TensorProto. + * @memberof onnx + * @classdesc Represents a TensorProto. + * @implements ITensorProto + * @constructor + * @param {onnx.ITensorProto=} [properties] Properties to set + */ + function TensorProto(properties) { + this.dims = []; + this.floatData = []; + this.int32Data = []; + this.stringData = []; + this.int64Data = []; + this.externalData = []; + this.doubleData = []; + this.uint64Data = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorProto dims. + * @member {Array.} dims + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dims = $util.emptyArray; + + /** + * TensorProto dataType. + * @member {number} dataType + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataType = 0; + + /** + * TensorProto segment. + * @member {onnx.TensorProto.ISegment|null|undefined} segment + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.segment = null; + + /** + * TensorProto floatData. + * @member {Array.} floatData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.floatData = $util.emptyArray; + + /** + * TensorProto int32Data. + * @member {Array.} int32Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int32Data = $util.emptyArray; + + /** + * TensorProto stringData. + * @member {Array.} stringData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.stringData = $util.emptyArray; + + /** + * TensorProto int64Data. + * @member {Array.} int64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int64Data = $util.emptyArray; + + /** + * TensorProto name. + * @member {string} name + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.name = ''; + + /** + * TensorProto docString. + * @member {string} docString + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.docString = ''; + + /** + * TensorProto rawData. + * @member {Uint8Array} rawData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.rawData = $util.newBuffer([]); + + /** + * TensorProto externalData. + * @member {Array.} externalData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.externalData = $util.emptyArray; + + /** + * TensorProto dataLocation. + * @member {onnx.TensorProto.DataLocation} dataLocation + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataLocation = 0; + + /** + * TensorProto doubleData. + * @member {Array.} doubleData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.doubleData = $util.emptyArray; + + /** + * TensorProto uint64Data. + * @member {Array.} uint64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.uint64Data = $util.emptyArray; + + /** + * Creates a new TensorProto instance using the specified properties. + * @function create + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto=} [properties] Properties to set + * @returns {onnx.TensorProto} TensorProto instance + */ + TensorProto.create = function create(properties) { + return new TensorProto(properties); + }; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(); + for (var i = 0; i < message.dims.length; ++i) writer.int64(message.dims[i]); + writer.ldelim(); + } + if (message.dataType != null && Object.hasOwnProperty.call(message, 'dataType')) + writer.uint32(/* id 2, wireType 0 =*/ 16).int32(message.dataType); + if (message.segment != null && Object.hasOwnProperty.call(message, 'segment')) + $root.onnx.TensorProto.Segment.encode( + message.segment, + writer.uint32(/* id 3, wireType 2 =*/ 26).fork(), + ).ldelim(); + if (message.floatData != null && message.floatData.length) { + writer.uint32(/* id 4, wireType 2 =*/ 34).fork(); + for (var i = 0; i < message.floatData.length; ++i) writer.float(message.floatData[i]); + writer.ldelim(); + } + if (message.int32Data != null && message.int32Data.length) { + writer.uint32(/* id 5, wireType 2 =*/ 42).fork(); + for (var i = 0; i < message.int32Data.length; ++i) writer.int32(message.int32Data[i]); + writer.ldelim(); + } + if (message.stringData != null && message.stringData.length) + for (var i = 0; i < message.stringData.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/ 50).bytes(message.stringData[i]); + if (message.int64Data != null && message.int64Data.length) { + writer.uint32(/* id 7, wireType 2 =*/ 58).fork(); + for (var i = 0; i < message.int64Data.length; ++i) writer.int64(message.int64Data[i]); + writer.ldelim(); + } + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 8, wireType 2 =*/ 66).string(message.name); + if (message.rawData != null && Object.hasOwnProperty.call(message, 'rawData')) + writer.uint32(/* id 9, wireType 2 =*/ 74).bytes(message.rawData); + if (message.doubleData != null && message.doubleData.length) { + writer.uint32(/* id 10, wireType 2 =*/ 82).fork(); + for (var i = 0; i < message.doubleData.length; ++i) writer.double(message.doubleData[i]); + writer.ldelim(); + } + if (message.uint64Data != null && message.uint64Data.length) { + writer.uint32(/* id 11, wireType 2 =*/ 90).fork(); + for (var i = 0; i < message.uint64Data.length; ++i) writer.uint64(message.uint64Data[i]); + writer.ldelim(); + } + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 12, wireType 2 =*/ 98).string(message.docString); + if (message.externalData != null && message.externalData.length) + for (var i = 0; i < message.externalData.length; ++i) + $root.onnx.StringStringEntryProto.encode( + message.externalData[i], + writer.uint32(/* id 13, wireType 2 =*/ 106).fork(), + ).ldelim(); + if (message.dataLocation != null && Object.hasOwnProperty.call(message, 'dataLocation')) + writer.uint32(/* id 14, wireType 0 =*/ 112).int32(message.dataLocation); + return writer; + }; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dims && message.dims.length)) message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.dims.push(reader.int64()); + } else message.dims.push(reader.int64()); + break; + } + case 2: { + message.dataType = reader.int32(); + break; + } + case 3: { + message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); + break; + } + case 4: { + if (!(message.floatData && message.floatData.length)) message.floatData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.floatData.push(reader.float()); + } else message.floatData.push(reader.float()); + break; + } + case 5: { + if (!(message.int32Data && message.int32Data.length)) message.int32Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.int32Data.push(reader.int32()); + } else message.int32Data.push(reader.int32()); + break; + } + case 6: { + if (!(message.stringData && message.stringData.length)) message.stringData = []; + message.stringData.push(reader.bytes()); + break; + } + case 7: { + if (!(message.int64Data && message.int64Data.length)) message.int64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.int64Data.push(reader.int64()); + } else message.int64Data.push(reader.int64()); + break; + } + case 8: { + message.name = reader.string(); + break; + } + case 12: { + message.docString = reader.string(); + break; + } + case 9: { + message.rawData = reader.bytes(); + break; + } + case 13: { + if (!(message.externalData && message.externalData.length)) message.externalData = []; + message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 14: { + message.dataLocation = reader.int32(); + break; + } + case 10: { + if (!(message.doubleData && message.doubleData.length)) message.doubleData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.doubleData.push(reader.double()); + } else message.doubleData.push(reader.double()); + break; + } + case 11: { + if (!(message.uint64Data && message.uint64Data.length)) message.uint64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.uint64Data.push(reader.uint64()); + } else message.uint64Data.push(reader.uint64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorProto message. + * @function verify + * @memberof onnx.TensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.dims != null && message.hasOwnProperty('dims')) { + if (!Array.isArray(message.dims)) return 'dims: array expected'; + for (var i = 0; i < message.dims.length; ++i) + if ( + !$util.isInteger(message.dims[i]) && + !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high)) + ) + return 'dims: integer|Long[] expected'; + } + if (message.dataType != null && message.hasOwnProperty('dataType')) + if (!$util.isInteger(message.dataType)) return 'dataType: integer expected'; + if (message.segment != null && message.hasOwnProperty('segment')) { + var error = $root.onnx.TensorProto.Segment.verify(message.segment); + if (error) return 'segment.' + error; + } + if (message.floatData != null && message.hasOwnProperty('floatData')) { + if (!Array.isArray(message.floatData)) return 'floatData: array expected'; + for (var i = 0; i < message.floatData.length; ++i) + if (typeof message.floatData[i] !== 'number') return 'floatData: number[] expected'; + } + if (message.int32Data != null && message.hasOwnProperty('int32Data')) { + if (!Array.isArray(message.int32Data)) return 'int32Data: array expected'; + for (var i = 0; i < message.int32Data.length; ++i) + if (!$util.isInteger(message.int32Data[i])) return 'int32Data: integer[] expected'; + } + if (message.stringData != null && message.hasOwnProperty('stringData')) { + if (!Array.isArray(message.stringData)) return 'stringData: array expected'; + for (var i = 0; i < message.stringData.length; ++i) + if ( + !( + (message.stringData[i] && typeof message.stringData[i].length === 'number') || + $util.isString(message.stringData[i]) + ) + ) + return 'stringData: buffer[] expected'; + } + if (message.int64Data != null && message.hasOwnProperty('int64Data')) { + if (!Array.isArray(message.int64Data)) return 'int64Data: array expected'; + for (var i = 0; i < message.int64Data.length; ++i) + if ( + !$util.isInteger(message.int64Data[i]) && + !( + message.int64Data[i] && + $util.isInteger(message.int64Data[i].low) && + $util.isInteger(message.int64Data[i].high) + ) + ) + return 'int64Data: integer|Long[] expected'; + } + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.rawData != null && message.hasOwnProperty('rawData')) + if (!((message.rawData && typeof message.rawData.length === 'number') || $util.isString(message.rawData))) + return 'rawData: buffer expected'; + if (message.externalData != null && message.hasOwnProperty('externalData')) { + if (!Array.isArray(message.externalData)) return 'externalData: array expected'; + for (var i = 0; i < message.externalData.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); + if (error) return 'externalData.' + error; + } + } + if (message.dataLocation != null && message.hasOwnProperty('dataLocation')) + switch (message.dataLocation) { + default: + return 'dataLocation: enum value expected'; + case 0: + case 1: + break; + } + if (message.doubleData != null && message.hasOwnProperty('doubleData')) { + if (!Array.isArray(message.doubleData)) return 'doubleData: array expected'; + for (var i = 0; i < message.doubleData.length; ++i) + if (typeof message.doubleData[i] !== 'number') return 'doubleData: number[] expected'; + } + if (message.uint64Data != null && message.hasOwnProperty('uint64Data')) { + if (!Array.isArray(message.uint64Data)) return 'uint64Data: array expected'; + for (var i = 0; i < message.uint64Data.length; ++i) + if ( + !$util.isInteger(message.uint64Data[i]) && + !( + message.uint64Data[i] && + $util.isInteger(message.uint64Data[i].low) && + $util.isInteger(message.uint64Data[i].high) + ) + ) + return 'uint64Data: integer|Long[] expected'; + } + return null; + }; + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto} TensorProto + */ + TensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto) return object; + var message = new $root.onnx.TensorProto(); + if (object.dims) { + if (!Array.isArray(object.dims)) throw TypeError('.onnx.TensorProto.dims: array expected'); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === 'string') message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === 'number') message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === 'object') + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + if (object.dataType != null) message.dataType = object.dataType | 0; + if (object.segment != null) { + if (typeof object.segment !== 'object') throw TypeError('.onnx.TensorProto.segment: object expected'); + message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); + } + if (object.floatData) { + if (!Array.isArray(object.floatData)) throw TypeError('.onnx.TensorProto.floatData: array expected'); + message.floatData = []; + for (var i = 0; i < object.floatData.length; ++i) message.floatData[i] = Number(object.floatData[i]); + } + if (object.int32Data) { + if (!Array.isArray(object.int32Data)) throw TypeError('.onnx.TensorProto.int32Data: array expected'); + message.int32Data = []; + for (var i = 0; i < object.int32Data.length; ++i) message.int32Data[i] = object.int32Data[i] | 0; + } + if (object.stringData) { + if (!Array.isArray(object.stringData)) throw TypeError('.onnx.TensorProto.stringData: array expected'); + message.stringData = []; + for (var i = 0; i < object.stringData.length; ++i) + if (typeof object.stringData[i] === 'string') + $util.base64.decode( + object.stringData[i], + (message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i]))), + 0, + ); + else if (object.stringData[i].length >= 0) message.stringData[i] = object.stringData[i]; + } + if (object.int64Data) { + if (!Array.isArray(object.int64Data)) throw TypeError('.onnx.TensorProto.int64Data: array expected'); + message.int64Data = []; + for (var i = 0; i < object.int64Data.length; ++i) + if ($util.Long) (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; + else if (typeof object.int64Data[i] === 'string') message.int64Data[i] = parseInt(object.int64Data[i], 10); + else if (typeof object.int64Data[i] === 'number') message.int64Data[i] = object.int64Data[i]; + else if (typeof object.int64Data[i] === 'object') + message.int64Data[i] = new $util.LongBits( + object.int64Data[i].low >>> 0, + object.int64Data[i].high >>> 0, + ).toNumber(); + } + if (object.name != null) message.name = String(object.name); + if (object.docString != null) message.docString = String(object.docString); + if (object.rawData != null) + if (typeof object.rawData === 'string') + $util.base64.decode( + object.rawData, + (message.rawData = $util.newBuffer($util.base64.length(object.rawData))), + 0, + ); + else if (object.rawData.length >= 0) message.rawData = object.rawData; + if (object.externalData) { + if (!Array.isArray(object.externalData)) throw TypeError('.onnx.TensorProto.externalData: array expected'); + message.externalData = []; + for (var i = 0; i < object.externalData.length; ++i) { + if (typeof object.externalData[i] !== 'object') + throw TypeError('.onnx.TensorProto.externalData: object expected'); + message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); + } + } + switch (object.dataLocation) { + default: + if (typeof object.dataLocation === 'number') { + message.dataLocation = object.dataLocation; + break; + } + break; + case 'DEFAULT': + case 0: + message.dataLocation = 0; + break; + case 'EXTERNAL': + case 1: + message.dataLocation = 1; + break; + } + if (object.doubleData) { + if (!Array.isArray(object.doubleData)) throw TypeError('.onnx.TensorProto.doubleData: array expected'); + message.doubleData = []; + for (var i = 0; i < object.doubleData.length; ++i) message.doubleData[i] = Number(object.doubleData[i]); + } + if (object.uint64Data) { + if (!Array.isArray(object.uint64Data)) throw TypeError('.onnx.TensorProto.uint64Data: array expected'); + message.uint64Data = []; + for (var i = 0; i < object.uint64Data.length; ++i) + if ($util.Long) (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; + else if (typeof object.uint64Data[i] === 'string') message.uint64Data[i] = parseInt(object.uint64Data[i], 10); + else if (typeof object.uint64Data[i] === 'number') message.uint64Data[i] = object.uint64Data[i]; + else if (typeof object.uint64Data[i] === 'object') + message.uint64Data[i] = new $util.LongBits( + object.uint64Data[i].low >>> 0, + object.uint64Data[i].high >>> 0, + ).toNumber(true); + } + return message; + }; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto + * @static + * @param {onnx.TensorProto} message TensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.dims = []; + object.floatData = []; + object.int32Data = []; + object.stringData = []; + object.int64Data = []; + object.doubleData = []; + object.uint64Data = []; + object.externalData = []; + } + if (options.defaults) { + object.dataType = 0; + object.segment = null; + object.name = ''; + if (options.bytes === String) object.rawData = ''; + else { + object.rawData = []; + if (options.bytes !== Array) object.rawData = $util.newBuffer(object.rawData); + } + object.docString = ''; + object.dataLocation = options.enums === String ? 'DEFAULT' : 0; + } + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === 'number') + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.dims[j]) + : options.longs === Number + ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() + : message.dims[j]; + } + if (message.dataType != null && message.hasOwnProperty('dataType')) object.dataType = message.dataType; + if (message.segment != null && message.hasOwnProperty('segment')) + object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); + if (message.floatData && message.floatData.length) { + object.floatData = []; + for (var j = 0; j < message.floatData.length; ++j) + object.floatData[j] = + options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; + } + if (message.int32Data && message.int32Data.length) { + object.int32Data = []; + for (var j = 0; j < message.int32Data.length; ++j) object.int32Data[j] = message.int32Data[j]; + } + if (message.stringData && message.stringData.length) { + object.stringData = []; + for (var j = 0; j < message.stringData.length; ++j) + object.stringData[j] = + options.bytes === String + ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) + : options.bytes === Array + ? Array.prototype.slice.call(message.stringData[j]) + : message.stringData[j]; + } + if (message.int64Data && message.int64Data.length) { + object.int64Data = []; + for (var j = 0; j < message.int64Data.length; ++j) + if (typeof message.int64Data[j] === 'number') + object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; + else + object.int64Data[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.int64Data[j]) + : options.longs === Number + ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() + : message.int64Data[j]; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.rawData != null && message.hasOwnProperty('rawData')) + object.rawData = + options.bytes === String + ? $util.base64.encode(message.rawData, 0, message.rawData.length) + : options.bytes === Array + ? Array.prototype.slice.call(message.rawData) + : message.rawData; + if (message.doubleData && message.doubleData.length) { + object.doubleData = []; + for (var j = 0; j < message.doubleData.length; ++j) + object.doubleData[j] = + options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; + } + if (message.uint64Data && message.uint64Data.length) { + object.uint64Data = []; + for (var j = 0; j < message.uint64Data.length; ++j) + if (typeof message.uint64Data[j] === 'number') + object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; + else + object.uint64Data[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.uint64Data[j]) + : options.longs === Number + ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) + : message.uint64Data[j]; + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.externalData && message.externalData.length) { + object.externalData = []; + for (var j = 0; j < message.externalData.length; ++j) + object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); + } + if (message.dataLocation != null && message.hasOwnProperty('dataLocation')) + object.dataLocation = + options.enums === String + ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined + ? message.dataLocation + : $root.onnx.TensorProto.DataLocation[message.dataLocation] + : message.dataLocation; + return object; + }; + + /** + * Converts this TensorProto to JSON. + * @function toJSON + * @memberof onnx.TensorProto + * @instance + * @returns {Object.} JSON object + */ + TensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; /** - * Namespace onnx. - * @exports onnx - * @namespace + * Gets the default type url for TensorProto + * @function getTypeUrl + * @memberof onnx.TensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url */ - var onnx = {}; + TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorProto'; + }; /** - * Version enum. - * @name onnx.Version + * DataType enum. + * @name onnx.TensorProto.DataType * @enum {number} - * @property {number} _START_VERSION=0 _START_VERSION value - * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value - * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value - * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value - * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value - * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value - * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value - * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value - * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value - * @property {number} IR_VERSION=9 IR_VERSION value - */ - onnx.Version = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "_START_VERSION"] = 0; - values[valuesById[1] = "IR_VERSION_2017_10_10"] = 1; - values[valuesById[2] = "IR_VERSION_2017_10_30"] = 2; - values[valuesById[3] = "IR_VERSION_2017_11_3"] = 3; - values[valuesById[4] = "IR_VERSION_2019_1_22"] = 4; - values[valuesById[5] = "IR_VERSION_2019_3_18"] = 5; - values[valuesById[6] = "IR_VERSION_2019_9_19"] = 6; - values[valuesById[7] = "IR_VERSION_2020_5_8"] = 7; - values[valuesById[8] = "IR_VERSION_2021_7_30"] = 8; - values[valuesById[9] = "IR_VERSION"] = 9; - return values; + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} UINT8=2 UINT8 value + * @property {number} INT8=3 INT8 value + * @property {number} UINT16=4 UINT16 value + * @property {number} INT16=5 INT16 value + * @property {number} INT32=6 INT32 value + * @property {number} INT64=7 INT64 value + * @property {number} STRING=8 STRING value + * @property {number} BOOL=9 BOOL value + * @property {number} FLOAT16=10 FLOAT16 value + * @property {number} DOUBLE=11 DOUBLE value + * @property {number} UINT32=12 UINT32 value + * @property {number} UINT64=13 UINT64 value + * @property {number} COMPLEX64=14 COMPLEX64 value + * @property {number} COMPLEX128=15 COMPLEX128 value + * @property {number} BFLOAT16=16 BFLOAT16 value + * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value + * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value + * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value + * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value + */ + TensorProto.DataType = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = 'UNDEFINED')] = 0; + values[(valuesById[1] = 'FLOAT')] = 1; + values[(valuesById[2] = 'UINT8')] = 2; + values[(valuesById[3] = 'INT8')] = 3; + values[(valuesById[4] = 'UINT16')] = 4; + values[(valuesById[5] = 'INT16')] = 5; + values[(valuesById[6] = 'INT32')] = 6; + values[(valuesById[7] = 'INT64')] = 7; + values[(valuesById[8] = 'STRING')] = 8; + values[(valuesById[9] = 'BOOL')] = 9; + values[(valuesById[10] = 'FLOAT16')] = 10; + values[(valuesById[11] = 'DOUBLE')] = 11; + values[(valuesById[12] = 'UINT32')] = 12; + values[(valuesById[13] = 'UINT64')] = 13; + values[(valuesById[14] = 'COMPLEX64')] = 14; + values[(valuesById[15] = 'COMPLEX128')] = 15; + values[(valuesById[16] = 'BFLOAT16')] = 16; + values[(valuesById[17] = 'FLOAT8E4M3FN')] = 17; + values[(valuesById[18] = 'FLOAT8E4M3FNUZ')] = 18; + values[(valuesById[19] = 'FLOAT8E5M2')] = 19; + values[(valuesById[20] = 'FLOAT8E5M2FNUZ')] = 20; + return values; })(); - onnx.AttributeProto = (function() { - - /** - * Properties of an AttributeProto. - * @memberof onnx - * @interface IAttributeProto - * @property {string|null} [name] AttributeProto name - * @property {string|null} [refAttrName] AttributeProto refAttrName - * @property {string|null} [docString] AttributeProto docString - * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type - * @property {number|null} [f] AttributeProto f - * @property {number|Long|null} [i] AttributeProto i - * @property {Uint8Array|null} [s] AttributeProto s - * @property {onnx.ITensorProto|null} [t] AttributeProto t - * @property {onnx.IGraphProto|null} [g] AttributeProto g - * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor - * @property {onnx.ITypeProto|null} [tp] AttributeProto tp - * @property {Array.|null} [floats] AttributeProto floats - * @property {Array.|null} [ints] AttributeProto ints - * @property {Array.|null} [strings] AttributeProto strings - * @property {Array.|null} [tensors] AttributeProto tensors - * @property {Array.|null} [graphs] AttributeProto graphs - * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors - * @property {Array.|null} [typeProtos] AttributeProto typeProtos - */ - - /** - * Constructs a new AttributeProto. - * @memberof onnx - * @classdesc Represents an AttributeProto. - * @implements IAttributeProto - * @constructor - * @param {onnx.IAttributeProto=} [properties] Properties to set - */ - function AttributeProto(properties) { - this.floats = []; - this.ints = []; - this.strings = []; - this.tensors = []; - this.graphs = []; - this.sparseTensors = []; - this.typeProtos = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } - - /** - * AttributeProto name. - * @member {string} name - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.name = ""; - - /** - * AttributeProto refAttrName. - * @member {string} refAttrName - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.refAttrName = ""; - - /** - * AttributeProto docString. - * @member {string} docString - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.docString = ""; - - /** - * AttributeProto type. - * @member {onnx.AttributeProto.AttributeType} type - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.type = 0; - - /** - * AttributeProto f. - * @member {number} f - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.f = 0; - - /** - * AttributeProto i. - * @member {number|Long} i - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * AttributeProto s. - * @member {Uint8Array} s - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.s = $util.newBuffer([]); - - /** - * AttributeProto t. - * @member {onnx.ITensorProto|null|undefined} t - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.t = null; - - /** - * AttributeProto g. - * @member {onnx.IGraphProto|null|undefined} g - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.g = null; - - /** - * AttributeProto sparseTensor. - * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.sparseTensor = null; - - /** - * AttributeProto tp. - * @member {onnx.ITypeProto|null|undefined} tp - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.tp = null; - - /** - * AttributeProto floats. - * @member {Array.} floats - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.floats = $util.emptyArray; - - /** - * AttributeProto ints. - * @member {Array.} ints - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.ints = $util.emptyArray; - - /** - * AttributeProto strings. - * @member {Array.} strings - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.strings = $util.emptyArray; - - /** - * AttributeProto tensors. - * @member {Array.} tensors - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.tensors = $util.emptyArray; - - /** - * AttributeProto graphs. - * @member {Array.} graphs - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.graphs = $util.emptyArray; - - /** - * AttributeProto sparseTensors. - * @member {Array.} sparseTensors - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.sparseTensors = $util.emptyArray; - - /** - * AttributeProto typeProtos. - * @member {Array.} typeProtos - * @memberof onnx.AttributeProto - * @instance - */ - AttributeProto.prototype.typeProtos = $util.emptyArray; - - /** - * Creates a new AttributeProto instance using the specified properties. - * @function create - * @memberof onnx.AttributeProto - * @static - * @param {onnx.IAttributeProto=} [properties] Properties to set - * @returns {onnx.AttributeProto} AttributeProto instance - */ - AttributeProto.create = function create(properties) { - return new AttributeProto(properties); - }; - - /** - * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. - * @function encode - * @memberof onnx.AttributeProto - * @static - * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - AttributeProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); - if (message.f != null && Object.hasOwnProperty.call(message, "f")) - writer.uint32(/* id 2, wireType 5 =*/21).float(message.f); - if (message.i != null && Object.hasOwnProperty.call(message, "i")) - writer.uint32(/* id 3, wireType 0 =*/24).int64(message.i); - if (message.s != null && Object.hasOwnProperty.call(message, "s")) - writer.uint32(/* id 4, wireType 2 =*/34).bytes(message.s); - if (message.t != null && Object.hasOwnProperty.call(message, "t")) - $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); - if (message.g != null && Object.hasOwnProperty.call(message, "g")) - $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/50).fork()).ldelim(); - if (message.floats != null && message.floats.length) { - writer.uint32(/* id 7, wireType 2 =*/58).fork(); - for (var i = 0; i < message.floats.length; ++i) - writer.float(message.floats[i]); - writer.ldelim(); - } - if (message.ints != null && message.ints.length) { - writer.uint32(/* id 8, wireType 2 =*/66).fork(); - for (var i = 0; i < message.ints.length; ++i) - writer.int64(message.ints[i]); - writer.ldelim(); - } - if (message.strings != null && message.strings.length) - for (var i = 0; i < message.strings.length; ++i) - writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.strings[i]); - if (message.tensors != null && message.tensors.length) - for (var i = 0; i < message.tensors.length; ++i) - $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/82).fork()).ldelim(); - if (message.graphs != null && message.graphs.length) - for (var i = 0; i < message.graphs.length; ++i) - $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 13, wireType 2 =*/106).string(message.docString); - if (message.tp != null && Object.hasOwnProperty.call(message, "tp")) - $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); - if (message.typeProtos != null && message.typeProtos.length) - for (var i = 0; i < message.typeProtos.length; ++i) - $root.onnx.TypeProto.encode(message.typeProtos[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); - if (message.type != null && Object.hasOwnProperty.call(message, "type")) - writer.uint32(/* id 20, wireType 0 =*/160).int32(message.type); - if (message.refAttrName != null && Object.hasOwnProperty.call(message, "refAttrName")) - writer.uint32(/* id 21, wireType 2 =*/170).string(message.refAttrName); - if (message.sparseTensor != null && Object.hasOwnProperty.call(message, "sparseTensor")) - $root.onnx.SparseTensorProto.encode(message.sparseTensor, writer.uint32(/* id 22, wireType 2 =*/178).fork()).ldelim(); - if (message.sparseTensors != null && message.sparseTensors.length) - for (var i = 0; i < message.sparseTensors.length; ++i) - $root.onnx.SparseTensorProto.encode(message.sparseTensors[i], writer.uint32(/* id 23, wireType 2 =*/186).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.AttributeProto - * @static - * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes an AttributeProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.AttributeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.AttributeProto} AttributeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - AttributeProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.AttributeProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.name = reader.string(); - break; - } - case 21: { - message.refAttrName = reader.string(); - break; - } - case 13: { - message.docString = reader.string(); - break; - } - case 20: { - message.type = reader.int32(); - break; - } - case 2: { - message.f = reader.float(); - break; - } - case 3: { - message.i = reader.int64(); - break; - } - case 4: { - message.s = reader.bytes(); - break; - } - case 5: { - message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); - break; - } - case 6: { - message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); - break; - } - case 22: { - message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); - break; - } - case 14: { - message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - case 7: { - if (!(message.floats && message.floats.length)) - message.floats = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.floats.push(reader.float()); - } else - message.floats.push(reader.float()); - break; - } - case 8: { - if (!(message.ints && message.ints.length)) - message.ints = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.ints.push(reader.int64()); - } else - message.ints.push(reader.int64()); - break; - } - case 9: { - if (!(message.strings && message.strings.length)) - message.strings = []; - message.strings.push(reader.bytes()); - break; - } - case 10: { - if (!(message.tensors && message.tensors.length)) - message.tensors = []; - message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); - break; - } - case 11: { - if (!(message.graphs && message.graphs.length)) - message.graphs = []; - message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); - break; - } - case 23: { - if (!(message.sparseTensors && message.sparseTensors.length)) - message.sparseTensors = []; - message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); - break; - } - case 15: { - if (!(message.typeProtos && message.typeProtos.length)) - message.typeProtos = []; - message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes an AttributeProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.AttributeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.AttributeProto} AttributeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - AttributeProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies an AttributeProto message. - * @function verify - * @memberof onnx.AttributeProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - AttributeProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) - if (!$util.isString(message.refAttrName)) - return "refAttrName: string expected"; - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.type != null && message.hasOwnProperty("type")) - switch (message.type) { - default: - return "type: enum value expected"; - case 0: - case 1: - case 2: - case 3: - case 4: - case 5: - case 11: - case 13: - case 6: - case 7: - case 8: - case 9: - case 10: - case 12: - case 14: - break; - } - if (message.f != null && message.hasOwnProperty("f")) - if (typeof message.f !== "number") - return "f: number expected"; - if (message.i != null && message.hasOwnProperty("i")) - if (!$util.isInteger(message.i) && !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high))) - return "i: integer|Long expected"; - if (message.s != null && message.hasOwnProperty("s")) - if (!(message.s && typeof message.s.length === "number" || $util.isString(message.s))) - return "s: buffer expected"; - if (message.t != null && message.hasOwnProperty("t")) { - var error = $root.onnx.TensorProto.verify(message.t); - if (error) - return "t." + error; - } - if (message.g != null && message.hasOwnProperty("g")) { - var error = $root.onnx.GraphProto.verify(message.g); - if (error) - return "g." + error; - } - if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) { - var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); - if (error) - return "sparseTensor." + error; - } - if (message.tp != null && message.hasOwnProperty("tp")) { - var error = $root.onnx.TypeProto.verify(message.tp); - if (error) - return "tp." + error; - } - if (message.floats != null && message.hasOwnProperty("floats")) { - if (!Array.isArray(message.floats)) - return "floats: array expected"; - for (var i = 0; i < message.floats.length; ++i) - if (typeof message.floats[i] !== "number") - return "floats: number[] expected"; - } - if (message.ints != null && message.hasOwnProperty("ints")) { - if (!Array.isArray(message.ints)) - return "ints: array expected"; - for (var i = 0; i < message.ints.length; ++i) - if (!$util.isInteger(message.ints[i]) && !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high))) - return "ints: integer|Long[] expected"; - } - if (message.strings != null && message.hasOwnProperty("strings")) { - if (!Array.isArray(message.strings)) - return "strings: array expected"; - for (var i = 0; i < message.strings.length; ++i) - if (!(message.strings[i] && typeof message.strings[i].length === "number" || $util.isString(message.strings[i]))) - return "strings: buffer[] expected"; - } - if (message.tensors != null && message.hasOwnProperty("tensors")) { - if (!Array.isArray(message.tensors)) - return "tensors: array expected"; - for (var i = 0; i < message.tensors.length; ++i) { - var error = $root.onnx.TensorProto.verify(message.tensors[i]); - if (error) - return "tensors." + error; - } - } - if (message.graphs != null && message.hasOwnProperty("graphs")) { - if (!Array.isArray(message.graphs)) - return "graphs: array expected"; - for (var i = 0; i < message.graphs.length; ++i) { - var error = $root.onnx.GraphProto.verify(message.graphs[i]); - if (error) - return "graphs." + error; - } - } - if (message.sparseTensors != null && message.hasOwnProperty("sparseTensors")) { - if (!Array.isArray(message.sparseTensors)) - return "sparseTensors: array expected"; - for (var i = 0; i < message.sparseTensors.length; ++i) { - var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); - if (error) - return "sparseTensors." + error; - } - } - if (message.typeProtos != null && message.hasOwnProperty("typeProtos")) { - if (!Array.isArray(message.typeProtos)) - return "typeProtos: array expected"; - for (var i = 0; i < message.typeProtos.length; ++i) { - var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); - if (error) - return "typeProtos." + error; - } + TensorProto.Segment = (function () { + /** + * Properties of a Segment. + * @memberof onnx.TensorProto + * @interface ISegment + * @property {number|Long|null} [begin] Segment begin + * @property {number|Long|null} [end] Segment end + */ + + /** + * Constructs a new Segment. + * @memberof onnx.TensorProto + * @classdesc Represents a Segment. + * @implements ISegment + * @constructor + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + */ + function Segment(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Segment begin. + * @member {number|Long} begin + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * Segment end. + * @member {number|Long} end + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.end = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + + /** + * Creates a new Segment instance using the specified properties. + * @function create + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + * @returns {onnx.TensorProto.Segment} Segment instance + */ + Segment.create = function create(properties) { + return new Segment(properties); + }; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.begin != null && Object.hasOwnProperty.call(message, 'begin')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int64(message.begin); + if (message.end != null && Object.hasOwnProperty.call(message, 'end')) + writer.uint32(/* id 2, wireType 0 =*/ 16).int64(message.end); + return writer; + }; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorProto.Segment(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.begin = reader.int64(); + break; + } + case 2: { + message.end = reader.int64(); + break; } - return null; - }; - - /** - * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.AttributeProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.AttributeProto} AttributeProto - */ - AttributeProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.AttributeProto) - return object; - var message = new $root.onnx.AttributeProto(); - if (object.name != null) - message.name = String(object.name); - if (object.refAttrName != null) - message.refAttrName = String(object.refAttrName); - if (object.docString != null) - message.docString = String(object.docString); - switch (object.type) { default: - if (typeof object.type === "number") { - message.type = object.type; - break; - } - break; - case "UNDEFINED": - case 0: - message.type = 0; - break; - case "FLOAT": - case 1: - message.type = 1; - break; - case "INT": - case 2: - message.type = 2; - break; - case "STRING": - case 3: - message.type = 3; - break; - case "TENSOR": - case 4: - message.type = 4; - break; - case "GRAPH": - case 5: - message.type = 5; - break; - case "SPARSE_TENSOR": - case 11: - message.type = 11; - break; - case "TYPE_PROTO": - case 13: - message.type = 13; - break; - case "FLOATS": - case 6: - message.type = 6; - break; - case "INTS": - case 7: - message.type = 7; - break; - case "STRINGS": - case 8: - message.type = 8; - break; - case "TENSORS": - case 9: - message.type = 9; - break; - case "GRAPHS": - case 10: - message.type = 10; - break; - case "SPARSE_TENSORS": - case 12: - message.type = 12; - break; - case "TYPE_PROTOS": - case 14: - message.type = 14; - break; - } - if (object.f != null) - message.f = Number(object.f); - if (object.i != null) - if ($util.Long) - (message.i = $util.Long.fromValue(object.i)).unsigned = false; - else if (typeof object.i === "string") - message.i = parseInt(object.i, 10); - else if (typeof object.i === "number") - message.i = object.i; - else if (typeof object.i === "object") - message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); - if (object.s != null) - if (typeof object.s === "string") - $util.base64.decode(object.s, message.s = $util.newBuffer($util.base64.length(object.s)), 0); - else if (object.s.length >= 0) - message.s = object.s; - if (object.t != null) { - if (typeof object.t !== "object") - throw TypeError(".onnx.AttributeProto.t: object expected"); - message.t = $root.onnx.TensorProto.fromObject(object.t); - } - if (object.g != null) { - if (typeof object.g !== "object") - throw TypeError(".onnx.AttributeProto.g: object expected"); - message.g = $root.onnx.GraphProto.fromObject(object.g); - } - if (object.sparseTensor != null) { - if (typeof object.sparseTensor !== "object") - throw TypeError(".onnx.AttributeProto.sparseTensor: object expected"); - message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); - } - if (object.tp != null) { - if (typeof object.tp !== "object") - throw TypeError(".onnx.AttributeProto.tp: object expected"); - message.tp = $root.onnx.TypeProto.fromObject(object.tp); - } - if (object.floats) { - if (!Array.isArray(object.floats)) - throw TypeError(".onnx.AttributeProto.floats: array expected"); - message.floats = []; - for (var i = 0; i < object.floats.length; ++i) - message.floats[i] = Number(object.floats[i]); - } - if (object.ints) { - if (!Array.isArray(object.ints)) - throw TypeError(".onnx.AttributeProto.ints: array expected"); - message.ints = []; - for (var i = 0; i < object.ints.length; ++i) - if ($util.Long) - (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; - else if (typeof object.ints[i] === "string") - message.ints[i] = parseInt(object.ints[i], 10); - else if (typeof object.ints[i] === "number") - message.ints[i] = object.ints[i]; - else if (typeof object.ints[i] === "object") - message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); - } - if (object.strings) { - if (!Array.isArray(object.strings)) - throw TypeError(".onnx.AttributeProto.strings: array expected"); - message.strings = []; - for (var i = 0; i < object.strings.length; ++i) - if (typeof object.strings[i] === "string") - $util.base64.decode(object.strings[i], message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i])), 0); - else if (object.strings[i].length >= 0) - message.strings[i] = object.strings[i]; - } - if (object.tensors) { - if (!Array.isArray(object.tensors)) - throw TypeError(".onnx.AttributeProto.tensors: array expected"); - message.tensors = []; - for (var i = 0; i < object.tensors.length; ++i) { - if (typeof object.tensors[i] !== "object") - throw TypeError(".onnx.AttributeProto.tensors: object expected"); - message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); - } - } - if (object.graphs) { - if (!Array.isArray(object.graphs)) - throw TypeError(".onnx.AttributeProto.graphs: array expected"); - message.graphs = []; - for (var i = 0; i < object.graphs.length; ++i) { - if (typeof object.graphs[i] !== "object") - throw TypeError(".onnx.AttributeProto.graphs: object expected"); - message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); - } - } - if (object.sparseTensors) { - if (!Array.isArray(object.sparseTensors)) - throw TypeError(".onnx.AttributeProto.sparseTensors: array expected"); - message.sparseTensors = []; - for (var i = 0; i < object.sparseTensors.length; ++i) { - if (typeof object.sparseTensors[i] !== "object") - throw TypeError(".onnx.AttributeProto.sparseTensors: object expected"); - message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); - } - } - if (object.typeProtos) { - if (!Array.isArray(object.typeProtos)) - throw TypeError(".onnx.AttributeProto.typeProtos: array expected"); - message.typeProtos = []; - for (var i = 0; i < object.typeProtos.length; ++i) { - if (typeof object.typeProtos[i] !== "object") - throw TypeError(".onnx.AttributeProto.typeProtos: object expected"); - message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); - } - } - return message; - }; - - /** - * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.AttributeProto - * @static - * @param {onnx.AttributeProto} message AttributeProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - AttributeProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.floats = []; - object.ints = []; - object.strings = []; - object.tensors = []; - object.graphs = []; - object.typeProtos = []; - object.sparseTensors = []; - } - if (options.defaults) { - object.name = ""; - object.f = 0; - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.i = options.longs === String ? "0" : 0; - if (options.bytes === String) - object.s = ""; - else { - object.s = []; - if (options.bytes !== Array) - object.s = $util.newBuffer(object.s); - } - object.t = null; - object.g = null; - object.docString = ""; - object.tp = null; - object.type = options.enums === String ? "UNDEFINED" : 0; - object.refAttrName = ""; - object.sparseTensor = null; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.f != null && message.hasOwnProperty("f")) - object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; - if (message.i != null && message.hasOwnProperty("i")) - if (typeof message.i === "number") - object.i = options.longs === String ? String(message.i) : message.i; - else - object.i = options.longs === String ? $util.Long.prototype.toString.call(message.i) : options.longs === Number ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() : message.i; - if (message.s != null && message.hasOwnProperty("s")) - object.s = options.bytes === String ? $util.base64.encode(message.s, 0, message.s.length) : options.bytes === Array ? Array.prototype.slice.call(message.s) : message.s; - if (message.t != null && message.hasOwnProperty("t")) - object.t = $root.onnx.TensorProto.toObject(message.t, options); - if (message.g != null && message.hasOwnProperty("g")) - object.g = $root.onnx.GraphProto.toObject(message.g, options); - if (message.floats && message.floats.length) { - object.floats = []; - for (var j = 0; j < message.floats.length; ++j) - object.floats[j] = options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; - } - if (message.ints && message.ints.length) { - object.ints = []; - for (var j = 0; j < message.ints.length; ++j) - if (typeof message.ints[j] === "number") - object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; - else - object.ints[j] = options.longs === String ? $util.Long.prototype.toString.call(message.ints[j]) : options.longs === Number ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() : message.ints[j]; - } - if (message.strings && message.strings.length) { - object.strings = []; - for (var j = 0; j < message.strings.length; ++j) - object.strings[j] = options.bytes === String ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.strings[j]) : message.strings[j]; - } - if (message.tensors && message.tensors.length) { - object.tensors = []; - for (var j = 0; j < message.tensors.length; ++j) - object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); - } - if (message.graphs && message.graphs.length) { - object.graphs = []; - for (var j = 0; j < message.graphs.length; ++j) - object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.tp != null && message.hasOwnProperty("tp")) - object.tp = $root.onnx.TypeProto.toObject(message.tp, options); - if (message.typeProtos && message.typeProtos.length) { - object.typeProtos = []; - for (var j = 0; j < message.typeProtos.length; ++j) - object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); - } - if (message.type != null && message.hasOwnProperty("type")) - object.type = options.enums === String ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined ? message.type : $root.onnx.AttributeProto.AttributeType[message.type] : message.type; - if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) - object.refAttrName = message.refAttrName; - if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) - object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); - if (message.sparseTensors && message.sparseTensors.length) { - object.sparseTensors = []; - for (var j = 0; j < message.sparseTensors.length; ++j) - object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); - } - return object; - }; - - /** - * Converts this AttributeProto to JSON. - * @function toJSON - * @memberof onnx.AttributeProto - * @instance - * @returns {Object.} JSON object - */ - AttributeProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for AttributeProto - * @function getTypeUrl - * @memberof onnx.AttributeProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.AttributeProto"; - }; - - /** - * AttributeType enum. - * @name onnx.AttributeProto.AttributeType - * @enum {number} - * @property {number} UNDEFINED=0 UNDEFINED value - * @property {number} FLOAT=1 FLOAT value - * @property {number} INT=2 INT value - * @property {number} STRING=3 STRING value - * @property {number} TENSOR=4 TENSOR value - * @property {number} GRAPH=5 GRAPH value - * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value - * @property {number} TYPE_PROTO=13 TYPE_PROTO value - * @property {number} FLOATS=6 FLOATS value - * @property {number} INTS=7 INTS value - * @property {number} STRINGS=8 STRINGS value - * @property {number} TENSORS=9 TENSORS value - * @property {number} GRAPHS=10 GRAPHS value - * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value - * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value - */ - AttributeProto.AttributeType = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "UNDEFINED"] = 0; - values[valuesById[1] = "FLOAT"] = 1; - values[valuesById[2] = "INT"] = 2; - values[valuesById[3] = "STRING"] = 3; - values[valuesById[4] = "TENSOR"] = 4; - values[valuesById[5] = "GRAPH"] = 5; - values[valuesById[11] = "SPARSE_TENSOR"] = 11; - values[valuesById[13] = "TYPE_PROTO"] = 13; - values[valuesById[6] = "FLOATS"] = 6; - values[valuesById[7] = "INTS"] = 7; - values[valuesById[8] = "STRINGS"] = 8; - values[valuesById[9] = "TENSORS"] = 9; - values[valuesById[10] = "GRAPHS"] = 10; - values[valuesById[12] = "SPARSE_TENSORS"] = 12; - values[valuesById[14] = "TYPE_PROTOS"] = 14; - return values; - })(); - - return AttributeProto; + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Segment message. + * @function verify + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Segment.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.begin != null && message.hasOwnProperty('begin')) + if ( + !$util.isInteger(message.begin) && + !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high)) + ) + return 'begin: integer|Long expected'; + if (message.end != null && message.hasOwnProperty('end')) + if ( + !$util.isInteger(message.end) && + !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high)) + ) + return 'end: integer|Long expected'; + return null; + }; + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto.Segment} Segment + */ + Segment.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto.Segment) return object; + var message = new $root.onnx.TensorProto.Segment(); + if (object.begin != null) + if ($util.Long) (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; + else if (typeof object.begin === 'string') message.begin = parseInt(object.begin, 10); + else if (typeof object.begin === 'number') message.begin = object.begin; + else if (typeof object.begin === 'object') + message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); + if (object.end != null) + if ($util.Long) (message.end = $util.Long.fromValue(object.end)).unsigned = false; + else if (typeof object.end === 'string') message.end = parseInt(object.end, 10); + else if (typeof object.end === 'number') message.end = object.end; + else if (typeof object.end === 'object') + message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.Segment} message Segment + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Segment.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.begin = + options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.begin = options.longs === String ? '0' : 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.end = options.longs === String ? '0' : 0; + } + if (message.begin != null && message.hasOwnProperty('begin')) + if (typeof message.begin === 'number') + object.begin = options.longs === String ? String(message.begin) : message.begin; + else + object.begin = + options.longs === String + ? $util.Long.prototype.toString.call(message.begin) + : options.longs === Number + ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() + : message.begin; + if (message.end != null && message.hasOwnProperty('end')) + if (typeof message.end === 'number') + object.end = options.longs === String ? String(message.end) : message.end; + else + object.end = + options.longs === String + ? $util.Long.prototype.toString.call(message.end) + : options.longs === Number + ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() + : message.end; + return object; + }; + + /** + * Converts this Segment to JSON. + * @function toJSON + * @memberof onnx.TensorProto.Segment + * @instance + * @returns {Object.} JSON object + */ + Segment.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Segment + * @function getTypeUrl + * @memberof onnx.TensorProto.Segment + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorProto.Segment'; + }; + + return Segment; + })(); + + /** + * DataLocation enum. + * @name onnx.TensorProto.DataLocation + * @enum {number} + * @property {number} DEFAULT=0 DEFAULT value + * @property {number} EXTERNAL=1 EXTERNAL value + */ + TensorProto.DataLocation = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = 'DEFAULT')] = 0; + values[(valuesById[1] = 'EXTERNAL')] = 1; + return values; })(); - onnx.ValueInfoProto = (function() { - - /** - * Properties of a ValueInfoProto. - * @memberof onnx - * @interface IValueInfoProto - * @property {string|null} [name] ValueInfoProto name - * @property {onnx.ITypeProto|null} [type] ValueInfoProto type - * @property {string|null} [docString] ValueInfoProto docString - */ - - /** - * Constructs a new ValueInfoProto. - * @memberof onnx - * @classdesc Represents a ValueInfoProto. - * @implements IValueInfoProto - * @constructor - * @param {onnx.IValueInfoProto=} [properties] Properties to set - */ - function ValueInfoProto(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + return TensorProto; + })(); + + onnx.SparseTensorProto = (function () { + /** + * Properties of a SparseTensorProto. + * @memberof onnx + * @interface ISparseTensorProto + * @property {onnx.ITensorProto|null} [values] SparseTensorProto values + * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices + * @property {Array.|null} [dims] SparseTensorProto dims + */ + + /** + * Constructs a new SparseTensorProto. + * @memberof onnx + * @classdesc Represents a SparseTensorProto. + * @implements ISparseTensorProto + * @constructor + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + */ + function SparseTensorProto(properties) { + this.dims = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensorProto values. + * @member {onnx.ITensorProto|null|undefined} values + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.values = null; + + /** + * SparseTensorProto indices. + * @member {onnx.ITensorProto|null|undefined} indices + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.indices = null; + + /** + * SparseTensorProto dims. + * @member {Array.} dims + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.dims = $util.emptyArray; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @function create + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + * @returns {onnx.SparseTensorProto} SparseTensorProto instance + */ + SparseTensorProto.create = function create(properties) { + return new SparseTensorProto(properties); + }; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.values != null && Object.hasOwnProperty.call(message, 'values')) + $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + if (message.indices != null && Object.hasOwnProperty.call(message, 'indices')) + $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 3, wireType 2 =*/ 26).fork(); + for (var i = 0; i < message.dims.length; ++i) writer.int64(message.dims[i]); + writer.ldelim(); + } + return writer; + }; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.SparseTensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.dims && message.dims.length)) message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) message.dims.push(reader.int64()); + } else message.dims.push(reader.int64()); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * ValueInfoProto name. - * @member {string} name - * @memberof onnx.ValueInfoProto - * @instance - */ - ValueInfoProto.prototype.name = ""; - - /** - * ValueInfoProto type. - * @member {onnx.ITypeProto|null|undefined} type - * @memberof onnx.ValueInfoProto - * @instance - */ - ValueInfoProto.prototype.type = null; - - /** - * ValueInfoProto docString. - * @member {string} docString - * @memberof onnx.ValueInfoProto - * @instance - */ - ValueInfoProto.prototype.docString = ""; - - /** - * Creates a new ValueInfoProto instance using the specified properties. - * @function create - * @memberof onnx.ValueInfoProto - * @static - * @param {onnx.IValueInfoProto=} [properties] Properties to set - * @returns {onnx.ValueInfoProto} ValueInfoProto instance - */ - ValueInfoProto.create = function create(properties) { - return new ValueInfoProto(properties); - }; - - /** - * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. - * @function encode - * @memberof onnx.ValueInfoProto - * @static - * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - ValueInfoProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); - if (message.type != null && Object.hasOwnProperty.call(message, "type")) - $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 3, wireType 2 =*/26).string(message.docString); - return writer; - }; - - /** - * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.ValueInfoProto - * @static - * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a ValueInfoProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.ValueInfoProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.ValueInfoProto} ValueInfoProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - ValueInfoProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ValueInfoProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.name = reader.string(); - break; - } - case 2: { - message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - case 3: { - message.docString = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.ValueInfoProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.ValueInfoProto} ValueInfoProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a ValueInfoProto message. - * @function verify - * @memberof onnx.ValueInfoProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - ValueInfoProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.type != null && message.hasOwnProperty("type")) { - var error = $root.onnx.TypeProto.verify(message.type); - if (error) - return "type." + error; - } - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - return null; - }; - - /** - * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.ValueInfoProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.ValueInfoProto} ValueInfoProto - */ - ValueInfoProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.ValueInfoProto) - return object; - var message = new $root.onnx.ValueInfoProto(); - if (object.name != null) - message.name = String(object.name); - if (object.type != null) { - if (typeof object.type !== "object") - throw TypeError(".onnx.ValueInfoProto.type: object expected"); - message.type = $root.onnx.TypeProto.fromObject(object.type); - } - if (object.docString != null) - message.docString = String(object.docString); - return message; - }; - - /** - * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.ValueInfoProto - * @static - * @param {onnx.ValueInfoProto} message ValueInfoProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - ValueInfoProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.name = ""; - object.type = null; - object.docString = ""; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.type != null && message.hasOwnProperty("type")) - object.type = $root.onnx.TypeProto.toObject(message.type, options); - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - return object; - }; - - /** - * Converts this ValueInfoProto to JSON. - * @function toJSON - * @memberof onnx.ValueInfoProto - * @instance - * @returns {Object.} JSON object - */ - ValueInfoProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for ValueInfoProto - * @function getTypeUrl - * @memberof onnx.ValueInfoProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.ValueInfoProto"; - }; + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; - return ValueInfoProto; - })(); + /** + * Verifies a SparseTensorProto message. + * @function verify + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensorProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.values != null && message.hasOwnProperty('values')) { + var error = $root.onnx.TensorProto.verify(message.values); + if (error) return 'values.' + error; + } + if (message.indices != null && message.hasOwnProperty('indices')) { + var error = $root.onnx.TensorProto.verify(message.indices); + if (error) return 'indices.' + error; + } + if (message.dims != null && message.hasOwnProperty('dims')) { + if (!Array.isArray(message.dims)) return 'dims: array expected'; + for (var i = 0; i < message.dims.length; ++i) + if ( + !$util.isInteger(message.dims[i]) && + !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high)) + ) + return 'dims: integer|Long[] expected'; + } + return null; + }; + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.SparseTensorProto} SparseTensorProto + */ + SparseTensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.SparseTensorProto) return object; + var message = new $root.onnx.SparseTensorProto(); + if (object.values != null) { + if (typeof object.values !== 'object') throw TypeError('.onnx.SparseTensorProto.values: object expected'); + message.values = $root.onnx.TensorProto.fromObject(object.values); + } + if (object.indices != null) { + if (typeof object.indices !== 'object') throw TypeError('.onnx.SparseTensorProto.indices: object expected'); + message.indices = $root.onnx.TensorProto.fromObject(object.indices); + } + if (object.dims) { + if (!Array.isArray(object.dims)) throw TypeError('.onnx.SparseTensorProto.dims: array expected'); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === 'string') message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === 'number') message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === 'object') + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.SparseTensorProto} message SparseTensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensorProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) object.dims = []; + if (options.defaults) { + object.values = null; + object.indices = null; + } + if (message.values != null && message.hasOwnProperty('values')) + object.values = $root.onnx.TensorProto.toObject(message.values, options); + if (message.indices != null && message.hasOwnProperty('indices')) + object.indices = $root.onnx.TensorProto.toObject(message.indices, options); + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === 'number') + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = + options.longs === String + ? $util.Long.prototype.toString.call(message.dims[j]) + : options.longs === Number + ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() + : message.dims[j]; + } + return object; + }; + + /** + * Converts this SparseTensorProto to JSON. + * @function toJSON + * @memberof onnx.SparseTensorProto + * @instance + * @returns {Object.} JSON object + */ + SparseTensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensorProto + * @function getTypeUrl + * @memberof onnx.SparseTensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.SparseTensorProto'; + }; - onnx.NodeProto = (function() { - - /** - * Properties of a NodeProto. - * @memberof onnx - * @interface INodeProto - * @property {Array.|null} [input] NodeProto input - * @property {Array.|null} [output] NodeProto output - * @property {string|null} [name] NodeProto name - * @property {string|null} [opType] NodeProto opType - * @property {string|null} [domain] NodeProto domain - * @property {Array.|null} [attribute] NodeProto attribute - * @property {string|null} [docString] NodeProto docString - */ - - /** - * Constructs a new NodeProto. - * @memberof onnx - * @classdesc Represents a NodeProto. - * @implements INodeProto - * @constructor - * @param {onnx.INodeProto=} [properties] Properties to set - */ - function NodeProto(properties) { - this.input = []; - this.output = []; - this.attribute = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + return SparseTensorProto; + })(); + + onnx.TensorShapeProto = (function () { + /** + * Properties of a TensorShapeProto. + * @memberof onnx + * @interface ITensorShapeProto + * @property {Array.|null} [dim] TensorShapeProto dim + */ + + /** + * Constructs a new TensorShapeProto. + * @memberof onnx + * @classdesc Represents a TensorShapeProto. + * @implements ITensorShapeProto + * @constructor + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + */ + function TensorShapeProto(properties) { + this.dim = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorShapeProto dim. + * @member {Array.} dim + * @memberof onnx.TensorShapeProto + * @instance + */ + TensorShapeProto.prototype.dim = $util.emptyArray; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + * @returns {onnx.TensorShapeProto} TensorShapeProto instance + */ + TensorShapeProto.create = function create(properties) { + return new TensorShapeProto(properties); + }; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.dim != null && message.dim.length) + for (var i = 0; i < message.dim.length; ++i) + $root.onnx.TensorShapeProto.Dimension.encode( + message.dim[i], + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorShapeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dim && message.dim.length)) message.dim = []; + message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * NodeProto input. - * @member {Array.} input - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.input = $util.emptyArray; - - /** - * NodeProto output. - * @member {Array.} output - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.output = $util.emptyArray; - - /** - * NodeProto name. - * @member {string} name - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.name = ""; - - /** - * NodeProto opType. - * @member {string} opType - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.opType = ""; - - /** - * NodeProto domain. - * @member {string} domain - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.domain = ""; - - /** - * NodeProto attribute. - * @member {Array.} attribute - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.attribute = $util.emptyArray; - - /** - * NodeProto docString. - * @member {string} docString - * @memberof onnx.NodeProto - * @instance - */ - NodeProto.prototype.docString = ""; - - /** - * Creates a new NodeProto instance using the specified properties. - * @function create - * @memberof onnx.NodeProto - * @static - * @param {onnx.INodeProto=} [properties] Properties to set - * @returns {onnx.NodeProto} NodeProto instance - */ - NodeProto.create = function create(properties) { - return new NodeProto(properties); - }; - - /** - * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. - * @function encode - * @memberof onnx.NodeProto - * @static - * @param {onnx.INodeProto} message NodeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - NodeProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.input != null && message.input.length) - for (var i = 0; i < message.input.length; ++i) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.input[i]); - if (message.output != null && message.output.length) - for (var i = 0; i < message.output.length; ++i) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.output[i]); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 3, wireType 2 =*/26).string(message.name); - if (message.opType != null && Object.hasOwnProperty.call(message, "opType")) - writer.uint32(/* id 4, wireType 2 =*/34).string(message.opType); - if (message.attribute != null && message.attribute.length) - for (var i = 0; i < message.attribute.length; ++i) - $root.onnx.AttributeProto.encode(message.attribute[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); - if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) - writer.uint32(/* id 7, wireType 2 =*/58).string(message.domain); - return writer; - }; - - /** - * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.NodeProto - * @static - * @param {onnx.INodeProto} message NodeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - NodeProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a NodeProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.NodeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.NodeProto} NodeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - NodeProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.NodeProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - if (!(message.input && message.input.length)) - message.input = []; - message.input.push(reader.string()); - break; - } - case 2: { - if (!(message.output && message.output.length)) - message.output = []; - message.output.push(reader.string()); - break; - } - case 3: { - message.name = reader.string(); - break; - } - case 4: { - message.opType = reader.string(); - break; - } - case 7: { - message.domain = reader.string(); - break; - } - case 5: { - if (!(message.attribute && message.attribute.length)) - message.attribute = []; - message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); - break; - } - case 6: { - message.docString = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a NodeProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.NodeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.NodeProto} NodeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - NodeProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a NodeProto message. - * @function verify - * @memberof onnx.NodeProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - NodeProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.input != null && message.hasOwnProperty("input")) { - if (!Array.isArray(message.input)) - return "input: array expected"; - for (var i = 0; i < message.input.length; ++i) - if (!$util.isString(message.input[i])) - return "input: string[] expected"; - } - if (message.output != null && message.hasOwnProperty("output")) { - if (!Array.isArray(message.output)) - return "output: array expected"; - for (var i = 0; i < message.output.length; ++i) - if (!$util.isString(message.output[i])) - return "output: string[] expected"; - } - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.opType != null && message.hasOwnProperty("opType")) - if (!$util.isString(message.opType)) - return "opType: string expected"; - if (message.domain != null && message.hasOwnProperty("domain")) - if (!$util.isString(message.domain)) - return "domain: string expected"; - if (message.attribute != null && message.hasOwnProperty("attribute")) { - if (!Array.isArray(message.attribute)) - return "attribute: array expected"; - for (var i = 0; i < message.attribute.length; ++i) { - var error = $root.onnx.AttributeProto.verify(message.attribute[i]); - if (error) - return "attribute." + error; - } - } - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - return null; - }; - - /** - * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.NodeProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.NodeProto} NodeProto - */ - NodeProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.NodeProto) - return object; - var message = new $root.onnx.NodeProto(); - if (object.input) { - if (!Array.isArray(object.input)) - throw TypeError(".onnx.NodeProto.input: array expected"); - message.input = []; - for (var i = 0; i < object.input.length; ++i) - message.input[i] = String(object.input[i]); - } - if (object.output) { - if (!Array.isArray(object.output)) - throw TypeError(".onnx.NodeProto.output: array expected"); - message.output = []; - for (var i = 0; i < object.output.length; ++i) - message.output[i] = String(object.output[i]); - } - if (object.name != null) - message.name = String(object.name); - if (object.opType != null) - message.opType = String(object.opType); - if (object.domain != null) - message.domain = String(object.domain); - if (object.attribute) { - if (!Array.isArray(object.attribute)) - throw TypeError(".onnx.NodeProto.attribute: array expected"); - message.attribute = []; - for (var i = 0; i < object.attribute.length; ++i) { - if (typeof object.attribute[i] !== "object") - throw TypeError(".onnx.NodeProto.attribute: object expected"); - message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); - } - } - if (object.docString != null) - message.docString = String(object.docString); - return message; - }; - - /** - * Creates a plain object from a NodeProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.NodeProto - * @static - * @param {onnx.NodeProto} message NodeProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - NodeProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.input = []; - object.output = []; - object.attribute = []; - } - if (options.defaults) { - object.name = ""; - object.opType = ""; - object.docString = ""; - object.domain = ""; - } - if (message.input && message.input.length) { - object.input = []; - for (var j = 0; j < message.input.length; ++j) - object.input[j] = message.input[j]; - } - if (message.output && message.output.length) { - object.output = []; - for (var j = 0; j < message.output.length; ++j) - object.output[j] = message.output[j]; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.opType != null && message.hasOwnProperty("opType")) - object.opType = message.opType; - if (message.attribute && message.attribute.length) { - object.attribute = []; - for (var j = 0; j < message.attribute.length; ++j) - object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.domain != null && message.hasOwnProperty("domain")) - object.domain = message.domain; - return object; - }; - - /** - * Converts this NodeProto to JSON. - * @function toJSON - * @memberof onnx.NodeProto - * @instance - * @returns {Object.} JSON object - */ - NodeProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for NodeProto - * @function getTypeUrl - * @memberof onnx.NodeProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.NodeProto"; - }; + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; - return NodeProto; - })(); + /** + * Verifies a TensorShapeProto message. + * @function verify + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorShapeProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.dim != null && message.hasOwnProperty('dim')) { + if (!Array.isArray(message.dim)) return 'dim: array expected'; + for (var i = 0; i < message.dim.length; ++i) { + var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); + if (error) return 'dim.' + error; + } + } + return null; + }; - onnx.TrainingInfoProto = (function() { - - /** - * Properties of a TrainingInfoProto. - * @memberof onnx - * @interface ITrainingInfoProto - * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization - * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm - * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding - * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding - */ - - /** - * Constructs a new TrainingInfoProto. - * @memberof onnx - * @classdesc Represents a TrainingInfoProto. - * @implements ITrainingInfoProto - * @constructor - * @param {onnx.ITrainingInfoProto=} [properties] Properties to set - */ - function TrainingInfoProto(properties) { - this.initializationBinding = []; - this.updateBinding = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto} TensorShapeProto + */ + TensorShapeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto) return object; + var message = new $root.onnx.TensorShapeProto(); + if (object.dim) { + if (!Array.isArray(object.dim)) throw TypeError('.onnx.TensorShapeProto.dim: array expected'); + message.dim = []; + for (var i = 0; i < object.dim.length; ++i) { + if (typeof object.dim[i] !== 'object') throw TypeError('.onnx.TensorShapeProto.dim: object expected'); + message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); } + } + return message; + }; - /** - * TrainingInfoProto initialization. - * @member {onnx.IGraphProto|null|undefined} initialization - * @memberof onnx.TrainingInfoProto - * @instance - */ - TrainingInfoProto.prototype.initialization = null; - - /** - * TrainingInfoProto algorithm. - * @member {onnx.IGraphProto|null|undefined} algorithm - * @memberof onnx.TrainingInfoProto - * @instance - */ - TrainingInfoProto.prototype.algorithm = null; - - /** - * TrainingInfoProto initializationBinding. - * @member {Array.} initializationBinding - * @memberof onnx.TrainingInfoProto - * @instance - */ - TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; - - /** - * TrainingInfoProto updateBinding. - * @member {Array.} updateBinding - * @memberof onnx.TrainingInfoProto - * @instance - */ - TrainingInfoProto.prototype.updateBinding = $util.emptyArray; - - /** - * Creates a new TrainingInfoProto instance using the specified properties. - * @function create - * @memberof onnx.TrainingInfoProto - * @static - * @param {onnx.ITrainingInfoProto=} [properties] Properties to set - * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance - */ - TrainingInfoProto.create = function create(properties) { - return new TrainingInfoProto(properties); - }; - - /** - * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. - * @function encode - * @memberof onnx.TrainingInfoProto - * @static - * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TrainingInfoProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.initialization != null && Object.hasOwnProperty.call(message, "initialization")) - $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - if (message.algorithm != null && Object.hasOwnProperty.call(message, "algorithm")) - $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - if (message.initializationBinding != null && message.initializationBinding.length) - for (var i = 0; i < message.initializationBinding.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.initializationBinding[i], writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); - if (message.updateBinding != null && message.updateBinding.length) - for (var i = 0; i < message.updateBinding.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.updateBinding[i], writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TrainingInfoProto - * @static - * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TrainingInfoProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.TrainingInfoProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TrainingInfoProto} TrainingInfoProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TrainingInfoProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TrainingInfoProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); - break; - } - case 2: { - message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); - break; - } - case 3: { - if (!(message.initializationBinding && message.initializationBinding.length)) - message.initializationBinding = []; - message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - case 4: { - if (!(message.updateBinding && message.updateBinding.length)) - message.updateBinding = []; - message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TrainingInfoProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TrainingInfoProto} TrainingInfoProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TrainingInfoProto message. - * @function verify - * @memberof onnx.TrainingInfoProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TrainingInfoProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.initialization != null && message.hasOwnProperty("initialization")) { - var error = $root.onnx.GraphProto.verify(message.initialization); - if (error) - return "initialization." + error; - } - if (message.algorithm != null && message.hasOwnProperty("algorithm")) { - var error = $root.onnx.GraphProto.verify(message.algorithm); - if (error) - return "algorithm." + error; - } - if (message.initializationBinding != null && message.hasOwnProperty("initializationBinding")) { - if (!Array.isArray(message.initializationBinding)) - return "initializationBinding: array expected"; - for (var i = 0; i < message.initializationBinding.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); - if (error) - return "initializationBinding." + error; - } - } - if (message.updateBinding != null && message.hasOwnProperty("updateBinding")) { - if (!Array.isArray(message.updateBinding)) - return "updateBinding: array expected"; - for (var i = 0; i < message.updateBinding.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); - if (error) - return "updateBinding." + error; - } - } - return null; - }; - - /** - * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TrainingInfoProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.TrainingInfoProto} TrainingInfoProto - */ - TrainingInfoProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TrainingInfoProto) - return object; - var message = new $root.onnx.TrainingInfoProto(); - if (object.initialization != null) { - if (typeof object.initialization !== "object") - throw TypeError(".onnx.TrainingInfoProto.initialization: object expected"); - message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); - } - if (object.algorithm != null) { - if (typeof object.algorithm !== "object") - throw TypeError(".onnx.TrainingInfoProto.algorithm: object expected"); - message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); - } - if (object.initializationBinding) { - if (!Array.isArray(object.initializationBinding)) - throw TypeError(".onnx.TrainingInfoProto.initializationBinding: array expected"); - message.initializationBinding = []; - for (var i = 0; i < object.initializationBinding.length; ++i) { - if (typeof object.initializationBinding[i] !== "object") - throw TypeError(".onnx.TrainingInfoProto.initializationBinding: object expected"); - message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.initializationBinding[i]); - } - } - if (object.updateBinding) { - if (!Array.isArray(object.updateBinding)) - throw TypeError(".onnx.TrainingInfoProto.updateBinding: array expected"); - message.updateBinding = []; - for (var i = 0; i < object.updateBinding.length; ++i) { - if (typeof object.updateBinding[i] !== "object") - throw TypeError(".onnx.TrainingInfoProto.updateBinding: object expected"); - message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TrainingInfoProto - * @static - * @param {onnx.TrainingInfoProto} message TrainingInfoProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TrainingInfoProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.initializationBinding = []; - object.updateBinding = []; - } - if (options.defaults) { - object.initialization = null; - object.algorithm = null; - } - if (message.initialization != null && message.hasOwnProperty("initialization")) - object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); - if (message.algorithm != null && message.hasOwnProperty("algorithm")) - object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); - if (message.initializationBinding && message.initializationBinding.length) { - object.initializationBinding = []; - for (var j = 0; j < message.initializationBinding.length; ++j) - object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.initializationBinding[j], options); - } - if (message.updateBinding && message.updateBinding.length) { - object.updateBinding = []; - for (var j = 0; j < message.updateBinding.length; ++j) - object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); - } - return object; - }; - - /** - * Converts this TrainingInfoProto to JSON. - * @function toJSON - * @memberof onnx.TrainingInfoProto - * @instance - * @returns {Object.} JSON object - */ - TrainingInfoProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TrainingInfoProto - * @function getTypeUrl - * @memberof onnx.TrainingInfoProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.TensorShapeProto} message TensorShapeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorShapeProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) object.dim = []; + if (message.dim && message.dim.length) { + object.dim = []; + for (var j = 0; j < message.dim.length; ++j) + object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); + } + return object; + }; + + /** + * Converts this TensorShapeProto to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto + * @instance + * @returns {Object.} JSON object + */ + TensorShapeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorShapeProto + * @function getTypeUrl + * @memberof onnx.TensorShapeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorShapeProto'; + }; + + TensorShapeProto.Dimension = (function () { + /** + * Properties of a Dimension. + * @memberof onnx.TensorShapeProto + * @interface IDimension + * @property {number|Long|null} [dimValue] Dimension dimValue + * @property {string|null} [dimParam] Dimension dimParam + * @property {string|null} [denotation] Dimension denotation + */ + + /** + * Constructs a new Dimension. + * @memberof onnx.TensorShapeProto + * @classdesc Represents a Dimension. + * @implements IDimension + * @constructor + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + */ + function Dimension(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Dimension dimValue. + * @member {number|Long|null|undefined} dimValue + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimValue = null; + + /** + * Dimension dimParam. + * @member {string|null|undefined} dimParam + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimParam = null; + + /** + * Dimension denotation. + * @member {string} denotation + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.denotation = ''; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * Dimension value. + * @member {"dimValue"|"dimParam"|undefined} value + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Object.defineProperty(Dimension.prototype, 'value', { + get: $util.oneOfGetter(($oneOfFields = ['dimValue', 'dimParam'])), + set: $util.oneOfSetter($oneOfFields), + }); + + /** + * Creates a new Dimension instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + * @returns {onnx.TensorShapeProto.Dimension} Dimension instance + */ + Dimension.create = function create(properties) { + return new Dimension(properties); + }; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.dimValue != null && Object.hasOwnProperty.call(message, 'dimValue')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int64(message.dimValue); + if (message.dimParam != null && Object.hasOwnProperty.call(message, 'dimParam')) + writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.dimParam); + if (message.denotation != null && Object.hasOwnProperty.call(message, 'denotation')) + writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.denotation); + return writer; + }; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TensorShapeProto.Dimension(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.dimValue = reader.int64(); + break; + } + case 2: { + message.dimParam = reader.string(); + break; + } + case 3: { + message.denotation = reader.string(); + break; } - return typeUrlPrefix + "/onnx.TrainingInfoProto"; - }; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Dimension message. + * @function verify + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Dimension.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + var properties = {}; + if (message.dimValue != null && message.hasOwnProperty('dimValue')) { + properties.value = 1; + if ( + !$util.isInteger(message.dimValue) && + !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high)) + ) + return 'dimValue: integer|Long expected'; + } + if (message.dimParam != null && message.hasOwnProperty('dimParam')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + if (!$util.isString(message.dimParam)) return 'dimParam: string expected'; + } + if (message.denotation != null && message.hasOwnProperty('denotation')) + if (!$util.isString(message.denotation)) return 'denotation: string expected'; + return null; + }; + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto.Dimension} Dimension + */ + Dimension.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto.Dimension) return object; + var message = new $root.onnx.TensorShapeProto.Dimension(); + if (object.dimValue != null) + if ($util.Long) (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; + else if (typeof object.dimValue === 'string') message.dimValue = parseInt(object.dimValue, 10); + else if (typeof object.dimValue === 'number') message.dimValue = object.dimValue; + else if (typeof object.dimValue === 'object') + message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); + if (object.dimParam != null) message.dimParam = String(object.dimParam); + if (object.denotation != null) message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.Dimension} message Dimension + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Dimension.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) object.denotation = ''; + if (message.dimValue != null && message.hasOwnProperty('dimValue')) { + if (typeof message.dimValue === 'number') + object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; + else + object.dimValue = + options.longs === String + ? $util.Long.prototype.toString.call(message.dimValue) + : options.longs === Number + ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() + : message.dimValue; + if (options.oneofs) object.value = 'dimValue'; + } + if (message.dimParam != null && message.hasOwnProperty('dimParam')) { + object.dimParam = message.dimParam; + if (options.oneofs) object.value = 'dimParam'; + } + if (message.denotation != null && message.hasOwnProperty('denotation')) object.denotation = message.denotation; + return object; + }; + + /** + * Converts this Dimension to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto.Dimension + * @instance + * @returns {Object.} JSON object + */ + Dimension.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Dimension + * @function getTypeUrl + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TensorShapeProto.Dimension'; + }; - return TrainingInfoProto; + return Dimension; })(); - onnx.ModelProto = (function() { - - /** - * Properties of a ModelProto. - * @memberof onnx - * @interface IModelProto - * @property {number|Long|null} [irVersion] ModelProto irVersion - * @property {Array.|null} [opsetImport] ModelProto opsetImport - * @property {string|null} [producerName] ModelProto producerName - * @property {string|null} [producerVersion] ModelProto producerVersion - * @property {string|null} [domain] ModelProto domain - * @property {number|Long|null} [modelVersion] ModelProto modelVersion - * @property {string|null} [docString] ModelProto docString - * @property {onnx.IGraphProto|null} [graph] ModelProto graph - * @property {Array.|null} [metadataProps] ModelProto metadataProps - * @property {Array.|null} [trainingInfo] ModelProto trainingInfo - * @property {Array.|null} [functions] ModelProto functions - */ - - /** - * Constructs a new ModelProto. - * @memberof onnx - * @classdesc Represents a ModelProto. - * @implements IModelProto - * @constructor - * @param {onnx.IModelProto=} [properties] Properties to set - */ - function ModelProto(properties) { - this.opsetImport = []; - this.metadataProps = []; - this.trainingInfo = []; - this.functions = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + return TensorShapeProto; + })(); + + onnx.TypeProto = (function () { + /** + * Properties of a TypeProto. + * @memberof onnx + * @interface ITypeProto + * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType + * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType + * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType + * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType + * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType + * @property {string|null} [denotation] TypeProto denotation + */ + + /** + * Constructs a new TypeProto. + * @memberof onnx + * @classdesc Represents a TypeProto. + * @implements ITypeProto + * @constructor + * @param {onnx.ITypeProto=} [properties] Properties to set + */ + function TypeProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * TypeProto tensorType. + * @member {onnx.TypeProto.ITensor|null|undefined} tensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.tensorType = null; + + /** + * TypeProto sequenceType. + * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sequenceType = null; + + /** + * TypeProto mapType. + * @member {onnx.TypeProto.IMap|null|undefined} mapType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.mapType = null; + + /** + * TypeProto optionalType. + * @member {onnx.TypeProto.IOptional|null|undefined} optionalType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.optionalType = null; + + /** + * TypeProto sparseTensorType. + * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sparseTensorType = null; + + /** + * TypeProto denotation. + * @member {string} denotation + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.denotation = ''; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * TypeProto value. + * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value + * @memberof onnx.TypeProto + * @instance + */ + Object.defineProperty(TypeProto.prototype, 'value', { + get: $util.oneOfGetter( + ($oneOfFields = ['tensorType', 'sequenceType', 'mapType', 'optionalType', 'sparseTensorType']), + ), + set: $util.oneOfSetter($oneOfFields), + }); + + /** + * Creates a new TypeProto instance using the specified properties. + * @function create + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto=} [properties] Properties to set + * @returns {onnx.TypeProto} TypeProto instance + */ + TypeProto.create = function create(properties) { + return new TypeProto(properties); + }; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.tensorType != null && Object.hasOwnProperty.call(message, 'tensorType')) + $root.onnx.TypeProto.Tensor.encode( + message.tensorType, + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), + ).ldelim(); + if (message.sequenceType != null && Object.hasOwnProperty.call(message, 'sequenceType')) + $root.onnx.TypeProto.Sequence.encode( + message.sequenceType, + writer.uint32(/* id 4, wireType 2 =*/ 34).fork(), + ).ldelim(); + if (message.mapType != null && Object.hasOwnProperty.call(message, 'mapType')) + $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/ 42).fork()).ldelim(); + if (message.denotation != null && Object.hasOwnProperty.call(message, 'denotation')) + writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.denotation); + if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, 'sparseTensorType')) + $root.onnx.TypeProto.SparseTensor.encode( + message.sparseTensorType, + writer.uint32(/* id 8, wireType 2 =*/ 66).fork(), + ).ldelim(); + if (message.optionalType != null && Object.hasOwnProperty.call(message, 'optionalType')) + $root.onnx.TypeProto.Optional.encode( + message.optionalType, + writer.uint32(/* id 9, wireType 2 =*/ 74).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); + break; + } + case 4: { + message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); + break; + } + case 5: { + message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); + break; + } + case 9: { + message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); + break; + } + case 8: { + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); + break; + } + case 6: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * ModelProto irVersion. - * @member {number|Long} irVersion - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * ModelProto opsetImport. - * @member {Array.} opsetImport - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.opsetImport = $util.emptyArray; - - /** - * ModelProto producerName. - * @member {string} producerName - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.producerName = ""; - - /** - * ModelProto producerVersion. - * @member {string} producerVersion - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.producerVersion = ""; - - /** - * ModelProto domain. - * @member {string} domain - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.domain = ""; - - /** - * ModelProto modelVersion. - * @member {number|Long} modelVersion - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * ModelProto docString. - * @member {string} docString - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.docString = ""; - - /** - * ModelProto graph. - * @member {onnx.IGraphProto|null|undefined} graph - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.graph = null; - - /** - * ModelProto metadataProps. - * @member {Array.} metadataProps - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.metadataProps = $util.emptyArray; - - /** - * ModelProto trainingInfo. - * @member {Array.} trainingInfo - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.trainingInfo = $util.emptyArray; - - /** - * ModelProto functions. - * @member {Array.} functions - * @memberof onnx.ModelProto - * @instance - */ - ModelProto.prototype.functions = $util.emptyArray; - - /** - * Creates a new ModelProto instance using the specified properties. - * @function create - * @memberof onnx.ModelProto - * @static - * @param {onnx.IModelProto=} [properties] Properties to set - * @returns {onnx.ModelProto} ModelProto instance - */ - ModelProto.create = function create(properties) { - return new ModelProto(properties); - }; - - /** - * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. - * @function encode - * @memberof onnx.ModelProto - * @static - * @param {onnx.IModelProto} message ModelProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - ModelProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.irVersion != null && Object.hasOwnProperty.call(message, "irVersion")) - writer.uint32(/* id 1, wireType 0 =*/8).int64(message.irVersion); - if (message.producerName != null && Object.hasOwnProperty.call(message, "producerName")) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.producerName); - if (message.producerVersion != null && Object.hasOwnProperty.call(message, "producerVersion")) - writer.uint32(/* id 3, wireType 2 =*/26).string(message.producerVersion); - if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) - writer.uint32(/* id 4, wireType 2 =*/34).string(message.domain); - if (message.modelVersion != null && Object.hasOwnProperty.call(message, "modelVersion")) - writer.uint32(/* id 5, wireType 0 =*/40).int64(message.modelVersion); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); - if (message.graph != null && Object.hasOwnProperty.call(message, "graph")) - $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); - if (message.opsetImport != null && message.opsetImport.length) - for (var i = 0; i < message.opsetImport.length; ++i) - $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); - if (message.metadataProps != null && message.metadataProps.length) - for (var i = 0; i < message.metadataProps.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.metadataProps[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); - if (message.trainingInfo != null && message.trainingInfo.length) - for (var i = 0; i < message.trainingInfo.length; ++i) - $root.onnx.TrainingInfoProto.encode(message.trainingInfo[i], writer.uint32(/* id 20, wireType 2 =*/162).fork()).ldelim(); - if (message.functions != null && message.functions.length) - for (var i = 0; i < message.functions.length; ++i) - $root.onnx.FunctionProto.encode(message.functions[i], writer.uint32(/* id 25, wireType 2 =*/202).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.ModelProto - * @static - * @param {onnx.IModelProto} message ModelProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - ModelProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a ModelProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.ModelProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.ModelProto} ModelProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - ModelProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ModelProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.irVersion = reader.int64(); - break; - } - case 8: { - if (!(message.opsetImport && message.opsetImport.length)) - message.opsetImport = []; - message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); - break; - } - case 2: { - message.producerName = reader.string(); - break; - } - case 3: { - message.producerVersion = reader.string(); - break; - } - case 4: { - message.domain = reader.string(); - break; - } - case 5: { - message.modelVersion = reader.int64(); - break; - } - case 6: { - message.docString = reader.string(); - break; - } - case 7: { - message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); - break; - } - case 14: { - if (!(message.metadataProps && message.metadataProps.length)) - message.metadataProps = []; - message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - case 20: { - if (!(message.trainingInfo && message.trainingInfo.length)) - message.trainingInfo = []; - message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); - break; - } - case 25: { - if (!(message.functions && message.functions.length)) - message.functions = []; - message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a ModelProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.ModelProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.ModelProto} ModelProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - ModelProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a ModelProto message. - * @function verify - * @memberof onnx.ModelProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - ModelProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.irVersion != null && message.hasOwnProperty("irVersion")) - if (!$util.isInteger(message.irVersion) && !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high))) - return "irVersion: integer|Long expected"; - if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { - if (!Array.isArray(message.opsetImport)) - return "opsetImport: array expected"; - for (var i = 0; i < message.opsetImport.length; ++i) { - var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); - if (error) - return "opsetImport." + error; - } - } - if (message.producerName != null && message.hasOwnProperty("producerName")) - if (!$util.isString(message.producerName)) - return "producerName: string expected"; - if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) - if (!$util.isString(message.producerVersion)) - return "producerVersion: string expected"; - if (message.domain != null && message.hasOwnProperty("domain")) - if (!$util.isString(message.domain)) - return "domain: string expected"; - if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) - if (!$util.isInteger(message.modelVersion) && !(message.modelVersion && $util.isInteger(message.modelVersion.low) && $util.isInteger(message.modelVersion.high))) - return "modelVersion: integer|Long expected"; - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.graph != null && message.hasOwnProperty("graph")) { - var error = $root.onnx.GraphProto.verify(message.graph); - if (error) - return "graph." + error; - } - if (message.metadataProps != null && message.hasOwnProperty("metadataProps")) { - if (!Array.isArray(message.metadataProps)) - return "metadataProps: array expected"; - for (var i = 0; i < message.metadataProps.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); - if (error) - return "metadataProps." + error; - } - } - if (message.trainingInfo != null && message.hasOwnProperty("trainingInfo")) { - if (!Array.isArray(message.trainingInfo)) - return "trainingInfo: array expected"; - for (var i = 0; i < message.trainingInfo.length; ++i) { - var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); - if (error) - return "trainingInfo." + error; - } - } - if (message.functions != null && message.hasOwnProperty("functions")) { - if (!Array.isArray(message.functions)) - return "functions: array expected"; - for (var i = 0; i < message.functions.length; ++i) { - var error = $root.onnx.FunctionProto.verify(message.functions[i]); - if (error) - return "functions." + error; - } - } - return null; - }; - - /** - * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.ModelProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.ModelProto} ModelProto - */ - ModelProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.ModelProto) - return object; - var message = new $root.onnx.ModelProto(); - if (object.irVersion != null) - if ($util.Long) - (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; - else if (typeof object.irVersion === "string") - message.irVersion = parseInt(object.irVersion, 10); - else if (typeof object.irVersion === "number") - message.irVersion = object.irVersion; - else if (typeof object.irVersion === "object") - message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); - if (object.opsetImport) { - if (!Array.isArray(object.opsetImport)) - throw TypeError(".onnx.ModelProto.opsetImport: array expected"); - message.opsetImport = []; - for (var i = 0; i < object.opsetImport.length; ++i) { - if (typeof object.opsetImport[i] !== "object") - throw TypeError(".onnx.ModelProto.opsetImport: object expected"); - message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); - } - } - if (object.producerName != null) - message.producerName = String(object.producerName); - if (object.producerVersion != null) - message.producerVersion = String(object.producerVersion); - if (object.domain != null) - message.domain = String(object.domain); - if (object.modelVersion != null) - if ($util.Long) - (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; - else if (typeof object.modelVersion === "string") - message.modelVersion = parseInt(object.modelVersion, 10); - else if (typeof object.modelVersion === "number") - message.modelVersion = object.modelVersion; - else if (typeof object.modelVersion === "object") - message.modelVersion = new $util.LongBits(object.modelVersion.low >>> 0, object.modelVersion.high >>> 0).toNumber(); - if (object.docString != null) - message.docString = String(object.docString); - if (object.graph != null) { - if (typeof object.graph !== "object") - throw TypeError(".onnx.ModelProto.graph: object expected"); - message.graph = $root.onnx.GraphProto.fromObject(object.graph); - } - if (object.metadataProps) { - if (!Array.isArray(object.metadataProps)) - throw TypeError(".onnx.ModelProto.metadataProps: array expected"); - message.metadataProps = []; - for (var i = 0; i < object.metadataProps.length; ++i) { - if (typeof object.metadataProps[i] !== "object") - throw TypeError(".onnx.ModelProto.metadataProps: object expected"); - message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); - } - } - if (object.trainingInfo) { - if (!Array.isArray(object.trainingInfo)) - throw TypeError(".onnx.ModelProto.trainingInfo: array expected"); - message.trainingInfo = []; - for (var i = 0; i < object.trainingInfo.length; ++i) { - if (typeof object.trainingInfo[i] !== "object") - throw TypeError(".onnx.ModelProto.trainingInfo: object expected"); - message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); - } - } - if (object.functions) { - if (!Array.isArray(object.functions)) - throw TypeError(".onnx.ModelProto.functions: array expected"); - message.functions = []; - for (var i = 0; i < object.functions.length; ++i) { - if (typeof object.functions[i] !== "object") - throw TypeError(".onnx.ModelProto.functions: object expected"); - message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a ModelProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.ModelProto - * @static - * @param {onnx.ModelProto} message ModelProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - ModelProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.opsetImport = []; - object.metadataProps = []; - object.trainingInfo = []; - object.functions = []; - } - if (options.defaults) { - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.irVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.irVersion = options.longs === String ? "0" : 0; - object.producerName = ""; - object.producerVersion = ""; - object.domain = ""; - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.modelVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.modelVersion = options.longs === String ? "0" : 0; - object.docString = ""; - object.graph = null; - } - if (message.irVersion != null && message.hasOwnProperty("irVersion")) - if (typeof message.irVersion === "number") - object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; - else - object.irVersion = options.longs === String ? $util.Long.prototype.toString.call(message.irVersion) : options.longs === Number ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() : message.irVersion; - if (message.producerName != null && message.hasOwnProperty("producerName")) - object.producerName = message.producerName; - if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) - object.producerVersion = message.producerVersion; - if (message.domain != null && message.hasOwnProperty("domain")) - object.domain = message.domain; - if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) - if (typeof message.modelVersion === "number") - object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; - else - object.modelVersion = options.longs === String ? $util.Long.prototype.toString.call(message.modelVersion) : options.longs === Number ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() : message.modelVersion; - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.graph != null && message.hasOwnProperty("graph")) - object.graph = $root.onnx.GraphProto.toObject(message.graph, options); - if (message.opsetImport && message.opsetImport.length) { - object.opsetImport = []; - for (var j = 0; j < message.opsetImport.length; ++j) - object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); - } - if (message.metadataProps && message.metadataProps.length) { - object.metadataProps = []; - for (var j = 0; j < message.metadataProps.length; ++j) - object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); - } - if (message.trainingInfo && message.trainingInfo.length) { - object.trainingInfo = []; - for (var j = 0; j < message.trainingInfo.length; ++j) - object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); - } - if (message.functions && message.functions.length) { - object.functions = []; - for (var j = 0; j < message.functions.length; ++j) - object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TypeProto message. + * @function verify + * @memberof onnx.TypeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TypeProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + var properties = {}; + if (message.tensorType != null && message.hasOwnProperty('tensorType')) { + properties.value = 1; + { + var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); + if (error) return 'tensorType.' + error; + } + } + if (message.sequenceType != null && message.hasOwnProperty('sequenceType')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); + if (error) return 'sequenceType.' + error; + } + } + if (message.mapType != null && message.hasOwnProperty('mapType')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Map.verify(message.mapType); + if (error) return 'mapType.' + error; + } + } + if (message.optionalType != null && message.hasOwnProperty('optionalType')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); + if (error) return 'optionalType.' + error; + } + } + if (message.sparseTensorType != null && message.hasOwnProperty('sparseTensorType')) { + if (properties.value === 1) return 'value: multiple values'; + properties.value = 1; + { + var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); + if (error) return 'sparseTensorType.' + error; + } + } + if (message.denotation != null && message.hasOwnProperty('denotation')) + if (!$util.isString(message.denotation)) return 'denotation: string expected'; + return null; + }; + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto} TypeProto + */ + TypeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto) return object; + var message = new $root.onnx.TypeProto(); + if (object.tensorType != null) { + if (typeof object.tensorType !== 'object') throw TypeError('.onnx.TypeProto.tensorType: object expected'); + message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); + } + if (object.sequenceType != null) { + if (typeof object.sequenceType !== 'object') throw TypeError('.onnx.TypeProto.sequenceType: object expected'); + message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); + } + if (object.mapType != null) { + if (typeof object.mapType !== 'object') throw TypeError('.onnx.TypeProto.mapType: object expected'); + message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); + } + if (object.optionalType != null) { + if (typeof object.optionalType !== 'object') throw TypeError('.onnx.TypeProto.optionalType: object expected'); + message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); + } + if (object.sparseTensorType != null) { + if (typeof object.sparseTensorType !== 'object') + throw TypeError('.onnx.TypeProto.sparseTensorType: object expected'); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); + } + if (object.denotation != null) message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto + * @static + * @param {onnx.TypeProto} message TypeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TypeProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) object.denotation = ''; + if (message.tensorType != null && message.hasOwnProperty('tensorType')) { + object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); + if (options.oneofs) object.value = 'tensorType'; + } + if (message.sequenceType != null && message.hasOwnProperty('sequenceType')) { + object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); + if (options.oneofs) object.value = 'sequenceType'; + } + if (message.mapType != null && message.hasOwnProperty('mapType')) { + object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); + if (options.oneofs) object.value = 'mapType'; + } + if (message.denotation != null && message.hasOwnProperty('denotation')) object.denotation = message.denotation; + if (message.sparseTensorType != null && message.hasOwnProperty('sparseTensorType')) { + object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); + if (options.oneofs) object.value = 'sparseTensorType'; + } + if (message.optionalType != null && message.hasOwnProperty('optionalType')) { + object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); + if (options.oneofs) object.value = 'optionalType'; + } + return object; + }; + + /** + * Converts this TypeProto to JSON. + * @function toJSON + * @memberof onnx.TypeProto + * @instance + * @returns {Object.} JSON object + */ + TypeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TypeProto + * @function getTypeUrl + * @memberof onnx.TypeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto'; + }; + + TypeProto.Tensor = (function () { + /** + * Properties of a Tensor. + * @memberof onnx.TypeProto + * @interface ITensor + * @property {number|null} [elemType] Tensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape + */ + + /** + * Constructs a new Tensor. + * @memberof onnx.TypeProto + * @classdesc Represents a Tensor. + * @implements ITensor + * @constructor + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + */ + function Tensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Tensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.elemType = 0; + + /** + * Tensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.shape = null; + + /** + * Creates a new Tensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + * @returns {onnx.TypeProto.Tensor} Tensor instance + */ + Tensor.create = function create(properties) { + return new Tensor(properties); + }; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, 'shape')) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.Tensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; } - return object; - }; - - /** - * Converts this ModelProto to JSON. - * @function toJSON - * @memberof onnx.ModelProto - * @instance - * @returns {Object.} JSON object - */ - ModelProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for ModelProto - * @function getTypeUrl - * @memberof onnx.ModelProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Tensor message. + * @function verify + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Tensor.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.elemType != null && message.hasOwnProperty('elemType')) + if (!$util.isInteger(message.elemType)) return 'elemType: integer expected'; + if (message.shape != null && message.hasOwnProperty('shape')) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) return 'shape.' + error; + } + return null; + }; + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Tensor} Tensor + */ + Tensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Tensor) return object; + var message = new $root.onnx.TypeProto.Tensor(); + if (object.elemType != null) message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== 'object') throw TypeError('.onnx.TypeProto.Tensor.shape: object expected'); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.Tensor} message Tensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Tensor.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty('elemType')) object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty('shape')) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this Tensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Tensor + * @instance + * @returns {Object.} JSON object + */ + Tensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Tensor + * @function getTypeUrl + * @memberof onnx.TypeProto.Tensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.Tensor'; + }; + + return Tensor; + })(); + + TypeProto.Sequence = (function () { + /** + * Properties of a Sequence. + * @memberof onnx.TypeProto + * @interface ISequence + * @property {onnx.ITypeProto|null} [elemType] Sequence elemType + */ + + /** + * Constructs a new Sequence. + * @memberof onnx.TypeProto + * @classdesc Represents a Sequence. + * @implements ISequence + * @constructor + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + */ + function Sequence(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Sequence elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Sequence + * @instance + */ + Sequence.prototype.elemType = null; + + /** + * Creates a new Sequence instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + * @returns {onnx.TypeProto.Sequence} Sequence instance + */ + Sequence.create = function create(properties) { + return new Sequence(properties); + }; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.Sequence(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; } - return typeUrlPrefix + "/onnx.ModelProto"; - }; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Sequence message. + * @function verify + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Sequence.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.elemType != null && message.hasOwnProperty('elemType')) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) return 'elemType.' + error; + } + return null; + }; + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Sequence} Sequence + */ + Sequence.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Sequence) return object; + var message = new $root.onnx.TypeProto.Sequence(); + if (object.elemType != null) { + if (typeof object.elemType !== 'object') + throw TypeError('.onnx.TypeProto.Sequence.elemType: object expected'); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.Sequence} message Sequence + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Sequence.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) object.elemType = null; + if (message.elemType != null && message.hasOwnProperty('elemType')) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Sequence to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Sequence + * @instance + * @returns {Object.} JSON object + */ + Sequence.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Sequence + * @function getTypeUrl + * @memberof onnx.TypeProto.Sequence + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.Sequence'; + }; - return ModelProto; + return Sequence; })(); - onnx.StringStringEntryProto = (function() { - - /** - * Properties of a StringStringEntryProto. - * @memberof onnx - * @interface IStringStringEntryProto - * @property {string|null} [key] StringStringEntryProto key - * @property {string|null} [value] StringStringEntryProto value - */ - - /** - * Constructs a new StringStringEntryProto. - * @memberof onnx - * @classdesc Represents a StringStringEntryProto. - * @implements IStringStringEntryProto - * @constructor - * @param {onnx.IStringStringEntryProto=} [properties] Properties to set - */ - function StringStringEntryProto(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } - - /** - * StringStringEntryProto key. - * @member {string} key - * @memberof onnx.StringStringEntryProto - * @instance - */ - StringStringEntryProto.prototype.key = ""; - - /** - * StringStringEntryProto value. - * @member {string} value - * @memberof onnx.StringStringEntryProto - * @instance - */ - StringStringEntryProto.prototype.value = ""; - - /** - * Creates a new StringStringEntryProto instance using the specified properties. - * @function create - * @memberof onnx.StringStringEntryProto - * @static - * @param {onnx.IStringStringEntryProto=} [properties] Properties to set - * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance - */ - StringStringEntryProto.create = function create(properties) { - return new StringStringEntryProto(properties); - }; - - /** - * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. - * @function encode - * @memberof onnx.StringStringEntryProto - * @static - * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - StringStringEntryProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.key != null && Object.hasOwnProperty.call(message, "key")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.key); - if (message.value != null && Object.hasOwnProperty.call(message, "value")) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.value); - return writer; - }; - - /** - * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.StringStringEntryProto - * @static - * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a StringStringEntryProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.StringStringEntryProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.StringStringEntryProto} StringStringEntryProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - StringStringEntryProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.StringStringEntryProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.key = reader.string(); - break; - } - case 2: { - message.value = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } + TypeProto.Map = (function () { + /** + * Properties of a Map. + * @memberof onnx.TypeProto + * @interface IMap + * @property {number|null} [keyType] Map keyType + * @property {onnx.ITypeProto|null} [valueType] Map valueType + */ + + /** + * Constructs a new Map. + * @memberof onnx.TypeProto + * @classdesc Represents a Map. + * @implements IMap + * @constructor + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + */ + function Map(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Map keyType. + * @member {number} keyType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.keyType = 0; + + /** + * Map valueType. + * @member {onnx.ITypeProto|null|undefined} valueType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.valueType = null; + + /** + * Creates a new Map instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + * @returns {onnx.TypeProto.Map} Map instance + */ + Map.create = function create(properties) { + return new Map(properties); + }; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.keyType != null && Object.hasOwnProperty.call(message, 'keyType')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int32(message.keyType); + if (message.valueType != null && Object.hasOwnProperty.call(message, 'valueType')) + $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Map message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.Map(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.keyType = reader.int32(); + break; + } + case 2: { + message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; } - return message; - }; - - /** - * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.StringStringEntryProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.StringStringEntryProto} StringStringEntryProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a StringStringEntryProto message. - * @function verify - * @memberof onnx.StringStringEntryProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - StringStringEntryProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.key != null && message.hasOwnProperty("key")) - if (!$util.isString(message.key)) - return "key: string expected"; - if (message.value != null && message.hasOwnProperty("value")) - if (!$util.isString(message.value)) - return "value: string expected"; - return null; - }; - - /** - * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.StringStringEntryProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.StringStringEntryProto} StringStringEntryProto - */ - StringStringEntryProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.StringStringEntryProto) - return object; - var message = new $root.onnx.StringStringEntryProto(); - if (object.key != null) - message.key = String(object.key); - if (object.value != null) - message.value = String(object.value); - return message; - }; - - /** - * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.StringStringEntryProto - * @static - * @param {onnx.StringStringEntryProto} message StringStringEntryProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - StringStringEntryProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.key = ""; - object.value = ""; - } - if (message.key != null && message.hasOwnProperty("key")) - object.key = message.key; - if (message.value != null && message.hasOwnProperty("value")) - object.value = message.value; - return object; - }; - - /** - * Converts this StringStringEntryProto to JSON. - * @function toJSON - * @memberof onnx.StringStringEntryProto - * @instance - * @returns {Object.} JSON object - */ - StringStringEntryProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for StringStringEntryProto - * @function getTypeUrl - * @memberof onnx.StringStringEntryProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.StringStringEntryProto"; - }; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Map message. + * @function verify + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Map.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.keyType != null && message.hasOwnProperty('keyType')) + if (!$util.isInteger(message.keyType)) return 'keyType: integer expected'; + if (message.valueType != null && message.hasOwnProperty('valueType')) { + var error = $root.onnx.TypeProto.verify(message.valueType); + if (error) return 'valueType.' + error; + } + return null; + }; + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Map} Map + */ + Map.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Map) return object; + var message = new $root.onnx.TypeProto.Map(); + if (object.keyType != null) message.keyType = object.keyType | 0; + if (object.valueType != null) { + if (typeof object.valueType !== 'object') throw TypeError('.onnx.TypeProto.Map.valueType: object expected'); + message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); + } + return message; + }; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.Map} message Map + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Map.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.keyType = 0; + object.valueType = null; + } + if (message.keyType != null && message.hasOwnProperty('keyType')) object.keyType = message.keyType; + if (message.valueType != null && message.hasOwnProperty('valueType')) + object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); + return object; + }; + + /** + * Converts this Map to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Map + * @instance + * @returns {Object.} JSON object + */ + Map.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Map + * @function getTypeUrl + * @memberof onnx.TypeProto.Map + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.Map'; + }; - return StringStringEntryProto; + return Map; })(); - onnx.TensorAnnotation = (function() { - - /** - * Properties of a TensorAnnotation. - * @memberof onnx - * @interface ITensorAnnotation - * @property {string|null} [tensorName] TensorAnnotation tensorName - * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames - */ - - /** - * Constructs a new TensorAnnotation. - * @memberof onnx - * @classdesc Represents a TensorAnnotation. - * @implements ITensorAnnotation - * @constructor - * @param {onnx.ITensorAnnotation=} [properties] Properties to set - */ - function TensorAnnotation(properties) { - this.quantParameterTensorNames = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + TypeProto.Optional = (function () { + /** + * Properties of an Optional. + * @memberof onnx.TypeProto + * @interface IOptional + * @property {onnx.ITypeProto|null} [elemType] Optional elemType + */ + + /** + * Constructs a new Optional. + * @memberof onnx.TypeProto + * @classdesc Represents an Optional. + * @implements IOptional + * @constructor + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + */ + function Optional(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * Optional elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Optional + * @instance + */ + Optional.prototype.elemType = null; + + /** + * Creates a new Optional instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + * @returns {onnx.TypeProto.Optional} Optional instance + */ + Optional.create = function create(properties) { + return new Optional(properties); + }; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.Optional(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an Optional message. + * @function verify + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Optional.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.elemType != null && message.hasOwnProperty('elemType')) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) return 'elemType.' + error; + } + return null; + }; + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Optional} Optional + */ + Optional.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Optional) return object; + var message = new $root.onnx.TypeProto.Optional(); + if (object.elemType != null) { + if (typeof object.elemType !== 'object') + throw TypeError('.onnx.TypeProto.Optional.elemType: object expected'); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); } + return message; + }; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.Optional} message Optional + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Optional.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) object.elemType = null; + if (message.elemType != null && message.hasOwnProperty('elemType')) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Optional to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Optional + * @instance + * @returns {Object.} JSON object + */ + Optional.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Optional + * @function getTypeUrl + * @memberof onnx.TypeProto.Optional + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.Optional'; + }; - /** - * TensorAnnotation tensorName. - * @member {string} tensorName - * @memberof onnx.TensorAnnotation - * @instance - */ - TensorAnnotation.prototype.tensorName = ""; - - /** - * TensorAnnotation quantParameterTensorNames. - * @member {Array.} quantParameterTensorNames - * @memberof onnx.TensorAnnotation - * @instance - */ - TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; - - /** - * Creates a new TensorAnnotation instance using the specified properties. - * @function create - * @memberof onnx.TensorAnnotation - * @static - * @param {onnx.ITensorAnnotation=} [properties] Properties to set - * @returns {onnx.TensorAnnotation} TensorAnnotation instance - */ - TensorAnnotation.create = function create(properties) { - return new TensorAnnotation(properties); - }; - - /** - * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. - * @function encode - * @memberof onnx.TensorAnnotation - * @static - * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorAnnotation.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.tensorName != null && Object.hasOwnProperty.call(message, "tensorName")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.tensorName); - if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) - for (var i = 0; i < message.quantParameterTensorNames.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.quantParameterTensorNames[i], writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorAnnotation - * @static - * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TensorAnnotation message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorAnnotation - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorAnnotation} TensorAnnotation - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorAnnotation.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorAnnotation(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.tensorName = reader.string(); - break; - } - case 2: { - if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) - message.quantParameterTensorNames = []; - message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorAnnotation - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorAnnotation} TensorAnnotation - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TensorAnnotation message. - * @function verify - * @memberof onnx.TensorAnnotation - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TensorAnnotation.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.tensorName != null && message.hasOwnProperty("tensorName")) - if (!$util.isString(message.tensorName)) - return "tensorName: string expected"; - if (message.quantParameterTensorNames != null && message.hasOwnProperty("quantParameterTensorNames")) { - if (!Array.isArray(message.quantParameterTensorNames)) - return "quantParameterTensorNames: array expected"; - for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); - if (error) - return "quantParameterTensorNames." + error; - } - } - return null; - }; - - /** - * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorAnnotation - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorAnnotation} TensorAnnotation - */ - TensorAnnotation.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorAnnotation) - return object; - var message = new $root.onnx.TensorAnnotation(); - if (object.tensorName != null) - message.tensorName = String(object.tensorName); - if (object.quantParameterTensorNames) { - if (!Array.isArray(object.quantParameterTensorNames)) - throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: array expected"); - message.quantParameterTensorNames = []; - for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { - if (typeof object.quantParameterTensorNames[i] !== "object") - throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: object expected"); - message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject(object.quantParameterTensorNames[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorAnnotation - * @static - * @param {onnx.TensorAnnotation} message TensorAnnotation - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TensorAnnotation.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) - object.quantParameterTensorNames = []; - if (options.defaults) - object.tensorName = ""; - if (message.tensorName != null && message.hasOwnProperty("tensorName")) - object.tensorName = message.tensorName; - if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { - object.quantParameterTensorNames = []; - for (var j = 0; j < message.quantParameterTensorNames.length; ++j) - object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject(message.quantParameterTensorNames[j], options); - } - return object; - }; - - /** - * Converts this TensorAnnotation to JSON. - * @function toJSON - * @memberof onnx.TensorAnnotation - * @instance - * @returns {Object.} JSON object - */ - TensorAnnotation.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TensorAnnotation - * @function getTypeUrl - * @memberof onnx.TensorAnnotation - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; + return Optional; + })(); + + TypeProto.SparseTensor = (function () { + /** + * Properties of a SparseTensor. + * @memberof onnx.TypeProto + * @interface ISparseTensor + * @property {number|null} [elemType] SparseTensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape + */ + + /** + * Constructs a new SparseTensor. + * @memberof onnx.TypeProto + * @classdesc Represents a SparseTensor. + * @implements ISparseTensor + * @constructor + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + */ + function SparseTensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.elemType = 0; + + /** + * SparseTensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.shape = null; + + /** + * Creates a new SparseTensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance + */ + SparseTensor.create = function create(properties) { + return new SparseTensor(properties); + }; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + writer.uint32(/* id 1, wireType 0 =*/ 8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, 'shape')) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.TypeProto.SparseTensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; } - return typeUrlPrefix + "/onnx.TensorAnnotation"; - }; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensor message. + * @function verify + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensor.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.elemType != null && message.hasOwnProperty('elemType')) + if (!$util.isInteger(message.elemType)) return 'elemType: integer expected'; + if (message.shape != null && message.hasOwnProperty('shape')) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) return 'shape.' + error; + } + return null; + }; + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + */ + SparseTensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.SparseTensor) return object; + var message = new $root.onnx.TypeProto.SparseTensor(); + if (object.elemType != null) message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== 'object') throw TypeError('.onnx.TypeProto.SparseTensor.shape: object expected'); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.SparseTensor} message SparseTensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensor.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty('elemType')) object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty('shape')) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this SparseTensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.SparseTensor + * @instance + * @returns {Object.} JSON object + */ + SparseTensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensor + * @function getTypeUrl + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.TypeProto.SparseTensor'; + }; - return TensorAnnotation; + return SparseTensor; })(); - onnx.GraphProto = (function() { - - /** - * Properties of a GraphProto. - * @memberof onnx - * @interface IGraphProto - * @property {Array.|null} [node] GraphProto node - * @property {string|null} [name] GraphProto name - * @property {Array.|null} [initializer] GraphProto initializer - * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer - * @property {string|null} [docString] GraphProto docString - * @property {Array.|null} [input] GraphProto input - * @property {Array.|null} [output] GraphProto output - * @property {Array.|null} [valueInfo] GraphProto valueInfo - * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation - */ - - /** - * Constructs a new GraphProto. - * @memberof onnx - * @classdesc Represents a GraphProto. - * @implements IGraphProto - * @constructor - * @param {onnx.IGraphProto=} [properties] Properties to set - */ - function GraphProto(properties) { - this.node = []; - this.initializer = []; - this.sparseInitializer = []; - this.input = []; - this.output = []; - this.valueInfo = []; - this.quantizationAnnotation = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + return TypeProto; + })(); - /** - * GraphProto node. - * @member {Array.} node - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.node = $util.emptyArray; - - /** - * GraphProto name. - * @member {string} name - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.name = ""; - - /** - * GraphProto initializer. - * @member {Array.} initializer - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.initializer = $util.emptyArray; - - /** - * GraphProto sparseInitializer. - * @member {Array.} sparseInitializer - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.sparseInitializer = $util.emptyArray; - - /** - * GraphProto docString. - * @member {string} docString - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.docString = ""; - - /** - * GraphProto input. - * @member {Array.} input - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.input = $util.emptyArray; - - /** - * GraphProto output. - * @member {Array.} output - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.output = $util.emptyArray; - - /** - * GraphProto valueInfo. - * @member {Array.} valueInfo - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.valueInfo = $util.emptyArray; - - /** - * GraphProto quantizationAnnotation. - * @member {Array.} quantizationAnnotation - * @memberof onnx.GraphProto - * @instance - */ - GraphProto.prototype.quantizationAnnotation = $util.emptyArray; - - /** - * Creates a new GraphProto instance using the specified properties. - * @function create - * @memberof onnx.GraphProto - * @static - * @param {onnx.IGraphProto=} [properties] Properties to set - * @returns {onnx.GraphProto} GraphProto instance - */ - GraphProto.create = function create(properties) { - return new GraphProto(properties); - }; - - /** - * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. - * @function encode - * @memberof onnx.GraphProto - * @static - * @param {onnx.IGraphProto} message GraphProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - GraphProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.node != null && message.node.length) - for (var i = 0; i < message.node.length; ++i) - $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.name); - if (message.initializer != null && message.initializer.length) - for (var i = 0; i < message.initializer.length; ++i) - $root.onnx.TensorProto.encode(message.initializer[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 10, wireType 2 =*/82).string(message.docString); - if (message.input != null && message.input.length) - for (var i = 0; i < message.input.length; ++i) - $root.onnx.ValueInfoProto.encode(message.input[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); - if (message.output != null && message.output.length) - for (var i = 0; i < message.output.length; ++i) - $root.onnx.ValueInfoProto.encode(message.output[i], writer.uint32(/* id 12, wireType 2 =*/98).fork()).ldelim(); - if (message.valueInfo != null && message.valueInfo.length) - for (var i = 0; i < message.valueInfo.length; ++i) - $root.onnx.ValueInfoProto.encode(message.valueInfo[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); - if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) - for (var i = 0; i < message.quantizationAnnotation.length; ++i) - $root.onnx.TensorAnnotation.encode(message.quantizationAnnotation[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); - if (message.sparseInitializer != null && message.sparseInitializer.length) - for (var i = 0; i < message.sparseInitializer.length; ++i) - $root.onnx.SparseTensorProto.encode(message.sparseInitializer[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.GraphProto - * @static - * @param {onnx.IGraphProto} message GraphProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - GraphProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a GraphProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.GraphProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.GraphProto} GraphProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - GraphProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.GraphProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - if (!(message.node && message.node.length)) - message.node = []; - message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); - break; - } - case 2: { - message.name = reader.string(); - break; - } - case 5: { - if (!(message.initializer && message.initializer.length)) - message.initializer = []; - message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); - break; - } - case 15: { - if (!(message.sparseInitializer && message.sparseInitializer.length)) - message.sparseInitializer = []; - message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); - break; - } - case 10: { - message.docString = reader.string(); - break; - } - case 11: { - if (!(message.input && message.input.length)) - message.input = []; - message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); - break; - } - case 12: { - if (!(message.output && message.output.length)) - message.output = []; - message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); - break; - } - case 13: { - if (!(message.valueInfo && message.valueInfo.length)) - message.valueInfo = []; - message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); - break; - } - case 14: { - if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) - message.quantizationAnnotation = []; - message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a GraphProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.GraphProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.GraphProto} GraphProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - GraphProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a GraphProto message. - * @function verify - * @memberof onnx.GraphProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - GraphProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.node != null && message.hasOwnProperty("node")) { - if (!Array.isArray(message.node)) - return "node: array expected"; - for (var i = 0; i < message.node.length; ++i) { - var error = $root.onnx.NodeProto.verify(message.node[i]); - if (error) - return "node." + error; - } - } - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.initializer != null && message.hasOwnProperty("initializer")) { - if (!Array.isArray(message.initializer)) - return "initializer: array expected"; - for (var i = 0; i < message.initializer.length; ++i) { - var error = $root.onnx.TensorProto.verify(message.initializer[i]); - if (error) - return "initializer." + error; - } - } - if (message.sparseInitializer != null && message.hasOwnProperty("sparseInitializer")) { - if (!Array.isArray(message.sparseInitializer)) - return "sparseInitializer: array expected"; - for (var i = 0; i < message.sparseInitializer.length; ++i) { - var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); - if (error) - return "sparseInitializer." + error; - } - } - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.input != null && message.hasOwnProperty("input")) { - if (!Array.isArray(message.input)) - return "input: array expected"; - for (var i = 0; i < message.input.length; ++i) { - var error = $root.onnx.ValueInfoProto.verify(message.input[i]); - if (error) - return "input." + error; - } - } - if (message.output != null && message.hasOwnProperty("output")) { - if (!Array.isArray(message.output)) - return "output: array expected"; - for (var i = 0; i < message.output.length; ++i) { - var error = $root.onnx.ValueInfoProto.verify(message.output[i]); - if (error) - return "output." + error; - } - } - if (message.valueInfo != null && message.hasOwnProperty("valueInfo")) { - if (!Array.isArray(message.valueInfo)) - return "valueInfo: array expected"; - for (var i = 0; i < message.valueInfo.length; ++i) { - var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); - if (error) - return "valueInfo." + error; - } - } - if (message.quantizationAnnotation != null && message.hasOwnProperty("quantizationAnnotation")) { - if (!Array.isArray(message.quantizationAnnotation)) - return "quantizationAnnotation: array expected"; - for (var i = 0; i < message.quantizationAnnotation.length; ++i) { - var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); - if (error) - return "quantizationAnnotation." + error; - } - } - return null; - }; - - /** - * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.GraphProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.GraphProto} GraphProto - */ - GraphProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.GraphProto) - return object; - var message = new $root.onnx.GraphProto(); - if (object.node) { - if (!Array.isArray(object.node)) - throw TypeError(".onnx.GraphProto.node: array expected"); - message.node = []; - for (var i = 0; i < object.node.length; ++i) { - if (typeof object.node[i] !== "object") - throw TypeError(".onnx.GraphProto.node: object expected"); - message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); - } - } - if (object.name != null) - message.name = String(object.name); - if (object.initializer) { - if (!Array.isArray(object.initializer)) - throw TypeError(".onnx.GraphProto.initializer: array expected"); - message.initializer = []; - for (var i = 0; i < object.initializer.length; ++i) { - if (typeof object.initializer[i] !== "object") - throw TypeError(".onnx.GraphProto.initializer: object expected"); - message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); - } - } - if (object.sparseInitializer) { - if (!Array.isArray(object.sparseInitializer)) - throw TypeError(".onnx.GraphProto.sparseInitializer: array expected"); - message.sparseInitializer = []; - for (var i = 0; i < object.sparseInitializer.length; ++i) { - if (typeof object.sparseInitializer[i] !== "object") - throw TypeError(".onnx.GraphProto.sparseInitializer: object expected"); - message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); - } - } - if (object.docString != null) - message.docString = String(object.docString); - if (object.input) { - if (!Array.isArray(object.input)) - throw TypeError(".onnx.GraphProto.input: array expected"); - message.input = []; - for (var i = 0; i < object.input.length; ++i) { - if (typeof object.input[i] !== "object") - throw TypeError(".onnx.GraphProto.input: object expected"); - message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); - } - } - if (object.output) { - if (!Array.isArray(object.output)) - throw TypeError(".onnx.GraphProto.output: array expected"); - message.output = []; - for (var i = 0; i < object.output.length; ++i) { - if (typeof object.output[i] !== "object") - throw TypeError(".onnx.GraphProto.output: object expected"); - message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); - } - } - if (object.valueInfo) { - if (!Array.isArray(object.valueInfo)) - throw TypeError(".onnx.GraphProto.valueInfo: array expected"); - message.valueInfo = []; - for (var i = 0; i < object.valueInfo.length; ++i) { - if (typeof object.valueInfo[i] !== "object") - throw TypeError(".onnx.GraphProto.valueInfo: object expected"); - message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); - } - } - if (object.quantizationAnnotation) { - if (!Array.isArray(object.quantizationAnnotation)) - throw TypeError(".onnx.GraphProto.quantizationAnnotation: array expected"); - message.quantizationAnnotation = []; - for (var i = 0; i < object.quantizationAnnotation.length; ++i) { - if (typeof object.quantizationAnnotation[i] !== "object") - throw TypeError(".onnx.GraphProto.quantizationAnnotation: object expected"); - message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a GraphProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.GraphProto - * @static - * @param {onnx.GraphProto} message GraphProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - GraphProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.node = []; - object.initializer = []; - object.input = []; - object.output = []; - object.valueInfo = []; - object.quantizationAnnotation = []; - object.sparseInitializer = []; - } - if (options.defaults) { - object.name = ""; - object.docString = ""; - } - if (message.node && message.node.length) { - object.node = []; - for (var j = 0; j < message.node.length; ++j) - object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.initializer && message.initializer.length) { - object.initializer = []; - for (var j = 0; j < message.initializer.length; ++j) - object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.input && message.input.length) { - object.input = []; - for (var j = 0; j < message.input.length; ++j) - object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); - } - if (message.output && message.output.length) { - object.output = []; - for (var j = 0; j < message.output.length; ++j) - object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); - } - if (message.valueInfo && message.valueInfo.length) { - object.valueInfo = []; - for (var j = 0; j < message.valueInfo.length; ++j) - object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); - } - if (message.quantizationAnnotation && message.quantizationAnnotation.length) { - object.quantizationAnnotation = []; - for (var j = 0; j < message.quantizationAnnotation.length; ++j) - object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject(message.quantizationAnnotation[j], options); - } - if (message.sparseInitializer && message.sparseInitializer.length) { - object.sparseInitializer = []; - for (var j = 0; j < message.sparseInitializer.length; ++j) - object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); - } - return object; - }; - - /** - * Converts this GraphProto to JSON. - * @function toJSON - * @memberof onnx.GraphProto - * @instance - * @returns {Object.} JSON object - */ - GraphProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for GraphProto - * @function getTypeUrl - * @memberof onnx.GraphProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.GraphProto"; - }; + onnx.OperatorSetIdProto = (function () { + /** + * Properties of an OperatorSetIdProto. + * @memberof onnx + * @interface IOperatorSetIdProto + * @property {string|null} [domain] OperatorSetIdProto domain + * @property {number|Long|null} [version] OperatorSetIdProto version + */ - return GraphProto; - })(); + /** + * Constructs a new OperatorSetIdProto. + * @memberof onnx + * @classdesc Represents an OperatorSetIdProto. + * @implements IOperatorSetIdProto + * @constructor + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + */ + function OperatorSetIdProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } - onnx.TensorProto = (function() { - - /** - * Properties of a TensorProto. - * @memberof onnx - * @interface ITensorProto - * @property {Array.|null} [dims] TensorProto dims - * @property {number|null} [dataType] TensorProto dataType - * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment - * @property {Array.|null} [floatData] TensorProto floatData - * @property {Array.|null} [int32Data] TensorProto int32Data - * @property {Array.|null} [stringData] TensorProto stringData - * @property {Array.|null} [int64Data] TensorProto int64Data - * @property {string|null} [name] TensorProto name - * @property {string|null} [docString] TensorProto docString - * @property {Uint8Array|null} [rawData] TensorProto rawData - * @property {Array.|null} [externalData] TensorProto externalData - * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation - * @property {Array.|null} [doubleData] TensorProto doubleData - * @property {Array.|null} [uint64Data] TensorProto uint64Data - */ - - /** - * Constructs a new TensorProto. - * @memberof onnx - * @classdesc Represents a TensorProto. - * @implements ITensorProto - * @constructor - * @param {onnx.ITensorProto=} [properties] Properties to set - */ - function TensorProto(properties) { - this.dims = []; - this.floatData = []; - this.int32Data = []; - this.stringData = []; - this.int64Data = []; - this.externalData = []; - this.doubleData = []; - this.uint64Data = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * OperatorSetIdProto domain. + * @member {string} domain + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.domain = ''; - /** - * TensorProto dims. - * @member {Array.} dims - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.dims = $util.emptyArray; - - /** - * TensorProto dataType. - * @member {number} dataType - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.dataType = 0; - - /** - * TensorProto segment. - * @member {onnx.TensorProto.ISegment|null|undefined} segment - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.segment = null; - - /** - * TensorProto floatData. - * @member {Array.} floatData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.floatData = $util.emptyArray; - - /** - * TensorProto int32Data. - * @member {Array.} int32Data - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.int32Data = $util.emptyArray; - - /** - * TensorProto stringData. - * @member {Array.} stringData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.stringData = $util.emptyArray; - - /** - * TensorProto int64Data. - * @member {Array.} int64Data - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.int64Data = $util.emptyArray; - - /** - * TensorProto name. - * @member {string} name - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.name = ""; - - /** - * TensorProto docString. - * @member {string} docString - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.docString = ""; - - /** - * TensorProto rawData. - * @member {Uint8Array} rawData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.rawData = $util.newBuffer([]); - - /** - * TensorProto externalData. - * @member {Array.} externalData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.externalData = $util.emptyArray; - - /** - * TensorProto dataLocation. - * @member {onnx.TensorProto.DataLocation} dataLocation - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.dataLocation = 0; - - /** - * TensorProto doubleData. - * @member {Array.} doubleData - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.doubleData = $util.emptyArray; - - /** - * TensorProto uint64Data. - * @member {Array.} uint64Data - * @memberof onnx.TensorProto - * @instance - */ - TensorProto.prototype.uint64Data = $util.emptyArray; - - /** - * Creates a new TensorProto instance using the specified properties. - * @function create - * @memberof onnx.TensorProto - * @static - * @param {onnx.ITensorProto=} [properties] Properties to set - * @returns {onnx.TensorProto} TensorProto instance - */ - TensorProto.create = function create(properties) { - return new TensorProto(properties); - }; - - /** - * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. - * @function encode - * @memberof onnx.TensorProto - * @static - * @param {onnx.ITensorProto} message TensorProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.dims != null && message.dims.length) { - writer.uint32(/* id 1, wireType 2 =*/10).fork(); - for (var i = 0; i < message.dims.length; ++i) - writer.int64(message.dims[i]); - writer.ldelim(); - } - if (message.dataType != null && Object.hasOwnProperty.call(message, "dataType")) - writer.uint32(/* id 2, wireType 0 =*/16).int32(message.dataType); - if (message.segment != null && Object.hasOwnProperty.call(message, "segment")) - $root.onnx.TensorProto.Segment.encode(message.segment, writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); - if (message.floatData != null && message.floatData.length) { - writer.uint32(/* id 4, wireType 2 =*/34).fork(); - for (var i = 0; i < message.floatData.length; ++i) - writer.float(message.floatData[i]); - writer.ldelim(); - } - if (message.int32Data != null && message.int32Data.length) { - writer.uint32(/* id 5, wireType 2 =*/42).fork(); - for (var i = 0; i < message.int32Data.length; ++i) - writer.int32(message.int32Data[i]); - writer.ldelim(); - } - if (message.stringData != null && message.stringData.length) - for (var i = 0; i < message.stringData.length; ++i) - writer.uint32(/* id 6, wireType 2 =*/50).bytes(message.stringData[i]); - if (message.int64Data != null && message.int64Data.length) { - writer.uint32(/* id 7, wireType 2 =*/58).fork(); - for (var i = 0; i < message.int64Data.length; ++i) - writer.int64(message.int64Data[i]); - writer.ldelim(); - } - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 8, wireType 2 =*/66).string(message.name); - if (message.rawData != null && Object.hasOwnProperty.call(message, "rawData")) - writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.rawData); - if (message.doubleData != null && message.doubleData.length) { - writer.uint32(/* id 10, wireType 2 =*/82).fork(); - for (var i = 0; i < message.doubleData.length; ++i) - writer.double(message.doubleData[i]); - writer.ldelim(); - } - if (message.uint64Data != null && message.uint64Data.length) { - writer.uint32(/* id 11, wireType 2 =*/90).fork(); - for (var i = 0; i < message.uint64Data.length; ++i) - writer.uint64(message.uint64Data[i]); - writer.ldelim(); - } - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 12, wireType 2 =*/98).string(message.docString); - if (message.externalData != null && message.externalData.length) - for (var i = 0; i < message.externalData.length; ++i) - $root.onnx.StringStringEntryProto.encode(message.externalData[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); - if (message.dataLocation != null && Object.hasOwnProperty.call(message, "dataLocation")) - writer.uint32(/* id 14, wireType 0 =*/112).int32(message.dataLocation); - return writer; - }; - - /** - * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorProto - * @static - * @param {onnx.ITensorProto} message TensorProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TensorProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorProto} TensorProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - if (!(message.dims && message.dims.length)) - message.dims = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.dims.push(reader.int64()); - } else - message.dims.push(reader.int64()); - break; - } - case 2: { - message.dataType = reader.int32(); - break; - } - case 3: { - message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); - break; - } - case 4: { - if (!(message.floatData && message.floatData.length)) - message.floatData = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.floatData.push(reader.float()); - } else - message.floatData.push(reader.float()); - break; - } - case 5: { - if (!(message.int32Data && message.int32Data.length)) - message.int32Data = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.int32Data.push(reader.int32()); - } else - message.int32Data.push(reader.int32()); - break; - } - case 6: { - if (!(message.stringData && message.stringData.length)) - message.stringData = []; - message.stringData.push(reader.bytes()); - break; - } - case 7: { - if (!(message.int64Data && message.int64Data.length)) - message.int64Data = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.int64Data.push(reader.int64()); - } else - message.int64Data.push(reader.int64()); - break; - } - case 8: { - message.name = reader.string(); - break; - } - case 12: { - message.docString = reader.string(); - break; - } - case 9: { - message.rawData = reader.bytes(); - break; - } - case 13: { - if (!(message.externalData && message.externalData.length)) - message.externalData = []; - message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); - break; - } - case 14: { - message.dataLocation = reader.int32(); - break; - } - case 10: { - if (!(message.doubleData && message.doubleData.length)) - message.doubleData = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.doubleData.push(reader.double()); - } else - message.doubleData.push(reader.double()); - break; - } - case 11: { - if (!(message.uint64Data && message.uint64Data.length)) - message.uint64Data = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.uint64Data.push(reader.uint64()); - } else - message.uint64Data.push(reader.uint64()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TensorProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorProto} TensorProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TensorProto message. - * @function verify - * @memberof onnx.TensorProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TensorProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.dims != null && message.hasOwnProperty("dims")) { - if (!Array.isArray(message.dims)) - return "dims: array expected"; - for (var i = 0; i < message.dims.length; ++i) - if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) - return "dims: integer|Long[] expected"; - } - if (message.dataType != null && message.hasOwnProperty("dataType")) - if (!$util.isInteger(message.dataType)) - return "dataType: integer expected"; - if (message.segment != null && message.hasOwnProperty("segment")) { - var error = $root.onnx.TensorProto.Segment.verify(message.segment); - if (error) - return "segment." + error; - } - if (message.floatData != null && message.hasOwnProperty("floatData")) { - if (!Array.isArray(message.floatData)) - return "floatData: array expected"; - for (var i = 0; i < message.floatData.length; ++i) - if (typeof message.floatData[i] !== "number") - return "floatData: number[] expected"; - } - if (message.int32Data != null && message.hasOwnProperty("int32Data")) { - if (!Array.isArray(message.int32Data)) - return "int32Data: array expected"; - for (var i = 0; i < message.int32Data.length; ++i) - if (!$util.isInteger(message.int32Data[i])) - return "int32Data: integer[] expected"; - } - if (message.stringData != null && message.hasOwnProperty("stringData")) { - if (!Array.isArray(message.stringData)) - return "stringData: array expected"; - for (var i = 0; i < message.stringData.length; ++i) - if (!(message.stringData[i] && typeof message.stringData[i].length === "number" || $util.isString(message.stringData[i]))) - return "stringData: buffer[] expected"; - } - if (message.int64Data != null && message.hasOwnProperty("int64Data")) { - if (!Array.isArray(message.int64Data)) - return "int64Data: array expected"; - for (var i = 0; i < message.int64Data.length; ++i) - if (!$util.isInteger(message.int64Data[i]) && !(message.int64Data[i] && $util.isInteger(message.int64Data[i].low) && $util.isInteger(message.int64Data[i].high))) - return "int64Data: integer|Long[] expected"; - } - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.rawData != null && message.hasOwnProperty("rawData")) - if (!(message.rawData && typeof message.rawData.length === "number" || $util.isString(message.rawData))) - return "rawData: buffer expected"; - if (message.externalData != null && message.hasOwnProperty("externalData")) { - if (!Array.isArray(message.externalData)) - return "externalData: array expected"; - for (var i = 0; i < message.externalData.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); - if (error) - return "externalData." + error; - } - } - if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) - switch (message.dataLocation) { - default: - return "dataLocation: enum value expected"; - case 0: - case 1: - break; - } - if (message.doubleData != null && message.hasOwnProperty("doubleData")) { - if (!Array.isArray(message.doubleData)) - return "doubleData: array expected"; - for (var i = 0; i < message.doubleData.length; ++i) - if (typeof message.doubleData[i] !== "number") - return "doubleData: number[] expected"; - } - if (message.uint64Data != null && message.hasOwnProperty("uint64Data")) { - if (!Array.isArray(message.uint64Data)) - return "uint64Data: array expected"; - for (var i = 0; i < message.uint64Data.length; ++i) - if (!$util.isInteger(message.uint64Data[i]) && !(message.uint64Data[i] && $util.isInteger(message.uint64Data[i].low) && $util.isInteger(message.uint64Data[i].high))) - return "uint64Data: integer|Long[] expected"; - } - return null; - }; - - /** - * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorProto} TensorProto - */ - TensorProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorProto) - return object; - var message = new $root.onnx.TensorProto(); - if (object.dims) { - if (!Array.isArray(object.dims)) - throw TypeError(".onnx.TensorProto.dims: array expected"); - message.dims = []; - for (var i = 0; i < object.dims.length; ++i) - if ($util.Long) - (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; - else if (typeof object.dims[i] === "string") - message.dims[i] = parseInt(object.dims[i], 10); - else if (typeof object.dims[i] === "number") - message.dims[i] = object.dims[i]; - else if (typeof object.dims[i] === "object") - message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); - } - if (object.dataType != null) - message.dataType = object.dataType | 0; - if (object.segment != null) { - if (typeof object.segment !== "object") - throw TypeError(".onnx.TensorProto.segment: object expected"); - message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); - } - if (object.floatData) { - if (!Array.isArray(object.floatData)) - throw TypeError(".onnx.TensorProto.floatData: array expected"); - message.floatData = []; - for (var i = 0; i < object.floatData.length; ++i) - message.floatData[i] = Number(object.floatData[i]); - } - if (object.int32Data) { - if (!Array.isArray(object.int32Data)) - throw TypeError(".onnx.TensorProto.int32Data: array expected"); - message.int32Data = []; - for (var i = 0; i < object.int32Data.length; ++i) - message.int32Data[i] = object.int32Data[i] | 0; - } - if (object.stringData) { - if (!Array.isArray(object.stringData)) - throw TypeError(".onnx.TensorProto.stringData: array expected"); - message.stringData = []; - for (var i = 0; i < object.stringData.length; ++i) - if (typeof object.stringData[i] === "string") - $util.base64.decode(object.stringData[i], message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i])), 0); - else if (object.stringData[i].length >= 0) - message.stringData[i] = object.stringData[i]; - } - if (object.int64Data) { - if (!Array.isArray(object.int64Data)) - throw TypeError(".onnx.TensorProto.int64Data: array expected"); - message.int64Data = []; - for (var i = 0; i < object.int64Data.length; ++i) - if ($util.Long) - (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; - else if (typeof object.int64Data[i] === "string") - message.int64Data[i] = parseInt(object.int64Data[i], 10); - else if (typeof object.int64Data[i] === "number") - message.int64Data[i] = object.int64Data[i]; - else if (typeof object.int64Data[i] === "object") - message.int64Data[i] = new $util.LongBits(object.int64Data[i].low >>> 0, object.int64Data[i].high >>> 0).toNumber(); - } - if (object.name != null) - message.name = String(object.name); - if (object.docString != null) - message.docString = String(object.docString); - if (object.rawData != null) - if (typeof object.rawData === "string") - $util.base64.decode(object.rawData, message.rawData = $util.newBuffer($util.base64.length(object.rawData)), 0); - else if (object.rawData.length >= 0) - message.rawData = object.rawData; - if (object.externalData) { - if (!Array.isArray(object.externalData)) - throw TypeError(".onnx.TensorProto.externalData: array expected"); - message.externalData = []; - for (var i = 0; i < object.externalData.length; ++i) { - if (typeof object.externalData[i] !== "object") - throw TypeError(".onnx.TensorProto.externalData: object expected"); - message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); - } - } - switch (object.dataLocation) { - default: - if (typeof object.dataLocation === "number") { - message.dataLocation = object.dataLocation; - break; - } - break; - case "DEFAULT": - case 0: - message.dataLocation = 0; - break; - case "EXTERNAL": - case 1: - message.dataLocation = 1; - break; - } - if (object.doubleData) { - if (!Array.isArray(object.doubleData)) - throw TypeError(".onnx.TensorProto.doubleData: array expected"); - message.doubleData = []; - for (var i = 0; i < object.doubleData.length; ++i) - message.doubleData[i] = Number(object.doubleData[i]); - } - if (object.uint64Data) { - if (!Array.isArray(object.uint64Data)) - throw TypeError(".onnx.TensorProto.uint64Data: array expected"); - message.uint64Data = []; - for (var i = 0; i < object.uint64Data.length; ++i) - if ($util.Long) - (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; - else if (typeof object.uint64Data[i] === "string") - message.uint64Data[i] = parseInt(object.uint64Data[i], 10); - else if (typeof object.uint64Data[i] === "number") - message.uint64Data[i] = object.uint64Data[i]; - else if (typeof object.uint64Data[i] === "object") - message.uint64Data[i] = new $util.LongBits(object.uint64Data[i].low >>> 0, object.uint64Data[i].high >>> 0).toNumber(true); - } - return message; - }; - - /** - * Creates a plain object from a TensorProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorProto - * @static - * @param {onnx.TensorProto} message TensorProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TensorProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.dims = []; - object.floatData = []; - object.int32Data = []; - object.stringData = []; - object.int64Data = []; - object.doubleData = []; - object.uint64Data = []; - object.externalData = []; - } - if (options.defaults) { - object.dataType = 0; - object.segment = null; - object.name = ""; - if (options.bytes === String) - object.rawData = ""; - else { - object.rawData = []; - if (options.bytes !== Array) - object.rawData = $util.newBuffer(object.rawData); - } - object.docString = ""; - object.dataLocation = options.enums === String ? "DEFAULT" : 0; - } - if (message.dims && message.dims.length) { - object.dims = []; - for (var j = 0; j < message.dims.length; ++j) - if (typeof message.dims[j] === "number") - object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; - else - object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; - } - if (message.dataType != null && message.hasOwnProperty("dataType")) - object.dataType = message.dataType; - if (message.segment != null && message.hasOwnProperty("segment")) - object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); - if (message.floatData && message.floatData.length) { - object.floatData = []; - for (var j = 0; j < message.floatData.length; ++j) - object.floatData[j] = options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; - } - if (message.int32Data && message.int32Data.length) { - object.int32Data = []; - for (var j = 0; j < message.int32Data.length; ++j) - object.int32Data[j] = message.int32Data[j]; - } - if (message.stringData && message.stringData.length) { - object.stringData = []; - for (var j = 0; j < message.stringData.length; ++j) - object.stringData[j] = options.bytes === String ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.stringData[j]) : message.stringData[j]; - } - if (message.int64Data && message.int64Data.length) { - object.int64Data = []; - for (var j = 0; j < message.int64Data.length; ++j) - if (typeof message.int64Data[j] === "number") - object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; - else - object.int64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.int64Data[j]) : options.longs === Number ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() : message.int64Data[j]; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.rawData != null && message.hasOwnProperty("rawData")) - object.rawData = options.bytes === String ? $util.base64.encode(message.rawData, 0, message.rawData.length) : options.bytes === Array ? Array.prototype.slice.call(message.rawData) : message.rawData; - if (message.doubleData && message.doubleData.length) { - object.doubleData = []; - for (var j = 0; j < message.doubleData.length; ++j) - object.doubleData[j] = options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; - } - if (message.uint64Data && message.uint64Data.length) { - object.uint64Data = []; - for (var j = 0; j < message.uint64Data.length; ++j) - if (typeof message.uint64Data[j] === "number") - object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; - else - object.uint64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.uint64Data[j]) : options.longs === Number ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) : message.uint64Data[j]; - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.externalData && message.externalData.length) { - object.externalData = []; - for (var j = 0; j < message.externalData.length; ++j) - object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); - } - if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) - object.dataLocation = options.enums === String ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined ? message.dataLocation : $root.onnx.TensorProto.DataLocation[message.dataLocation] : message.dataLocation; - return object; - }; - - /** - * Converts this TensorProto to JSON. - * @function toJSON - * @memberof onnx.TensorProto - * @instance - * @returns {Object.} JSON object - */ - TensorProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TensorProto - * @function getTypeUrl - * @memberof onnx.TensorProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TensorProto"; - }; - - /** - * DataType enum. - * @name onnx.TensorProto.DataType - * @enum {number} - * @property {number} UNDEFINED=0 UNDEFINED value - * @property {number} FLOAT=1 FLOAT value - * @property {number} UINT8=2 UINT8 value - * @property {number} INT8=3 INT8 value - * @property {number} UINT16=4 UINT16 value - * @property {number} INT16=5 INT16 value - * @property {number} INT32=6 INT32 value - * @property {number} INT64=7 INT64 value - * @property {number} STRING=8 STRING value - * @property {number} BOOL=9 BOOL value - * @property {number} FLOAT16=10 FLOAT16 value - * @property {number} DOUBLE=11 DOUBLE value - * @property {number} UINT32=12 UINT32 value - * @property {number} UINT64=13 UINT64 value - * @property {number} COMPLEX64=14 COMPLEX64 value - * @property {number} COMPLEX128=15 COMPLEX128 value - * @property {number} BFLOAT16=16 BFLOAT16 value - * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value - * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value - * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value - * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value - */ - TensorProto.DataType = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "UNDEFINED"] = 0; - values[valuesById[1] = "FLOAT"] = 1; - values[valuesById[2] = "UINT8"] = 2; - values[valuesById[3] = "INT8"] = 3; - values[valuesById[4] = "UINT16"] = 4; - values[valuesById[5] = "INT16"] = 5; - values[valuesById[6] = "INT32"] = 6; - values[valuesById[7] = "INT64"] = 7; - values[valuesById[8] = "STRING"] = 8; - values[valuesById[9] = "BOOL"] = 9; - values[valuesById[10] = "FLOAT16"] = 10; - values[valuesById[11] = "DOUBLE"] = 11; - values[valuesById[12] = "UINT32"] = 12; - values[valuesById[13] = "UINT64"] = 13; - values[valuesById[14] = "COMPLEX64"] = 14; - values[valuesById[15] = "COMPLEX128"] = 15; - values[valuesById[16] = "BFLOAT16"] = 16; - values[valuesById[17] = "FLOAT8E4M3FN"] = 17; - values[valuesById[18] = "FLOAT8E4M3FNUZ"] = 18; - values[valuesById[19] = "FLOAT8E5M2"] = 19; - values[valuesById[20] = "FLOAT8E5M2FNUZ"] = 20; - return values; - })(); - - TensorProto.Segment = (function() { - - /** - * Properties of a Segment. - * @memberof onnx.TensorProto - * @interface ISegment - * @property {number|Long|null} [begin] Segment begin - * @property {number|Long|null} [end] Segment end - */ - - /** - * Constructs a new Segment. - * @memberof onnx.TensorProto - * @classdesc Represents a Segment. - * @implements ISegment - * @constructor - * @param {onnx.TensorProto.ISegment=} [properties] Properties to set - */ - function Segment(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * OperatorSetIdProto version. + * @member {number|Long} version + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; - /** - * Segment begin. - * @member {number|Long} begin - * @memberof onnx.TensorProto.Segment - * @instance - */ - Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * Segment end. - * @member {number|Long} end - * @memberof onnx.TensorProto.Segment - * @instance - */ - Segment.prototype.end = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * Creates a new Segment instance using the specified properties. - * @function create - * @memberof onnx.TensorProto.Segment - * @static - * @param {onnx.TensorProto.ISegment=} [properties] Properties to set - * @returns {onnx.TensorProto.Segment} Segment instance - */ - Segment.create = function create(properties) { - return new Segment(properties); - }; - - /** - * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. - * @function encode - * @memberof onnx.TensorProto.Segment - * @static - * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Segment.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.begin != null && Object.hasOwnProperty.call(message, "begin")) - writer.uint32(/* id 1, wireType 0 =*/8).int64(message.begin); - if (message.end != null && Object.hasOwnProperty.call(message, "end")) - writer.uint32(/* id 2, wireType 0 =*/16).int64(message.end); - return writer; - }; - - /** - * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorProto.Segment - * @static - * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Segment.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Segment message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorProto.Segment - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorProto.Segment} Segment - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Segment.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto.Segment(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.begin = reader.int64(); - break; - } - case 2: { - message.end = reader.int64(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Segment message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorProto.Segment - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorProto.Segment} Segment - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Segment.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Segment message. - * @function verify - * @memberof onnx.TensorProto.Segment - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Segment.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.begin != null && message.hasOwnProperty("begin")) - if (!$util.isInteger(message.begin) && !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high))) - return "begin: integer|Long expected"; - if (message.end != null && message.hasOwnProperty("end")) - if (!$util.isInteger(message.end) && !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high))) - return "end: integer|Long expected"; - return null; - }; - - /** - * Creates a Segment message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorProto.Segment - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorProto.Segment} Segment - */ - Segment.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorProto.Segment) - return object; - var message = new $root.onnx.TensorProto.Segment(); - if (object.begin != null) - if ($util.Long) - (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; - else if (typeof object.begin === "string") - message.begin = parseInt(object.begin, 10); - else if (typeof object.begin === "number") - message.begin = object.begin; - else if (typeof object.begin === "object") - message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); - if (object.end != null) - if ($util.Long) - (message.end = $util.Long.fromValue(object.end)).unsigned = false; - else if (typeof object.end === "string") - message.end = parseInt(object.end, 10); - else if (typeof object.end === "number") - message.end = object.end; - else if (typeof object.end === "object") - message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); - return message; - }; - - /** - * Creates a plain object from a Segment message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorProto.Segment - * @static - * @param {onnx.TensorProto.Segment} message Segment - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Segment.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.begin = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.begin = options.longs === String ? "0" : 0; - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.end = options.longs === String ? "0" : 0; - } - if (message.begin != null && message.hasOwnProperty("begin")) - if (typeof message.begin === "number") - object.begin = options.longs === String ? String(message.begin) : message.begin; - else - object.begin = options.longs === String ? $util.Long.prototype.toString.call(message.begin) : options.longs === Number ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() : message.begin; - if (message.end != null && message.hasOwnProperty("end")) - if (typeof message.end === "number") - object.end = options.longs === String ? String(message.end) : message.end; - else - object.end = options.longs === String ? $util.Long.prototype.toString.call(message.end) : options.longs === Number ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() : message.end; - return object; - }; - - /** - * Converts this Segment to JSON. - * @function toJSON - * @memberof onnx.TensorProto.Segment - * @instance - * @returns {Object.} JSON object - */ - Segment.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Segment - * @function getTypeUrl - * @memberof onnx.TensorProto.Segment - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TensorProto.Segment"; - }; - - return Segment; - })(); - - /** - * DataLocation enum. - * @name onnx.TensorProto.DataLocation - * @enum {number} - * @property {number} DEFAULT=0 DEFAULT value - * @property {number} EXTERNAL=1 EXTERNAL value - */ - TensorProto.DataLocation = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "DEFAULT"] = 0; - values[valuesById[1] = "EXTERNAL"] = 1; - return values; - })(); - - return TensorProto; - })(); + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @function create + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance + */ + OperatorSetIdProto.create = function create(properties) { + return new OperatorSetIdProto(properties); + }; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.domain); + if (message.version != null && Object.hasOwnProperty.call(message, 'version')) + writer.uint32(/* id 2, wireType 0 =*/ 16).int64(message.version); + return writer; + }; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; - onnx.SparseTensorProto = (function() { - - /** - * Properties of a SparseTensorProto. - * @memberof onnx - * @interface ISparseTensorProto - * @property {onnx.ITensorProto|null} [values] SparseTensorProto values - * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices - * @property {Array.|null} [dims] SparseTensorProto dims - */ - - /** - * Constructs a new SparseTensorProto. - * @memberof onnx - * @classdesc Represents a SparseTensorProto. - * @implements ISparseTensorProto - * @constructor - * @param {onnx.ISparseTensorProto=} [properties] Properties to set - */ - function SparseTensorProto(properties) { - this.dims = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.OperatorSetIdProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.domain = reader.string(); + break; + } + case 2: { + message.version = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * SparseTensorProto values. - * @member {onnx.ITensorProto|null|undefined} values - * @memberof onnx.SparseTensorProto - * @instance - */ - SparseTensorProto.prototype.values = null; - - /** - * SparseTensorProto indices. - * @member {onnx.ITensorProto|null|undefined} indices - * @memberof onnx.SparseTensorProto - * @instance - */ - SparseTensorProto.prototype.indices = null; - - /** - * SparseTensorProto dims. - * @member {Array.} dims - * @memberof onnx.SparseTensorProto - * @instance - */ - SparseTensorProto.prototype.dims = $util.emptyArray; - - /** - * Creates a new SparseTensorProto instance using the specified properties. - * @function create - * @memberof onnx.SparseTensorProto - * @static - * @param {onnx.ISparseTensorProto=} [properties] Properties to set - * @returns {onnx.SparseTensorProto} SparseTensorProto instance - */ - SparseTensorProto.create = function create(properties) { - return new SparseTensorProto(properties); - }; - - /** - * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. - * @function encode - * @memberof onnx.SparseTensorProto - * @static - * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - SparseTensorProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.values != null && Object.hasOwnProperty.call(message, "values")) - $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - if (message.indices != null && Object.hasOwnProperty.call(message, "indices")) - $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - if (message.dims != null && message.dims.length) { - writer.uint32(/* id 3, wireType 2 =*/26).fork(); - for (var i = 0; i < message.dims.length; ++i) - writer.int64(message.dims[i]); - writer.ldelim(); - } - return writer; - }; - - /** - * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.SparseTensorProto - * @static - * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a SparseTensorProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.SparseTensorProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.SparseTensorProto} SparseTensorProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - SparseTensorProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.SparseTensorProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); - break; - } - case 2: { - message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); - break; - } - case 3: { - if (!(message.dims && message.dims.length)) - message.dims = []; - if ((tag & 7) === 2) { - var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) - message.dims.push(reader.int64()); - } else - message.dims.push(reader.int64()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.SparseTensorProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.SparseTensorProto} SparseTensorProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a SparseTensorProto message. - * @function verify - * @memberof onnx.SparseTensorProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - SparseTensorProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.values != null && message.hasOwnProperty("values")) { - var error = $root.onnx.TensorProto.verify(message.values); - if (error) - return "values." + error; - } - if (message.indices != null && message.hasOwnProperty("indices")) { - var error = $root.onnx.TensorProto.verify(message.indices); - if (error) - return "indices." + error; - } - if (message.dims != null && message.hasOwnProperty("dims")) { - if (!Array.isArray(message.dims)) - return "dims: array expected"; - for (var i = 0; i < message.dims.length; ++i) - if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) - return "dims: integer|Long[] expected"; - } - return null; - }; - - /** - * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.SparseTensorProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.SparseTensorProto} SparseTensorProto - */ - SparseTensorProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.SparseTensorProto) - return object; - var message = new $root.onnx.SparseTensorProto(); - if (object.values != null) { - if (typeof object.values !== "object") - throw TypeError(".onnx.SparseTensorProto.values: object expected"); - message.values = $root.onnx.TensorProto.fromObject(object.values); - } - if (object.indices != null) { - if (typeof object.indices !== "object") - throw TypeError(".onnx.SparseTensorProto.indices: object expected"); - message.indices = $root.onnx.TensorProto.fromObject(object.indices); - } - if (object.dims) { - if (!Array.isArray(object.dims)) - throw TypeError(".onnx.SparseTensorProto.dims: array expected"); - message.dims = []; - for (var i = 0; i < object.dims.length; ++i) - if ($util.Long) - (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; - else if (typeof object.dims[i] === "string") - message.dims[i] = parseInt(object.dims[i], 10); - else if (typeof object.dims[i] === "number") - message.dims[i] = object.dims[i]; - else if (typeof object.dims[i] === "object") - message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); - } - return message; - }; - - /** - * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.SparseTensorProto - * @static - * @param {onnx.SparseTensorProto} message SparseTensorProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - SparseTensorProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) - object.dims = []; - if (options.defaults) { - object.values = null; - object.indices = null; - } - if (message.values != null && message.hasOwnProperty("values")) - object.values = $root.onnx.TensorProto.toObject(message.values, options); - if (message.indices != null && message.hasOwnProperty("indices")) - object.indices = $root.onnx.TensorProto.toObject(message.indices, options); - if (message.dims && message.dims.length) { - object.dims = []; - for (var j = 0; j < message.dims.length; ++j) - if (typeof message.dims[j] === "number") - object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; - else - object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; - } - return object; - }; - - /** - * Converts this SparseTensorProto to JSON. - * @function toJSON - * @memberof onnx.SparseTensorProto - * @instance - * @returns {Object.} JSON object - */ - SparseTensorProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for SparseTensorProto - * @function getTypeUrl - * @memberof onnx.SparseTensorProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.SparseTensorProto"; - }; + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; - return SparseTensorProto; - })(); + /** + * Verifies an OperatorSetIdProto message. + * @function verify + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + OperatorSetIdProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.domain != null && message.hasOwnProperty('domain')) + if (!$util.isString(message.domain)) return 'domain: string expected'; + if (message.version != null && message.hasOwnProperty('version')) + if ( + !$util.isInteger(message.version) && + !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high)) + ) + return 'version: integer|Long expected'; + return null; + }; - onnx.TensorShapeProto = (function() { - - /** - * Properties of a TensorShapeProto. - * @memberof onnx - * @interface ITensorShapeProto - * @property {Array.|null} [dim] TensorShapeProto dim - */ - - /** - * Constructs a new TensorShapeProto. - * @memberof onnx - * @classdesc Represents a TensorShapeProto. - * @implements ITensorShapeProto - * @constructor - * @param {onnx.ITensorShapeProto=} [properties] Properties to set - */ - function TensorShapeProto(properties) { - this.dim = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + */ + OperatorSetIdProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.OperatorSetIdProto) return object; + var message = new $root.onnx.OperatorSetIdProto(); + if (object.domain != null) message.domain = String(object.domain); + if (object.version != null) + if ($util.Long) (message.version = $util.Long.fromValue(object.version)).unsigned = false; + else if (typeof object.version === 'string') message.version = parseInt(object.version, 10); + else if (typeof object.version === 'number') message.version = object.version; + else if (typeof object.version === 'object') + message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); + return message; + }; - /** - * TensorShapeProto dim. - * @member {Array.} dim - * @memberof onnx.TensorShapeProto - * @instance - */ - TensorShapeProto.prototype.dim = $util.emptyArray; - - /** - * Creates a new TensorShapeProto instance using the specified properties. - * @function create - * @memberof onnx.TensorShapeProto - * @static - * @param {onnx.ITensorShapeProto=} [properties] Properties to set - * @returns {onnx.TensorShapeProto} TensorShapeProto instance - */ - TensorShapeProto.create = function create(properties) { - return new TensorShapeProto(properties); - }; - - /** - * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. - * @function encode - * @memberof onnx.TensorShapeProto - * @static - * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorShapeProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.dim != null && message.dim.length) - for (var i = 0; i < message.dim.length; ++i) - $root.onnx.TensorShapeProto.Dimension.encode(message.dim[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorShapeProto - * @static - * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TensorShapeProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorShapeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorShapeProto} TensorShapeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorShapeProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - if (!(message.dim && message.dim.length)) - message.dim = []; - message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorShapeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorShapeProto} TensorShapeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TensorShapeProto message. - * @function verify - * @memberof onnx.TensorShapeProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TensorShapeProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.dim != null && message.hasOwnProperty("dim")) { - if (!Array.isArray(message.dim)) - return "dim: array expected"; - for (var i = 0; i < message.dim.length; ++i) { - var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); - if (error) - return "dim." + error; - } - } - return null; - }; - - /** - * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorShapeProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorShapeProto} TensorShapeProto - */ - TensorShapeProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorShapeProto) - return object; - var message = new $root.onnx.TensorShapeProto(); - if (object.dim) { - if (!Array.isArray(object.dim)) - throw TypeError(".onnx.TensorShapeProto.dim: array expected"); - message.dim = []; - for (var i = 0; i < object.dim.length; ++i) { - if (typeof object.dim[i] !== "object") - throw TypeError(".onnx.TensorShapeProto.dim: object expected"); - message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); - } - } - return message; - }; - - /** - * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorShapeProto - * @static - * @param {onnx.TensorShapeProto} message TensorShapeProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TensorShapeProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) - object.dim = []; - if (message.dim && message.dim.length) { - object.dim = []; - for (var j = 0; j < message.dim.length; ++j) - object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); - } - return object; - }; - - /** - * Converts this TensorShapeProto to JSON. - * @function toJSON - * @memberof onnx.TensorShapeProto - * @instance - * @returns {Object.} JSON object - */ - TensorShapeProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TensorShapeProto - * @function getTypeUrl - * @memberof onnx.TensorShapeProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TensorShapeProto"; - }; - - TensorShapeProto.Dimension = (function() { - - /** - * Properties of a Dimension. - * @memberof onnx.TensorShapeProto - * @interface IDimension - * @property {number|Long|null} [dimValue] Dimension dimValue - * @property {string|null} [dimParam] Dimension dimParam - * @property {string|null} [denotation] Dimension denotation - */ - - /** - * Constructs a new Dimension. - * @memberof onnx.TensorShapeProto - * @classdesc Represents a Dimension. - * @implements IDimension - * @constructor - * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set - */ - function Dimension(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + OperatorSetIdProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.defaults) { + object.domain = ''; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.version = + options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else object.version = options.longs === String ? '0' : 0; + } + if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + if (message.version != null && message.hasOwnProperty('version')) + if (typeof message.version === 'number') + object.version = options.longs === String ? String(message.version) : message.version; + else + object.version = + options.longs === String + ? $util.Long.prototype.toString.call(message.version) + : options.longs === Number + ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() + : message.version; + return object; + }; - /** - * Dimension dimValue. - * @member {number|Long|null|undefined} dimValue - * @memberof onnx.TensorShapeProto.Dimension - * @instance - */ - Dimension.prototype.dimValue = null; - - /** - * Dimension dimParam. - * @member {string|null|undefined} dimParam - * @memberof onnx.TensorShapeProto.Dimension - * @instance - */ - Dimension.prototype.dimParam = null; - - /** - * Dimension denotation. - * @member {string} denotation - * @memberof onnx.TensorShapeProto.Dimension - * @instance - */ - Dimension.prototype.denotation = ""; - - // OneOf field names bound to virtual getters and setters - var $oneOfFields; - - /** - * Dimension value. - * @member {"dimValue"|"dimParam"|undefined} value - * @memberof onnx.TensorShapeProto.Dimension - * @instance - */ - Object.defineProperty(Dimension.prototype, "value", { - get: $util.oneOfGetter($oneOfFields = ["dimValue", "dimParam"]), - set: $util.oneOfSetter($oneOfFields) - }); - - /** - * Creates a new Dimension instance using the specified properties. - * @function create - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set - * @returns {onnx.TensorShapeProto.Dimension} Dimension instance - */ - Dimension.create = function create(properties) { - return new Dimension(properties); - }; - - /** - * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. - * @function encode - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Dimension.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.dimValue != null && Object.hasOwnProperty.call(message, "dimValue")) - writer.uint32(/* id 1, wireType 0 =*/8).int64(message.dimValue); - if (message.dimParam != null && Object.hasOwnProperty.call(message, "dimParam")) - writer.uint32(/* id 2, wireType 2 =*/18).string(message.dimParam); - if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) - writer.uint32(/* id 3, wireType 2 =*/26).string(message.denotation); - return writer; - }; - - /** - * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Dimension.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Dimension message from the specified reader or buffer. - * @function decode - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TensorShapeProto.Dimension} Dimension - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Dimension.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto.Dimension(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.dimValue = reader.int64(); - break; - } - case 2: { - message.dimParam = reader.string(); - break; - } - case 3: { - message.denotation = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Dimension message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TensorShapeProto.Dimension} Dimension - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Dimension.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Dimension message. - * @function verify - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Dimension.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - var properties = {}; - if (message.dimValue != null && message.hasOwnProperty("dimValue")) { - properties.value = 1; - if (!$util.isInteger(message.dimValue) && !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high))) - return "dimValue: integer|Long expected"; - } - if (message.dimParam != null && message.hasOwnProperty("dimParam")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - if (!$util.isString(message.dimParam)) - return "dimParam: string expected"; - } - if (message.denotation != null && message.hasOwnProperty("denotation")) - if (!$util.isString(message.denotation)) - return "denotation: string expected"; - return null; - }; - - /** - * Creates a Dimension message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {Object.} object Plain object - * @returns {onnx.TensorShapeProto.Dimension} Dimension - */ - Dimension.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorShapeProto.Dimension) - return object; - var message = new $root.onnx.TensorShapeProto.Dimension(); - if (object.dimValue != null) - if ($util.Long) - (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; - else if (typeof object.dimValue === "string") - message.dimValue = parseInt(object.dimValue, 10); - else if (typeof object.dimValue === "number") - message.dimValue = object.dimValue; - else if (typeof object.dimValue === "object") - message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); - if (object.dimParam != null) - message.dimParam = String(object.dimParam); - if (object.denotation != null) - message.denotation = String(object.denotation); - return message; - }; - - /** - * Creates a plain object from a Dimension message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {onnx.TensorShapeProto.Dimension} message Dimension - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Dimension.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) - object.denotation = ""; - if (message.dimValue != null && message.hasOwnProperty("dimValue")) { - if (typeof message.dimValue === "number") - object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; - else - object.dimValue = options.longs === String ? $util.Long.prototype.toString.call(message.dimValue) : options.longs === Number ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() : message.dimValue; - if (options.oneofs) - object.value = "dimValue"; - } - if (message.dimParam != null && message.hasOwnProperty("dimParam")) { - object.dimParam = message.dimParam; - if (options.oneofs) - object.value = "dimParam"; - } - if (message.denotation != null && message.hasOwnProperty("denotation")) - object.denotation = message.denotation; - return object; - }; - - /** - * Converts this Dimension to JSON. - * @function toJSON - * @memberof onnx.TensorShapeProto.Dimension - * @instance - * @returns {Object.} JSON object - */ - Dimension.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Dimension - * @function getTypeUrl - * @memberof onnx.TensorShapeProto.Dimension - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TensorShapeProto.Dimension"; - }; - - return Dimension; - })(); - - return TensorShapeProto; - })(); + /** + * Converts this OperatorSetIdProto to JSON. + * @function toJSON + * @memberof onnx.OperatorSetIdProto + * @instance + * @returns {Object.} JSON object + */ + OperatorSetIdProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; - onnx.TypeProto = (function() { - - /** - * Properties of a TypeProto. - * @memberof onnx - * @interface ITypeProto - * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType - * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType - * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType - * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType - * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType - * @property {string|null} [denotation] TypeProto denotation - */ - - /** - * Constructs a new TypeProto. - * @memberof onnx - * @classdesc Represents a TypeProto. - * @implements ITypeProto - * @constructor - * @param {onnx.ITypeProto=} [properties] Properties to set - */ - function TypeProto(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * Gets the default type url for OperatorSetIdProto + * @function getTypeUrl + * @memberof onnx.OperatorSetIdProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.OperatorSetIdProto'; + }; + + return OperatorSetIdProto; + })(); + + /** + * OperatorStatus enum. + * @name onnx.OperatorStatus + * @enum {number} + * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value + * @property {number} STABLE=1 STABLE value + */ + onnx.OperatorStatus = (function () { + var valuesById = {}, + values = Object.create(valuesById); + values[(valuesById[0] = 'EXPERIMENTAL')] = 0; + values[(valuesById[1] = 'STABLE')] = 1; + return values; + })(); + + onnx.FunctionProto = (function () { + /** + * Properties of a FunctionProto. + * @memberof onnx + * @interface IFunctionProto + * @property {string|null} [name] FunctionProto name + * @property {Array.|null} [input] FunctionProto input + * @property {Array.|null} [output] FunctionProto output + * @property {Array.|null} [attribute] FunctionProto attribute + * @property {Array.|null} [attributeProto] FunctionProto attributeProto + * @property {Array.|null} [node] FunctionProto node + * @property {string|null} [docString] FunctionProto docString + * @property {Array.|null} [opsetImport] FunctionProto opsetImport + * @property {string|null} [domain] FunctionProto domain + */ - /** - * TypeProto tensorType. - * @member {onnx.TypeProto.ITensor|null|undefined} tensorType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.tensorType = null; - - /** - * TypeProto sequenceType. - * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.sequenceType = null; - - /** - * TypeProto mapType. - * @member {onnx.TypeProto.IMap|null|undefined} mapType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.mapType = null; - - /** - * TypeProto optionalType. - * @member {onnx.TypeProto.IOptional|null|undefined} optionalType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.optionalType = null; - - /** - * TypeProto sparseTensorType. - * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.sparseTensorType = null; - - /** - * TypeProto denotation. - * @member {string} denotation - * @memberof onnx.TypeProto - * @instance - */ - TypeProto.prototype.denotation = ""; - - // OneOf field names bound to virtual getters and setters - var $oneOfFields; - - /** - * TypeProto value. - * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value - * @memberof onnx.TypeProto - * @instance - */ - Object.defineProperty(TypeProto.prototype, "value", { - get: $util.oneOfGetter($oneOfFields = ["tensorType", "sequenceType", "mapType", "optionalType", "sparseTensorType"]), - set: $util.oneOfSetter($oneOfFields) - }); - - /** - * Creates a new TypeProto instance using the specified properties. - * @function create - * @memberof onnx.TypeProto - * @static - * @param {onnx.ITypeProto=} [properties] Properties to set - * @returns {onnx.TypeProto} TypeProto instance - */ - TypeProto.create = function create(properties) { - return new TypeProto(properties); - }; - - /** - * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto - * @static - * @param {onnx.ITypeProto} message TypeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TypeProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.tensorType != null && Object.hasOwnProperty.call(message, "tensorType")) - $root.onnx.TypeProto.Tensor.encode(message.tensorType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - if (message.sequenceType != null && Object.hasOwnProperty.call(message, "sequenceType")) - $root.onnx.TypeProto.Sequence.encode(message.sequenceType, writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); - if (message.mapType != null && Object.hasOwnProperty.call(message, "mapType")) - $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); - if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) - writer.uint32(/* id 6, wireType 2 =*/50).string(message.denotation); - if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, "sparseTensorType")) - $root.onnx.TypeProto.SparseTensor.encode(message.sparseTensorType, writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); - if (message.optionalType != null && Object.hasOwnProperty.call(message, "optionalType")) - $root.onnx.TypeProto.Optional.encode(message.optionalType, writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto - * @static - * @param {onnx.ITypeProto} message TypeProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - TypeProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a TypeProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto} TypeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TypeProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); - break; - } - case 4: { - message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); - break; - } - case 5: { - message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); - break; - } - case 9: { - message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); - break; - } - case 8: { - message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); - break; - } - case 6: { - message.denotation = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a TypeProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto} TypeProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - TypeProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a TypeProto message. - * @function verify - * @memberof onnx.TypeProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - TypeProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - var properties = {}; - if (message.tensorType != null && message.hasOwnProperty("tensorType")) { - properties.value = 1; - { - var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); - if (error) - return "tensorType." + error; - } - } - if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - { - var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); - if (error) - return "sequenceType." + error; - } - } - if (message.mapType != null && message.hasOwnProperty("mapType")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - { - var error = $root.onnx.TypeProto.Map.verify(message.mapType); - if (error) - return "mapType." + error; - } - } - if (message.optionalType != null && message.hasOwnProperty("optionalType")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - { - var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); - if (error) - return "optionalType." + error; - } - } - if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { - if (properties.value === 1) - return "value: multiple values"; - properties.value = 1; - { - var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); - if (error) - return "sparseTensorType." + error; - } - } - if (message.denotation != null && message.hasOwnProperty("denotation")) - if (!$util.isString(message.denotation)) - return "denotation: string expected"; - return null; - }; - - /** - * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto} TypeProto - */ - TypeProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto) - return object; - var message = new $root.onnx.TypeProto(); - if (object.tensorType != null) { - if (typeof object.tensorType !== "object") - throw TypeError(".onnx.TypeProto.tensorType: object expected"); - message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); - } - if (object.sequenceType != null) { - if (typeof object.sequenceType !== "object") - throw TypeError(".onnx.TypeProto.sequenceType: object expected"); - message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); - } - if (object.mapType != null) { - if (typeof object.mapType !== "object") - throw TypeError(".onnx.TypeProto.mapType: object expected"); - message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); - } - if (object.optionalType != null) { - if (typeof object.optionalType !== "object") - throw TypeError(".onnx.TypeProto.optionalType: object expected"); - message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); - } - if (object.sparseTensorType != null) { - if (typeof object.sparseTensorType !== "object") - throw TypeError(".onnx.TypeProto.sparseTensorType: object expected"); - message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); - } - if (object.denotation != null) - message.denotation = String(object.denotation); - return message; - }; - - /** - * Creates a plain object from a TypeProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto - * @static - * @param {onnx.TypeProto} message TypeProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - TypeProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) - object.denotation = ""; - if (message.tensorType != null && message.hasOwnProperty("tensorType")) { - object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); - if (options.oneofs) - object.value = "tensorType"; - } - if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { - object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); - if (options.oneofs) - object.value = "sequenceType"; - } - if (message.mapType != null && message.hasOwnProperty("mapType")) { - object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); - if (options.oneofs) - object.value = "mapType"; - } - if (message.denotation != null && message.hasOwnProperty("denotation")) - object.denotation = message.denotation; - if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { - object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); - if (options.oneofs) - object.value = "sparseTensorType"; - } - if (message.optionalType != null && message.hasOwnProperty("optionalType")) { - object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); - if (options.oneofs) - object.value = "optionalType"; - } - return object; - }; - - /** - * Converts this TypeProto to JSON. - * @function toJSON - * @memberof onnx.TypeProto - * @instance - * @returns {Object.} JSON object - */ - TypeProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for TypeProto - * @function getTypeUrl - * @memberof onnx.TypeProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto"; - }; - - TypeProto.Tensor = (function() { - - /** - * Properties of a Tensor. - * @memberof onnx.TypeProto - * @interface ITensor - * @property {number|null} [elemType] Tensor elemType - * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape - */ - - /** - * Constructs a new Tensor. - * @memberof onnx.TypeProto - * @classdesc Represents a Tensor. - * @implements ITensor - * @constructor - * @param {onnx.TypeProto.ITensor=} [properties] Properties to set - */ - function Tensor(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * Constructs a new FunctionProto. + * @memberof onnx + * @classdesc Represents a FunctionProto. + * @implements IFunctionProto + * @constructor + * @param {onnx.IFunctionProto=} [properties] Properties to set + */ + function FunctionProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + this.attributeProto = []; + this.node = []; + this.opsetImport = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + } - /** - * Tensor elemType. - * @member {number} elemType - * @memberof onnx.TypeProto.Tensor - * @instance - */ - Tensor.prototype.elemType = 0; - - /** - * Tensor shape. - * @member {onnx.ITensorShapeProto|null|undefined} shape - * @memberof onnx.TypeProto.Tensor - * @instance - */ - Tensor.prototype.shape = null; - - /** - * Creates a new Tensor instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.Tensor - * @static - * @param {onnx.TypeProto.ITensor=} [properties] Properties to set - * @returns {onnx.TypeProto.Tensor} Tensor instance - */ - Tensor.create = function create(properties) { - return new Tensor(properties); - }; - - /** - * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.Tensor - * @static - * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Tensor.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) - writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); - if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) - $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.Tensor - * @static - * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Tensor.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Tensor message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.Tensor - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.Tensor} Tensor - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Tensor.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Tensor(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.elemType = reader.int32(); - break; - } - case 2: { - message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Tensor message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.Tensor - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.Tensor} Tensor - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Tensor.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Tensor message. - * @function verify - * @memberof onnx.TypeProto.Tensor - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Tensor.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.elemType != null && message.hasOwnProperty("elemType")) - if (!$util.isInteger(message.elemType)) - return "elemType: integer expected"; - if (message.shape != null && message.hasOwnProperty("shape")) { - var error = $root.onnx.TensorShapeProto.verify(message.shape); - if (error) - return "shape." + error; - } - return null; - }; - - /** - * Creates a Tensor message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.Tensor - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.Tensor} Tensor - */ - Tensor.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.Tensor) - return object; - var message = new $root.onnx.TypeProto.Tensor(); - if (object.elemType != null) - message.elemType = object.elemType | 0; - if (object.shape != null) { - if (typeof object.shape !== "object") - throw TypeError(".onnx.TypeProto.Tensor.shape: object expected"); - message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); - } - return message; - }; - - /** - * Creates a plain object from a Tensor message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.Tensor - * @static - * @param {onnx.TypeProto.Tensor} message Tensor - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Tensor.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.elemType = 0; - object.shape = null; - } - if (message.elemType != null && message.hasOwnProperty("elemType")) - object.elemType = message.elemType; - if (message.shape != null && message.hasOwnProperty("shape")) - object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); - return object; - }; - - /** - * Converts this Tensor to JSON. - * @function toJSON - * @memberof onnx.TypeProto.Tensor - * @instance - * @returns {Object.} JSON object - */ - Tensor.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Tensor - * @function getTypeUrl - * @memberof onnx.TypeProto.Tensor - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.Tensor"; - }; - - return Tensor; - })(); - - TypeProto.Sequence = (function() { - - /** - * Properties of a Sequence. - * @memberof onnx.TypeProto - * @interface ISequence - * @property {onnx.ITypeProto|null} [elemType] Sequence elemType - */ - - /** - * Constructs a new Sequence. - * @memberof onnx.TypeProto - * @classdesc Represents a Sequence. - * @implements ISequence - * @constructor - * @param {onnx.TypeProto.ISequence=} [properties] Properties to set - */ - function Sequence(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto name. + * @member {string} name + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.name = ''; - /** - * Sequence elemType. - * @member {onnx.ITypeProto|null|undefined} elemType - * @memberof onnx.TypeProto.Sequence - * @instance - */ - Sequence.prototype.elemType = null; - - /** - * Creates a new Sequence instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.Sequence - * @static - * @param {onnx.TypeProto.ISequence=} [properties] Properties to set - * @returns {onnx.TypeProto.Sequence} Sequence instance - */ - Sequence.create = function create(properties) { - return new Sequence(properties); - }; - - /** - * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.Sequence - * @static - * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Sequence.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) - $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.Sequence - * @static - * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Sequence.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Sequence message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.Sequence - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.Sequence} Sequence - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Sequence.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Sequence(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Sequence message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.Sequence - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.Sequence} Sequence - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Sequence.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Sequence message. - * @function verify - * @memberof onnx.TypeProto.Sequence - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Sequence.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.elemType != null && message.hasOwnProperty("elemType")) { - var error = $root.onnx.TypeProto.verify(message.elemType); - if (error) - return "elemType." + error; - } - return null; - }; - - /** - * Creates a Sequence message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.Sequence - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.Sequence} Sequence - */ - Sequence.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.Sequence) - return object; - var message = new $root.onnx.TypeProto.Sequence(); - if (object.elemType != null) { - if (typeof object.elemType !== "object") - throw TypeError(".onnx.TypeProto.Sequence.elemType: object expected"); - message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); - } - return message; - }; - - /** - * Creates a plain object from a Sequence message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.Sequence - * @static - * @param {onnx.TypeProto.Sequence} message Sequence - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Sequence.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) - object.elemType = null; - if (message.elemType != null && message.hasOwnProperty("elemType")) - object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); - return object; - }; - - /** - * Converts this Sequence to JSON. - * @function toJSON - * @memberof onnx.TypeProto.Sequence - * @instance - * @returns {Object.} JSON object - */ - Sequence.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Sequence - * @function getTypeUrl - * @memberof onnx.TypeProto.Sequence - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.Sequence"; - }; - - return Sequence; - })(); - - TypeProto.Map = (function() { - - /** - * Properties of a Map. - * @memberof onnx.TypeProto - * @interface IMap - * @property {number|null} [keyType] Map keyType - * @property {onnx.ITypeProto|null} [valueType] Map valueType - */ - - /** - * Constructs a new Map. - * @memberof onnx.TypeProto - * @classdesc Represents a Map. - * @implements IMap - * @constructor - * @param {onnx.TypeProto.IMap=} [properties] Properties to set - */ - function Map(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto input. + * @member {Array.} input + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.input = $util.emptyArray; - /** - * Map keyType. - * @member {number} keyType - * @memberof onnx.TypeProto.Map - * @instance - */ - Map.prototype.keyType = 0; - - /** - * Map valueType. - * @member {onnx.ITypeProto|null|undefined} valueType - * @memberof onnx.TypeProto.Map - * @instance - */ - Map.prototype.valueType = null; - - /** - * Creates a new Map instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.Map - * @static - * @param {onnx.TypeProto.IMap=} [properties] Properties to set - * @returns {onnx.TypeProto.Map} Map instance - */ - Map.create = function create(properties) { - return new Map(properties); - }; - - /** - * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.Map - * @static - * @param {onnx.TypeProto.IMap} message Map message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Map.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.keyType != null && Object.hasOwnProperty.call(message, "keyType")) - writer.uint32(/* id 1, wireType 0 =*/8).int32(message.keyType); - if (message.valueType != null && Object.hasOwnProperty.call(message, "valueType")) - $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.Map - * @static - * @param {onnx.TypeProto.IMap} message Map message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Map.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a Map message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.Map - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.Map} Map - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Map.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Map(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.keyType = reader.int32(); - break; - } - case 2: { - message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a Map message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.Map - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.Map} Map - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Map.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a Map message. - * @function verify - * @memberof onnx.TypeProto.Map - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Map.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.keyType != null && message.hasOwnProperty("keyType")) - if (!$util.isInteger(message.keyType)) - return "keyType: integer expected"; - if (message.valueType != null && message.hasOwnProperty("valueType")) { - var error = $root.onnx.TypeProto.verify(message.valueType); - if (error) - return "valueType." + error; - } - return null; - }; - - /** - * Creates a Map message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.Map - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.Map} Map - */ - Map.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.Map) - return object; - var message = new $root.onnx.TypeProto.Map(); - if (object.keyType != null) - message.keyType = object.keyType | 0; - if (object.valueType != null) { - if (typeof object.valueType !== "object") - throw TypeError(".onnx.TypeProto.Map.valueType: object expected"); - message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); - } - return message; - }; - - /** - * Creates a plain object from a Map message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.Map - * @static - * @param {onnx.TypeProto.Map} message Map - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Map.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.keyType = 0; - object.valueType = null; - } - if (message.keyType != null && message.hasOwnProperty("keyType")) - object.keyType = message.keyType; - if (message.valueType != null && message.hasOwnProperty("valueType")) - object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); - return object; - }; - - /** - * Converts this Map to JSON. - * @function toJSON - * @memberof onnx.TypeProto.Map - * @instance - * @returns {Object.} JSON object - */ - Map.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Map - * @function getTypeUrl - * @memberof onnx.TypeProto.Map - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.Map"; - }; - - return Map; - })(); - - TypeProto.Optional = (function() { - - /** - * Properties of an Optional. - * @memberof onnx.TypeProto - * @interface IOptional - * @property {onnx.ITypeProto|null} [elemType] Optional elemType - */ - - /** - * Constructs a new Optional. - * @memberof onnx.TypeProto - * @classdesc Represents an Optional. - * @implements IOptional - * @constructor - * @param {onnx.TypeProto.IOptional=} [properties] Properties to set - */ - function Optional(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto output. + * @member {Array.} output + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.output = $util.emptyArray; - /** - * Optional elemType. - * @member {onnx.ITypeProto|null|undefined} elemType - * @memberof onnx.TypeProto.Optional - * @instance - */ - Optional.prototype.elemType = null; - - /** - * Creates a new Optional instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.Optional - * @static - * @param {onnx.TypeProto.IOptional=} [properties] Properties to set - * @returns {onnx.TypeProto.Optional} Optional instance - */ - Optional.create = function create(properties) { - return new Optional(properties); - }; - - /** - * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.Optional - * @static - * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Optional.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) - $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.Optional - * @static - * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - Optional.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes an Optional message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.Optional - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.Optional} Optional - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Optional.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Optional(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes an Optional message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.Optional - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.Optional} Optional - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - Optional.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies an Optional message. - * @function verify - * @memberof onnx.TypeProto.Optional - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - Optional.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.elemType != null && message.hasOwnProperty("elemType")) { - var error = $root.onnx.TypeProto.verify(message.elemType); - if (error) - return "elemType." + error; - } - return null; - }; - - /** - * Creates an Optional message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.Optional - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.Optional} Optional - */ - Optional.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.Optional) - return object; - var message = new $root.onnx.TypeProto.Optional(); - if (object.elemType != null) { - if (typeof object.elemType !== "object") - throw TypeError(".onnx.TypeProto.Optional.elemType: object expected"); - message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); - } - return message; - }; - - /** - * Creates a plain object from an Optional message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.Optional - * @static - * @param {onnx.TypeProto.Optional} message Optional - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - Optional.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) - object.elemType = null; - if (message.elemType != null && message.hasOwnProperty("elemType")) - object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); - return object; - }; - - /** - * Converts this Optional to JSON. - * @function toJSON - * @memberof onnx.TypeProto.Optional - * @instance - * @returns {Object.} JSON object - */ - Optional.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for Optional - * @function getTypeUrl - * @memberof onnx.TypeProto.Optional - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.Optional"; - }; - - return Optional; - })(); - - TypeProto.SparseTensor = (function() { - - /** - * Properties of a SparseTensor. - * @memberof onnx.TypeProto - * @interface ISparseTensor - * @property {number|null} [elemType] SparseTensor elemType - * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape - */ - - /** - * Constructs a new SparseTensor. - * @memberof onnx.TypeProto - * @classdesc Represents a SparseTensor. - * @implements ISparseTensor - * @constructor - * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set - */ - function SparseTensor(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto attribute. + * @member {Array.} attribute + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attribute = $util.emptyArray; - /** - * SparseTensor elemType. - * @member {number} elemType - * @memberof onnx.TypeProto.SparseTensor - * @instance - */ - SparseTensor.prototype.elemType = 0; - - /** - * SparseTensor shape. - * @member {onnx.ITensorShapeProto|null|undefined} shape - * @memberof onnx.TypeProto.SparseTensor - * @instance - */ - SparseTensor.prototype.shape = null; - - /** - * Creates a new SparseTensor instance using the specified properties. - * @function create - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set - * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance - */ - SparseTensor.create = function create(properties) { - return new SparseTensor(properties); - }; - - /** - * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. - * @function encode - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - SparseTensor.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) - writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); - if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) - $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a SparseTensor message from the specified reader or buffer. - * @function decode - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.TypeProto.SparseTensor} SparseTensor - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - SparseTensor.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.SparseTensor(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.elemType = reader.int32(); - break; - } - case 2: { - message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a SparseTensor message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.TypeProto.SparseTensor} SparseTensor - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - SparseTensor.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a SparseTensor message. - * @function verify - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - SparseTensor.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.elemType != null && message.hasOwnProperty("elemType")) - if (!$util.isInteger(message.elemType)) - return "elemType: integer expected"; - if (message.shape != null && message.hasOwnProperty("shape")) { - var error = $root.onnx.TensorShapeProto.verify(message.shape); - if (error) - return "shape." + error; - } - return null; - }; - - /** - * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {Object.} object Plain object - * @returns {onnx.TypeProto.SparseTensor} SparseTensor - */ - SparseTensor.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TypeProto.SparseTensor) - return object; - var message = new $root.onnx.TypeProto.SparseTensor(); - if (object.elemType != null) - message.elemType = object.elemType | 0; - if (object.shape != null) { - if (typeof object.shape !== "object") - throw TypeError(".onnx.TypeProto.SparseTensor.shape: object expected"); - message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); - } - return message; - }; - - /** - * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {onnx.TypeProto.SparseTensor} message SparseTensor - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - SparseTensor.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.elemType = 0; - object.shape = null; - } - if (message.elemType != null && message.hasOwnProperty("elemType")) - object.elemType = message.elemType; - if (message.shape != null && message.hasOwnProperty("shape")) - object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); - return object; - }; - - /** - * Converts this SparseTensor to JSON. - * @function toJSON - * @memberof onnx.TypeProto.SparseTensor - * @instance - * @returns {Object.} JSON object - */ - SparseTensor.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for SparseTensor - * @function getTypeUrl - * @memberof onnx.TypeProto.SparseTensor - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.TypeProto.SparseTensor"; - }; - - return SparseTensor; - })(); - - return TypeProto; - })(); + /** + * FunctionProto attributeProto. + * @member {Array.} attributeProto + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attributeProto = $util.emptyArray; - onnx.OperatorSetIdProto = (function() { - - /** - * Properties of an OperatorSetIdProto. - * @memberof onnx - * @interface IOperatorSetIdProto - * @property {string|null} [domain] OperatorSetIdProto domain - * @property {number|Long|null} [version] OperatorSetIdProto version - */ - - /** - * Constructs a new OperatorSetIdProto. - * @memberof onnx - * @classdesc Represents an OperatorSetIdProto. - * @implements IOperatorSetIdProto - * @constructor - * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set - */ - function OperatorSetIdProto(properties) { - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; - } + /** + * FunctionProto node. + * @member {Array.} node + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.node = $util.emptyArray; - /** - * OperatorSetIdProto domain. - * @member {string} domain - * @memberof onnx.OperatorSetIdProto - * @instance - */ - OperatorSetIdProto.prototype.domain = ""; - - /** - * OperatorSetIdProto version. - * @member {number|Long} version - * @memberof onnx.OperatorSetIdProto - * @instance - */ - OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0,0,false) : 0; - - /** - * Creates a new OperatorSetIdProto instance using the specified properties. - * @function create - * @memberof onnx.OperatorSetIdProto - * @static - * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set - * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance - */ - OperatorSetIdProto.create = function create(properties) { - return new OperatorSetIdProto(properties); - }; - - /** - * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. - * @function encode - * @memberof onnx.OperatorSetIdProto - * @static - * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - OperatorSetIdProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.domain); - if (message.version != null && Object.hasOwnProperty.call(message, "version")) - writer.uint32(/* id 2, wireType 0 =*/16).int64(message.version); - return writer; - }; - - /** - * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.OperatorSetIdProto - * @static - * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes an OperatorSetIdProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.OperatorSetIdProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - OperatorSetIdProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.OperatorSetIdProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.domain = reader.string(); - break; - } - case 2: { - message.version = reader.int64(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.OperatorSetIdProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies an OperatorSetIdProto message. - * @function verify - * @memberof onnx.OperatorSetIdProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - OperatorSetIdProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.domain != null && message.hasOwnProperty("domain")) - if (!$util.isString(message.domain)) - return "domain: string expected"; - if (message.version != null && message.hasOwnProperty("version")) - if (!$util.isInteger(message.version) && !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high))) - return "version: integer|Long expected"; - return null; - }; - - /** - * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.OperatorSetIdProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto - */ - OperatorSetIdProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.OperatorSetIdProto) - return object; - var message = new $root.onnx.OperatorSetIdProto(); - if (object.domain != null) - message.domain = String(object.domain); - if (object.version != null) - if ($util.Long) - (message.version = $util.Long.fromValue(object.version)).unsigned = false; - else if (typeof object.version === "string") - message.version = parseInt(object.version, 10); - else if (typeof object.version === "number") - message.version = object.version; - else if (typeof object.version === "object") - message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); - return message; - }; - - /** - * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.OperatorSetIdProto - * @static - * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - OperatorSetIdProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.defaults) { - object.domain = ""; - if ($util.Long) { - var long = new $util.Long(0, 0, false); - object.version = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else - object.version = options.longs === String ? "0" : 0; - } - if (message.domain != null && message.hasOwnProperty("domain")) - object.domain = message.domain; - if (message.version != null && message.hasOwnProperty("version")) - if (typeof message.version === "number") - object.version = options.longs === String ? String(message.version) : message.version; - else - object.version = options.longs === String ? $util.Long.prototype.toString.call(message.version) : options.longs === Number ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() : message.version; - return object; - }; - - /** - * Converts this OperatorSetIdProto to JSON. - * @function toJSON - * @memberof onnx.OperatorSetIdProto - * @instance - * @returns {Object.} JSON object - */ - OperatorSetIdProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for OperatorSetIdProto - * @function getTypeUrl - * @memberof onnx.OperatorSetIdProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.OperatorSetIdProto"; - }; + /** + * FunctionProto docString. + * @member {string} docString + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.docString = ''; - return OperatorSetIdProto; - })(); + /** + * FunctionProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.opsetImport = $util.emptyArray; /** - * OperatorStatus enum. - * @name onnx.OperatorStatus - * @enum {number} - * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value - * @property {number} STABLE=1 STABLE value - */ - onnx.OperatorStatus = (function() { - var valuesById = {}, values = Object.create(valuesById); - values[valuesById[0] = "EXPERIMENTAL"] = 0; - values[valuesById[1] = "STABLE"] = 1; - return values; - })(); + * FunctionProto domain. + * @member {string} domain + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.domain = ''; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @function create + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto=} [properties] Properties to set + * @returns {onnx.FunctionProto} FunctionProto instance + */ + FunctionProto.create = function create(properties) { + return new FunctionProto(properties); + }; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encode + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encode = function encode(message, writer) { + if (!writer) writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.name); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 4, wireType 2 =*/ 34).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 5, wireType 2 =*/ 42).string(message.output[i]); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.attribute[i]); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/ 58).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + writer.uint32(/* id 8, wireType 2 =*/ 66).string(message.docString); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode( + message.opsetImport[i], + writer.uint32(/* id 9, wireType 2 =*/ 74).fork(), + ).ldelim(); + if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + writer.uint32(/* id 10, wireType 2 =*/ 82).string(message.domain); + if (message.attributeProto != null && message.attributeProto.length) + for (var i = 0; i < message.attributeProto.length; ++i) + $root.onnx.AttributeProto.encode( + message.attributeProto[i], + writer.uint32(/* id 11, wireType 2 =*/ 90).fork(), + ).ldelim(); + return writer; + }; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; - onnx.FunctionProto = (function() { - - /** - * Properties of a FunctionProto. - * @memberof onnx - * @interface IFunctionProto - * @property {string|null} [name] FunctionProto name - * @property {Array.|null} [input] FunctionProto input - * @property {Array.|null} [output] FunctionProto output - * @property {Array.|null} [attribute] FunctionProto attribute - * @property {Array.|null} [attributeProto] FunctionProto attributeProto - * @property {Array.|null} [node] FunctionProto node - * @property {string|null} [docString] FunctionProto docString - * @property {Array.|null} [opsetImport] FunctionProto opsetImport - * @property {string|null} [domain] FunctionProto domain - */ - - /** - * Constructs a new FunctionProto. - * @memberof onnx - * @classdesc Represents a FunctionProto. - * @implements IFunctionProto - * @constructor - * @param {onnx.IFunctionProto=} [properties] Properties to set - */ - function FunctionProto(properties) { - this.input = []; - this.output = []; - this.attribute = []; - this.attributeProto = []; - this.node = []; - this.opsetImport = []; - if (properties) - for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) - this[keys[i]] = properties[keys[i]]; + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, + message = new $root.onnx.FunctionProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 4: { + if (!(message.input && message.input.length)) message.input = []; + message.input.push(reader.string()); + break; + } + case 5: { + if (!(message.output && message.output.length)) message.output = []; + message.output.push(reader.string()); + break; + } + case 6: { + if (!(message.attribute && message.attribute.length)) message.attribute = []; + message.attribute.push(reader.string()); + break; + } + case 11: { + if (!(message.attributeProto && message.attributeProto.length)) message.attributeProto = []; + message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 7: { + if (!(message.node && message.node.length)) message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 8: { + message.docString = reader.string(); + break; + } + case 9: { + if (!(message.opsetImport && message.opsetImport.length)) message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.domain = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; } + } + return message; + }; - /** - * FunctionProto name. - * @member {string} name - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.name = ""; - - /** - * FunctionProto input. - * @member {Array.} input - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.input = $util.emptyArray; - - /** - * FunctionProto output. - * @member {Array.} output - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.output = $util.emptyArray; - - /** - * FunctionProto attribute. - * @member {Array.} attribute - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.attribute = $util.emptyArray; - - /** - * FunctionProto attributeProto. - * @member {Array.} attributeProto - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.attributeProto = $util.emptyArray; - - /** - * FunctionProto node. - * @member {Array.} node - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.node = $util.emptyArray; - - /** - * FunctionProto docString. - * @member {string} docString - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.docString = ""; - - /** - * FunctionProto opsetImport. - * @member {Array.} opsetImport - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.opsetImport = $util.emptyArray; - - /** - * FunctionProto domain. - * @member {string} domain - * @memberof onnx.FunctionProto - * @instance - */ - FunctionProto.prototype.domain = ""; - - /** - * Creates a new FunctionProto instance using the specified properties. - * @function create - * @memberof onnx.FunctionProto - * @static - * @param {onnx.IFunctionProto=} [properties] Properties to set - * @returns {onnx.FunctionProto} FunctionProto instance - */ - FunctionProto.create = function create(properties) { - return new FunctionProto(properties); - }; - - /** - * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. - * @function encode - * @memberof onnx.FunctionProto - * @static - * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - FunctionProto.encode = function encode(message, writer) { - if (!writer) - writer = $Writer.create(); - if (message.name != null && Object.hasOwnProperty.call(message, "name")) - writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); - if (message.input != null && message.input.length) - for (var i = 0; i < message.input.length; ++i) - writer.uint32(/* id 4, wireType 2 =*/34).string(message.input[i]); - if (message.output != null && message.output.length) - for (var i = 0; i < message.output.length; ++i) - writer.uint32(/* id 5, wireType 2 =*/42).string(message.output[i]); - if (message.attribute != null && message.attribute.length) - for (var i = 0; i < message.attribute.length; ++i) - writer.uint32(/* id 6, wireType 2 =*/50).string(message.attribute[i]); - if (message.node != null && message.node.length) - for (var i = 0; i < message.node.length; ++i) - $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) - writer.uint32(/* id 8, wireType 2 =*/66).string(message.docString); - if (message.opsetImport != null && message.opsetImport.length) - for (var i = 0; i < message.opsetImport.length; ++i) - $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); - if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) - writer.uint32(/* id 10, wireType 2 =*/82).string(message.domain); - if (message.attributeProto != null && message.attributeProto.length) - for (var i = 0; i < message.attributeProto.length; ++i) - $root.onnx.AttributeProto.encode(message.attributeProto[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); - return writer; - }; - - /** - * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. - * @function encodeDelimited - * @memberof onnx.FunctionProto - * @static - * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode - * @param {$protobuf.Writer} [writer] Writer to encode to - * @returns {$protobuf.Writer} Writer - */ - FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { - return this.encode(message, writer).ldelim(); - }; - - /** - * Decodes a FunctionProto message from the specified reader or buffer. - * @function decode - * @memberof onnx.FunctionProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @param {number} [length] Message length if known beforehand - * @returns {onnx.FunctionProto} FunctionProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - FunctionProto.decode = function decode(reader, length) { - if (!(reader instanceof $Reader)) - reader = $Reader.create(reader); - var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.FunctionProto(); - while (reader.pos < end) { - var tag = reader.uint32(); - switch (tag >>> 3) { - case 1: { - message.name = reader.string(); - break; - } - case 4: { - if (!(message.input && message.input.length)) - message.input = []; - message.input.push(reader.string()); - break; - } - case 5: { - if (!(message.output && message.output.length)) - message.output = []; - message.output.push(reader.string()); - break; - } - case 6: { - if (!(message.attribute && message.attribute.length)) - message.attribute = []; - message.attribute.push(reader.string()); - break; - } - case 11: { - if (!(message.attributeProto && message.attributeProto.length)) - message.attributeProto = []; - message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); - break; - } - case 7: { - if (!(message.node && message.node.length)) - message.node = []; - message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); - break; - } - case 8: { - message.docString = reader.string(); - break; - } - case 9: { - if (!(message.opsetImport && message.opsetImport.length)) - message.opsetImport = []; - message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); - break; - } - case 10: { - message.domain = reader.string(); - break; - } - default: - reader.skipType(tag & 7); - break; - } - } - return message; - }; - - /** - * Decodes a FunctionProto message from the specified reader or buffer, length delimited. - * @function decodeDelimited - * @memberof onnx.FunctionProto - * @static - * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from - * @returns {onnx.FunctionProto} FunctionProto - * @throws {Error} If the payload is not a reader or valid buffer - * @throws {$protobuf.util.ProtocolError} If required fields are missing - */ - FunctionProto.decodeDelimited = function decodeDelimited(reader) { - if (!(reader instanceof $Reader)) - reader = new $Reader(reader); - return this.decode(reader, reader.uint32()); - }; - - /** - * Verifies a FunctionProto message. - * @function verify - * @memberof onnx.FunctionProto - * @static - * @param {Object.} message Plain object to verify - * @returns {string|null} `null` if valid, otherwise the reason why it is not - */ - FunctionProto.verify = function verify(message) { - if (typeof message !== "object" || message === null) - return "object expected"; - if (message.name != null && message.hasOwnProperty("name")) - if (!$util.isString(message.name)) - return "name: string expected"; - if (message.input != null && message.hasOwnProperty("input")) { - if (!Array.isArray(message.input)) - return "input: array expected"; - for (var i = 0; i < message.input.length; ++i) - if (!$util.isString(message.input[i])) - return "input: string[] expected"; - } - if (message.output != null && message.hasOwnProperty("output")) { - if (!Array.isArray(message.output)) - return "output: array expected"; - for (var i = 0; i < message.output.length; ++i) - if (!$util.isString(message.output[i])) - return "output: string[] expected"; - } - if (message.attribute != null && message.hasOwnProperty("attribute")) { - if (!Array.isArray(message.attribute)) - return "attribute: array expected"; - for (var i = 0; i < message.attribute.length; ++i) - if (!$util.isString(message.attribute[i])) - return "attribute: string[] expected"; - } - if (message.attributeProto != null && message.hasOwnProperty("attributeProto")) { - if (!Array.isArray(message.attributeProto)) - return "attributeProto: array expected"; - for (var i = 0; i < message.attributeProto.length; ++i) { - var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); - if (error) - return "attributeProto." + error; - } - } - if (message.node != null && message.hasOwnProperty("node")) { - if (!Array.isArray(message.node)) - return "node: array expected"; - for (var i = 0; i < message.node.length; ++i) { - var error = $root.onnx.NodeProto.verify(message.node[i]); - if (error) - return "node." + error; - } - } - if (message.docString != null && message.hasOwnProperty("docString")) - if (!$util.isString(message.docString)) - return "docString: string expected"; - if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { - if (!Array.isArray(message.opsetImport)) - return "opsetImport: array expected"; - for (var i = 0; i < message.opsetImport.length; ++i) { - var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); - if (error) - return "opsetImport." + error; - } - } - if (message.domain != null && message.hasOwnProperty("domain")) - if (!$util.isString(message.domain)) - return "domain: string expected"; - return null; - }; - - /** - * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. - * @function fromObject - * @memberof onnx.FunctionProto - * @static - * @param {Object.} object Plain object - * @returns {onnx.FunctionProto} FunctionProto - */ - FunctionProto.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.FunctionProto) - return object; - var message = new $root.onnx.FunctionProto(); - if (object.name != null) - message.name = String(object.name); - if (object.input) { - if (!Array.isArray(object.input)) - throw TypeError(".onnx.FunctionProto.input: array expected"); - message.input = []; - for (var i = 0; i < object.input.length; ++i) - message.input[i] = String(object.input[i]); - } - if (object.output) { - if (!Array.isArray(object.output)) - throw TypeError(".onnx.FunctionProto.output: array expected"); - message.output = []; - for (var i = 0; i < object.output.length; ++i) - message.output[i] = String(object.output[i]); - } - if (object.attribute) { - if (!Array.isArray(object.attribute)) - throw TypeError(".onnx.FunctionProto.attribute: array expected"); - message.attribute = []; - for (var i = 0; i < object.attribute.length; ++i) - message.attribute[i] = String(object.attribute[i]); - } - if (object.attributeProto) { - if (!Array.isArray(object.attributeProto)) - throw TypeError(".onnx.FunctionProto.attributeProto: array expected"); - message.attributeProto = []; - for (var i = 0; i < object.attributeProto.length; ++i) { - if (typeof object.attributeProto[i] !== "object") - throw TypeError(".onnx.FunctionProto.attributeProto: object expected"); - message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); - } - } - if (object.node) { - if (!Array.isArray(object.node)) - throw TypeError(".onnx.FunctionProto.node: array expected"); - message.node = []; - for (var i = 0; i < object.node.length; ++i) { - if (typeof object.node[i] !== "object") - throw TypeError(".onnx.FunctionProto.node: object expected"); - message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); - } - } - if (object.docString != null) - message.docString = String(object.docString); - if (object.opsetImport) { - if (!Array.isArray(object.opsetImport)) - throw TypeError(".onnx.FunctionProto.opsetImport: array expected"); - message.opsetImport = []; - for (var i = 0; i < object.opsetImport.length; ++i) { - if (typeof object.opsetImport[i] !== "object") - throw TypeError(".onnx.FunctionProto.opsetImport: object expected"); - message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); - } - } - if (object.domain != null) - message.domain = String(object.domain); - return message; - }; - - /** - * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. - * @function toObject - * @memberof onnx.FunctionProto - * @static - * @param {onnx.FunctionProto} message FunctionProto - * @param {$protobuf.IConversionOptions} [options] Conversion options - * @returns {Object.} Plain object - */ - FunctionProto.toObject = function toObject(message, options) { - if (!options) - options = {}; - var object = {}; - if (options.arrays || options.defaults) { - object.input = []; - object.output = []; - object.attribute = []; - object.node = []; - object.opsetImport = []; - object.attributeProto = []; - } - if (options.defaults) { - object.name = ""; - object.docString = ""; - object.domain = ""; - } - if (message.name != null && message.hasOwnProperty("name")) - object.name = message.name; - if (message.input && message.input.length) { - object.input = []; - for (var j = 0; j < message.input.length; ++j) - object.input[j] = message.input[j]; - } - if (message.output && message.output.length) { - object.output = []; - for (var j = 0; j < message.output.length; ++j) - object.output[j] = message.output[j]; - } - if (message.attribute && message.attribute.length) { - object.attribute = []; - for (var j = 0; j < message.attribute.length; ++j) - object.attribute[j] = message.attribute[j]; - } - if (message.node && message.node.length) { - object.node = []; - for (var j = 0; j < message.node.length; ++j) - object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); - } - if (message.docString != null && message.hasOwnProperty("docString")) - object.docString = message.docString; - if (message.opsetImport && message.opsetImport.length) { - object.opsetImport = []; - for (var j = 0; j < message.opsetImport.length; ++j) - object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); - } - if (message.domain != null && message.hasOwnProperty("domain")) - object.domain = message.domain; - if (message.attributeProto && message.attributeProto.length) { - object.attributeProto = []; - for (var j = 0; j < message.attributeProto.length; ++j) - object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); - } - return object; - }; - - /** - * Converts this FunctionProto to JSON. - * @function toJSON - * @memberof onnx.FunctionProto - * @instance - * @returns {Object.} JSON object - */ - FunctionProto.prototype.toJSON = function toJSON() { - return this.constructor.toObject(this, $protobuf.util.toJSONOptions); - }; - - /** - * Gets the default type url for FunctionProto - * @function getTypeUrl - * @memberof onnx.FunctionProto - * @static - * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") - * @returns {string} The default type url - */ - FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { - if (typeUrlPrefix === undefined) { - typeUrlPrefix = "type.googleapis.com"; - } - return typeUrlPrefix + "/onnx.FunctionProto"; - }; + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; - return FunctionProto; - })(); + /** + * Verifies a FunctionProto message. + * @function verify + * @memberof onnx.FunctionProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + FunctionProto.verify = function verify(message) { + if (typeof message !== 'object' || message === null) return 'object expected'; + if (message.name != null && message.hasOwnProperty('name')) + if (!$util.isString(message.name)) return 'name: string expected'; + if (message.input != null && message.hasOwnProperty('input')) { + if (!Array.isArray(message.input)) return 'input: array expected'; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) return 'input: string[] expected'; + } + if (message.output != null && message.hasOwnProperty('output')) { + if (!Array.isArray(message.output)) return 'output: array expected'; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) return 'output: string[] expected'; + } + if (message.attribute != null && message.hasOwnProperty('attribute')) { + if (!Array.isArray(message.attribute)) return 'attribute: array expected'; + for (var i = 0; i < message.attribute.length; ++i) + if (!$util.isString(message.attribute[i])) return 'attribute: string[] expected'; + } + if (message.attributeProto != null && message.hasOwnProperty('attributeProto')) { + if (!Array.isArray(message.attributeProto)) return 'attributeProto: array expected'; + for (var i = 0; i < message.attributeProto.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); + if (error) return 'attributeProto.' + error; + } + } + if (message.node != null && message.hasOwnProperty('node')) { + if (!Array.isArray(message.node)) return 'node: array expected'; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) return 'node.' + error; + } + } + if (message.docString != null && message.hasOwnProperty('docString')) + if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.opsetImport != null && message.hasOwnProperty('opsetImport')) { + if (!Array.isArray(message.opsetImport)) return 'opsetImport: array expected'; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) return 'opsetImport.' + error; + } + } + if (message.domain != null && message.hasOwnProperty('domain')) + if (!$util.isString(message.domain)) return 'domain: string expected'; + return null; + }; + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.FunctionProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.FunctionProto} FunctionProto + */ + FunctionProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.FunctionProto) return object; + var message = new $root.onnx.FunctionProto(); + if (object.name != null) message.name = String(object.name); + if (object.input) { + if (!Array.isArray(object.input)) throw TypeError('.onnx.FunctionProto.input: array expected'); + message.input = []; + for (var i = 0; i < object.input.length; ++i) message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) throw TypeError('.onnx.FunctionProto.output: array expected'); + message.output = []; + for (var i = 0; i < object.output.length; ++i) message.output[i] = String(object.output[i]); + } + if (object.attribute) { + if (!Array.isArray(object.attribute)) throw TypeError('.onnx.FunctionProto.attribute: array expected'); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) message.attribute[i] = String(object.attribute[i]); + } + if (object.attributeProto) { + if (!Array.isArray(object.attributeProto)) + throw TypeError('.onnx.FunctionProto.attributeProto: array expected'); + message.attributeProto = []; + for (var i = 0; i < object.attributeProto.length; ++i) { + if (typeof object.attributeProto[i] !== 'object') + throw TypeError('.onnx.FunctionProto.attributeProto: object expected'); + message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); + } + } + if (object.node) { + if (!Array.isArray(object.node)) throw TypeError('.onnx.FunctionProto.node: array expected'); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== 'object') throw TypeError('.onnx.FunctionProto.node: object expected'); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.docString != null) message.docString = String(object.docString); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) throw TypeError('.onnx.FunctionProto.opsetImport: array expected'); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== 'object') + throw TypeError('.onnx.FunctionProto.opsetImport: object expected'); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.domain != null) message.domain = String(object.domain); + return message; + }; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.FunctionProto + * @static + * @param {onnx.FunctionProto} message FunctionProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + FunctionProto.toObject = function toObject(message, options) { + if (!options) options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + object.node = []; + object.opsetImport = []; + object.attributeProto = []; + } + if (options.defaults) { + object.name = ''; + object.docString = ''; + object.domain = ''; + } + if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) object.output[j] = message.output[j]; + } + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) object.attribute[j] = message.attribute[j]; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + if (message.attributeProto && message.attributeProto.length) { + object.attributeProto = []; + for (var j = 0; j < message.attributeProto.length; ++j) + object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); + } + return object; + }; + + /** + * Converts this FunctionProto to JSON. + * @function toJSON + * @memberof onnx.FunctionProto + * @instance + * @returns {Object.} JSON object + */ + FunctionProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for FunctionProto + * @function getTypeUrl + * @memberof onnx.FunctionProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = 'type.googleapis.com'; + } + return typeUrlPrefix + '/onnx.FunctionProto'; + }; + + return FunctionProto; + })(); - return onnx; + return onnx; })(); module.exports = $root; diff --git a/js/web/lib/onnxjs/session-handler-inference.ts b/js/web/lib/onnxjs/session-handler-inference.ts index 47e50aeab673a..c1c2576971840 100644 --- a/js/web/lib/onnxjs/session-handler-inference.ts +++ b/js/web/lib/onnxjs/session-handler-inference.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; +import { InferenceSession, InferenceSessionHandler, SessionHandler, Tensor } from 'onnxruntime-common'; -import {Session} from './session'; -import {Tensor as OnnxjsTensor} from './tensor'; +import { Session } from './session'; +import { Tensor as OnnxjsTensor } from './tensor'; export class OnnxjsSessionHandler implements InferenceSessionHandler { constructor(private session: Session) { @@ -16,17 +16,24 @@ export class OnnxjsSessionHandler implements InferenceSessionHandler { inputNames: readonly string[]; outputNames: readonly string[]; async run( - feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, - _options: InferenceSession.RunOptions): Promise { + feeds: SessionHandler.FeedsType, + _fetches: SessionHandler.FetchesType, + _options: InferenceSession.RunOptions, + ): Promise { const inputMap = new Map(); for (const name in feeds) { if (Object.hasOwnProperty.call(feeds, name)) { const feed = feeds[name]; inputMap.set( - name, - new OnnxjsTensor( - feed.dims, feed.type as OnnxjsTensor.DataType, undefined, undefined, - feed.data as OnnxjsTensor.NumberType)); + name, + new OnnxjsTensor( + feed.dims, + feed.type as OnnxjsTensor.DataType, + undefined, + undefined, + feed.data as OnnxjsTensor.NumberType, + ), + ); } } const outputMap = await this.session.run(inputMap); diff --git a/js/web/lib/onnxjs/session.ts b/js/web/lib/onnxjs/session.ts index 73e656f3b04b5..26243ed9fe509 100644 --- a/js/web/lib/onnxjs/session.ts +++ b/js/web/lib/onnxjs/session.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {resolveBackend, SessionHandlerType} from './backend'; -import {ExecutionPlan} from './execution-plan'; -import {Graph} from './graph'; -import {Profiler} from './instrument'; -import {Model} from './model'; -import {Operator} from './operators'; -import {Tensor} from './tensor'; +import { resolveBackend, SessionHandlerType } from './backend'; +import { ExecutionPlan } from './execution-plan'; +import { Graph } from './graph'; +import { Profiler } from './instrument'; +import { Model } from './model'; +import { Operator } from './operators'; +import { Tensor } from './tensor'; export declare namespace Session { export interface Config { @@ -27,7 +27,7 @@ export class Session { this._initialized = false; this.backendHint = config.backendHint; this.profiler = Profiler.create(config.profiler); - this.context = {profiler: this.profiler, graphInputTypes: [], graphInputDims: []}; + this.context = { profiler: this.profiler, graphInputTypes: [], graphInputDims: [] }; } get inputNames(): readonly string[] { @@ -48,7 +48,7 @@ export class Session { async loadModel(uri: string): Promise; async loadModel(buffer: ArrayBuffer, byteOffset?: number, length?: number): Promise; async loadModel(buffer: Uint8Array): Promise; - async loadModel(arg: string|ArrayBuffer|Uint8Array, byteOffset?: number, length?: number): Promise { + async loadModel(arg: string | ArrayBuffer | Uint8Array, byteOffset?: number, length?: number): Promise { await this.profiler.event('session', 'Session.loadModel', async () => { // resolve backend and session handler const backend = await resolveBackend(this.backendHint); @@ -59,7 +59,7 @@ export class Session { const isOrtFormat = arg.endsWith('.ort'); if (typeof process !== 'undefined' && process.versions && process.versions.node) { // node - const {readFile} = require('node:fs/promises'); + const { readFile } = require('node:fs/promises'); const buf = await readFile(arg); this.initialize(buf, isOrtFormat); } else { @@ -86,8 +86,9 @@ export class Session { this.profiler.event('session', 'Session.initialize', () => { // load graph - const graphInitializer = - this.sessionHandler.transformGraph ? this.sessionHandler as Graph.Initializer : undefined; + const graphInitializer = this.sessionHandler.transformGraph + ? (this.sessionHandler as Graph.Initializer) + : undefined; this._model.load(modelProtoBlob, graphInitializer, isOrtFormat); // graph is completely initialzied at this stage , let the interested handlers know @@ -104,7 +105,7 @@ export class Session { this._initialized = true; } - async run(inputs: Map|Tensor[]): Promise> { + async run(inputs: Map | Tensor[]): Promise> { if (!this._initialized) { throw new Error('session not initialized yet'); } @@ -118,7 +119,7 @@ export class Session { }); } - private normalizeAndValidateInputs(inputs: Map|Tensor[]): Tensor[] { + private normalizeAndValidateInputs(inputs: Map | Tensor[]): Tensor[] { const modelInputNames = this._model.graph.getInputNames(); // normalize inputs @@ -150,8 +151,12 @@ export class Session { // validate dims requirements // First session run - graph input data is not cached for the session - if (!this.context.graphInputTypes || this.context.graphInputTypes.length === 0 || !this.context.graphInputDims || - this.context.graphInputDims.length === 0) { + if ( + !this.context.graphInputTypes || + this.context.graphInputTypes.length === 0 || + !this.context.graphInputDims || + this.context.graphInputDims.length === 0 + ) { const modelInputIndices = this._model.graph.getInputIndices(); const modelValues = this._model.graph.getValues(); @@ -192,19 +197,28 @@ export class Session { } private validateInputTensorDims( - graphInputDims: Array, givenInputs: Tensor[], noneDimSupported: boolean) { + graphInputDims: Array, + givenInputs: Tensor[], + noneDimSupported: boolean, + ) { for (let i = 0; i < givenInputs.length; i++) { const expectedDims = graphInputDims[i]; const actualDims = givenInputs[i].dims; if (!this.compareTensorDims(expectedDims, actualDims, noneDimSupported)) { - throw new Error(`input tensor[${i}] check failed: expected shape '[${expectedDims.join(',')}]' but got [${ - actualDims.join(',')}]`); + throw new Error( + `input tensor[${i}] check failed: expected shape '[${expectedDims.join(',')}]' but got [${actualDims.join( + ',', + )}]`, + ); } } } - private compareTensorDims(expectedDims: readonly number[], actualDims: readonly number[], noneDimSupported: boolean): - boolean { + private compareTensorDims( + expectedDims: readonly number[], + actualDims: readonly number[], + noneDimSupported: boolean, + ): boolean { if (expectedDims.length !== actualDims.length) { return false; } diff --git a/js/web/lib/onnxjs/tensor.ts b/js/web/lib/onnxjs/tensor.ts index 1a4c1dfe7494d..6e9ecf8006d4d 100644 --- a/js/web/lib/onnxjs/tensor.ts +++ b/js/web/lib/onnxjs/tensor.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Guid} from 'guid-typescript'; +import { Guid } from 'guid-typescript'; import Long from 'long'; -import {onnxruntime} from './ort-schema/flatbuffers/ort-generated'; -import {onnx} from './ort-schema/protobuf/onnx'; -import {decodeUtf8String, ProtoUtil, ShapeUtil} from './util'; +import { onnxruntime } from './ort-schema/flatbuffers/ort-generated'; +import { onnx } from './ort-schema/protobuf/onnx'; +import { decodeUtf8String, ProtoUtil, ShapeUtil } from './util'; import ortFbs = onnxruntime.experimental.fbs; @@ -29,10 +29,15 @@ export declare namespace Tensor { export type StringType = Tensor.DataTypeMap['string']; export type BooleanType = Tensor.DataTypeMap['bool']; - export type IntegerType = Tensor.DataTypeMap['int8']|Tensor.DataTypeMap['uint8']|Tensor.DataTypeMap['int16']| - Tensor.DataTypeMap['uint16']|Tensor.DataTypeMap['int32']|Tensor.DataTypeMap['uint32']; - export type FloatType = Tensor.DataTypeMap['float32']|Tensor.DataTypeMap['float64']; - export type NumberType = BooleanType|IntegerType|FloatType; + export type IntegerType = + | Tensor.DataTypeMap['int8'] + | Tensor.DataTypeMap['uint8'] + | Tensor.DataTypeMap['int16'] + | Tensor.DataTypeMap['uint16'] + | Tensor.DataTypeMap['int32'] + | Tensor.DataTypeMap['uint32']; + export type FloatType = Tensor.DataTypeMap['float32'] | Tensor.DataTypeMap['float64']; + export type NumberType = BooleanType | IntegerType | FloatType; export type Id = Guid; } @@ -154,31 +159,34 @@ export class Tensor { } constructor( - /** - * get the dimensions of the tensor - */ - public readonly dims: readonly number[], - /** - * get the type of the tensor - */ - public readonly type: Tensor.DataType, private dataProvider?: DataProvider, - private asyncDataProvider?: AsyncDataProvider, private cache?: TensorData, - /** - * get the data ID that used to map to a tensor data - */ - public readonly dataId: Guid = Guid.create()) { + /** + * get the dimensions of the tensor + */ + public readonly dims: readonly number[], + /** + * get the type of the tensor + */ + public readonly type: Tensor.DataType, + private dataProvider?: DataProvider, + private asyncDataProvider?: AsyncDataProvider, + private cache?: TensorData, + /** + * get the data ID that used to map to a tensor data + */ + public readonly dataId: Guid = Guid.create(), + ) { this.size = ShapeUtil.validateDimsAndCalcSize(dims); const size = this.size; - const empty = (dataProvider === undefined && asyncDataProvider === undefined && cache === undefined); + const empty = dataProvider === undefined && asyncDataProvider === undefined && cache === undefined; if (cache !== undefined) { if (cache.length !== size) { - throw new RangeError('Input dims doesn\'t match data length.'); + throw new RangeError("Input dims doesn't match data length."); } } if (type === 'string') { - if (cache !== undefined && (!Array.isArray(cache) || !cache.every(i => typeof i === 'string'))) { + if (cache !== undefined && (!Array.isArray(cache) || !cache.every((i) => typeof i === 'string'))) { throw new TypeError('cache should be a string array'); } @@ -219,16 +227,20 @@ export class Tensor { tensorProto.stringData!.forEach((str, i) => { value.data[i] = decodeUtf8String(str); }); - } else if ( - tensorProto.rawData && typeof tensorProto.rawData.byteLength === 'number' && - tensorProto.rawData.byteLength > 0) { + tensorProto.rawData && + typeof tensorProto.rawData.byteLength === 'number' && + tensorProto.rawData.byteLength > 0 + ) { // NOT considering segment for now (IMPORTANT) // populate value from rawData const dataDest = value.data; - const dataSource = - new DataView(tensorProto.rawData.buffer, tensorProto.rawData.byteOffset, tensorProto.rawData.byteLength); + const dataSource = new DataView( + tensorProto.rawData.buffer, + tensorProto.rawData.byteOffset, + tensorProto.rawData.byteLength, + ); const elementSize = sizeofProto(tensorProto.dataType!); const length = tensorProto.rawData.byteLength / elementSize; @@ -245,7 +257,7 @@ export class Tensor { } } else { // populate value from array - let array: Array; + let array: Array; switch (tensorProto.dataType) { case onnx.TensorProto.DataType.FLOAT: array = tensorProto.floatData!; @@ -321,15 +333,20 @@ export class Tensor { for (let i = 0; i < ortTensor.stringDataLength(); i++) { value.data[i] = ortTensor.stringData(i); } - } else if ( - ortTensor.rawDataArray() && typeof ortTensor.rawDataLength() === 'number' && ortTensor.rawDataLength() > 0) { + ortTensor.rawDataArray() && + typeof ortTensor.rawDataLength() === 'number' && + ortTensor.rawDataLength() > 0 + ) { // NOT considering segment for now (IMPORTANT) // populate value from rawData const dataDest = value.data; const dataSource = new DataView( - ortTensor.rawDataArray()!.buffer, ortTensor.rawDataArray()!.byteOffset, ortTensor.rawDataLength()); + ortTensor.rawDataArray()!.buffer, + ortTensor.rawDataArray()!.byteOffset, + ortTensor.rawDataLength(), + ); const elementSize = sizeofProto(ortTensor.dataType()); const length = ortTensor.rawDataLength() / elementSize; @@ -369,7 +386,7 @@ function sizeof(type: Tensor.DataType): number { } } -function sizeofProto(type: onnx.TensorProto.DataType|ortFbs.TensorDataType): number { +function sizeofProto(type: onnx.TensorProto.DataType | ortFbs.TensorDataType): number { switch (type) { case onnx.TensorProto.DataType.UINT8: case onnx.TensorProto.DataType.INT8: @@ -423,15 +440,18 @@ function dataviewConstructor(type: Tensor.DataType) { } // convert a long number to a 32-bit integer (cast-down) -function longToNumber(i: Long, type: onnx.TensorProto.DataType|ortFbs.TensorDataType): number { +function longToNumber(i: Long, type: onnx.TensorProto.DataType | ortFbs.TensorDataType): number { // INT64, UINT32, UINT64 if (type === onnx.TensorProto.DataType.INT64 || type === ortFbs.TensorDataType.INT64) { if (i.greaterThanOrEqual(2147483648) || i.lessThan(-2147483648)) { throw new TypeError('int64 is not supported'); } } else if ( - type === onnx.TensorProto.DataType.UINT32 || type === ortFbs.TensorDataType.UINT32 || - type === onnx.TensorProto.DataType.UINT64 || type === ortFbs.TensorDataType.UINT64) { + type === onnx.TensorProto.DataType.UINT32 || + type === ortFbs.TensorDataType.UINT32 || + type === onnx.TensorProto.DataType.UINT64 || + type === ortFbs.TensorDataType.UINT64 + ) { if (i.greaterThanOrEqual(4294967296) || i.lessThan(0)) { throw new TypeError('uint64 is not supported'); } @@ -443,7 +463,11 @@ function longToNumber(i: Long, type: onnx.TensorProto.DataType|ortFbs.TensorData } // read one value from TensorProto -function readProto(view: DataView, type: onnx.TensorProto.DataType|ortFbs.TensorDataType, byteOffset: number): number { +function readProto( + view: DataView, + type: onnx.TensorProto.DataType | ortFbs.TensorDataType, + byteOffset: number, +): number { switch (type) { case onnx.TensorProto.DataType.BOOL: case onnx.TensorProto.DataType.UINT8: @@ -462,12 +486,16 @@ function readProto(view: DataView, type: onnx.TensorProto.DataType|ortFbs.Tensor return view.getUint32(byteOffset, true); case onnx.TensorProto.DataType.INT64: return longToNumber( - Long.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), false), type); + Long.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), false), + type, + ); case onnx.TensorProto.DataType.DOUBLE: return view.getFloat64(byteOffset, true); case onnx.TensorProto.DataType.UINT64: return longToNumber( - Long.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), true), type); + Long.fromBits(view.getUint32(byteOffset, true), view.getUint32(byteOffset + 4, true), true), + type, + ); default: throw new Error(`cannot read from DataView for type ${onnx.TensorProto.DataType[type]}`); } diff --git a/js/web/lib/onnxjs/util.ts b/js/web/lib/onnxjs/util.ts index 22c4e4c755f55..e1a6966c7b0a3 100644 --- a/js/web/lib/onnxjs/util.ts +++ b/js/web/lib/onnxjs/util.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {flatbuffers} from 'flatbuffers'; +import { flatbuffers } from 'flatbuffers'; import Long from 'long'; -import {Graph} from './graph'; -import {onnxruntime} from './ort-schema/flatbuffers/ort-generated'; -import {onnx} from './ort-schema/protobuf/onnx'; -import {Tensor} from './tensor'; +import { Graph } from './graph'; +import { onnxruntime } from './ort-schema/flatbuffers/ort-generated'; +import { onnx } from './ort-schema/protobuf/onnx'; +import { Tensor } from './tensor'; // check the inputs shape before running an OP. // return true when the inputs pass the check @@ -40,10 +40,29 @@ export class ArrayUtil { * @returns Whether these 2 are equal */ static arraysEqual( - n1: readonly number[]|Int8Array|Uint8Array|Int16Array|Uint16Array|Int32Array|Uint32Array|Uint8ClampedArray| - Float32Array|Float64Array, - n2: readonly number[]|Int8Array|Uint8Array|Int16Array|Uint16Array|Int32Array|Uint32Array|Uint8ClampedArray| - Float32Array|Float64Array) { + n1: + | readonly number[] + | Int8Array + | Uint8Array + | Int16Array + | Uint16Array + | Int32Array + | Uint32Array + | Uint8ClampedArray + | Float32Array + | Float64Array, + n2: + | readonly number[] + | Int8Array + | Uint8Array + | Int16Array + | Uint16Array + | Int32Array + | Uint32Array + | Uint8ClampedArray + | Float32Array + | Float64Array, + ) { if (n1.length !== n2.length) { return false; } @@ -63,17 +82,19 @@ export class MatMulUtil { * @param dimsB The shape of tensor B. Should be an array of positive integers * @returns A tuple containing the preprocessed input shapes as required by ONNX specifications */ - static preprocessInputShapes(dimsA: readonly number[], dimsB: readonly number[]): - [readonly number[], readonly number[]] { + static preprocessInputShapes( + dimsA: readonly number[], + dimsB: readonly number[], + ): [readonly number[], readonly number[]] { // If the first argument is 1-D, it is promoted to a matrix by prepending // a 1 to its dimensions. After matrix multiplication the prepended 1 is // removed. - const a = (dimsA.length === 1) ? [1, dimsA[0]] : dimsA; + const a = dimsA.length === 1 ? [1, dimsA[0]] : dimsA; // If the second argument is 1-D, it is promoted to a matrix by appending // a 1 to its dimensions. After matrix multiplication the appended 1 is // removed. - const b = (dimsB.length === 1) ? [dimsB[0], 1] : dimsB; + const b = dimsB.length === 1 ? [dimsB[0], 1] : dimsB; return [a, b]; } @@ -103,8 +124,8 @@ export class MatMulUtil { * @param b The shape of tensor B. Should be a tuple of 2 positive integers * @returns The expected shape of the result, or undefined if N/A */ - static calcMatMulShape(a: [number, number], b: [number, number]): [number, number]|undefined { - return (a[1] !== b[0]) ? undefined : [a[0], b[1]]; + static calcMatMulShape(a: [number, number], b: [number, number]): [number, number] | undefined { + return a[1] !== b[0] ? undefined : [a[0], b[1]]; } } @@ -116,7 +137,11 @@ export class BroadcastUtil { * @param isMatMul Whether the operation is MatMul * @returns The expected shape of the result, or undefined if N/A */ - static calcShape(adims: readonly number[], bdims: readonly number[], isMatMul = false): readonly number[]|undefined { + static calcShape( + adims: readonly number[], + bdims: readonly number[], + isMatMul = false, + ): readonly number[] | undefined { const arank = adims.length; const brank = bdims.length; if (arank === 0) { @@ -133,8 +158,10 @@ export class BroadcastUtil { if (arank < 2 || brank < 2) { return undefined; } - const cShapeMatMul = - MatMulUtil.calcMatMulShape([adims[arank - 2], adims[arank - 1]], [bdims[brank - 2], bdims[brank - 1]]); + const cShapeMatMul = MatMulUtil.calcMatMulShape( + [adims[arank - 2], adims[arank - 1]], + [bdims[brank - 2], bdims[brank - 1]], + ); if (cShapeMatMul === undefined) { return undefined; } @@ -195,8 +222,12 @@ export class BroadcastUtil { * @returns The result tensor, or undefined if input not broadcastable. */ static calc( - a: Tensor, b: Tensor, op: (a: string|number, b: string|number) => (string | number), inplace: boolean, - resultType?: Tensor.DataType): Tensor|undefined { + a: Tensor, + b: Tensor, + op: (a: string | number, b: string | number) => string | number, + inplace: boolean, + resultType?: Tensor.DataType, + ): Tensor | undefined { const outputShape = BroadcastUtil.calcShape(a.dims, b.dims); if (outputShape) { @@ -218,8 +249,8 @@ export class BroadcastUtil { const outputIndices = new Array(outputShape.length); const originalIndicesA = new Array(a.dims.length); const originalIndicesB = new Array(b.dims.length); - let valA: string|number = 0; - let valB: string|number = 0; + let valA: string | number = 0; + let valB: string | number = 0; let isAScalar = false; let isBScalar = false; if (a.dims.length === 0) { @@ -304,8 +335,12 @@ export class BroadcastUtil { // copy array helper // mimics memcpy as much as possible export function arrayCopyHelper( - target: number[]|Tensor.NumberType, source: number[]|Tensor.NumberType, targetIndex: number, sourceIndex: number, - blockSize: number) { + target: number[] | Tensor.NumberType, + source: number[] | Tensor.NumberType, + targetIndex: number, + sourceIndex: number, + blockSize: number, +) { if (sourceIndex < 0 || sourceIndex >= source.length) { throw new Error('sourceIndex out of bounds'); } @@ -329,8 +364,12 @@ export class GemmUtil { // and return back the shape of the output in the form of a tuple // will throw exception if the input shapes are not compatible static getShapeOfGemmResult( - leftShape: readonly number[], transLeft: boolean, rightShape: readonly number[], transRight: boolean, - biasShape?: readonly number[]): readonly number[] { + leftShape: readonly number[], + transLeft: boolean, + rightShape: readonly number[], + transRight: boolean, + biasShape?: readonly number[], + ): readonly number[] { if (leftShape.length !== 2 || rightShape.length !== 2) { throw new Error('shape need to be of size 2'); } @@ -374,8 +413,9 @@ export class GemmUtil { } export class ProtoUtil { - static tensorDataTypeFromProto(typeProto: onnx.TensorProto.DataType| - onnxruntime.experimental.fbs.TensorDataType): Tensor.DataType { + static tensorDataTypeFromProto( + typeProto: onnx.TensorProto.DataType | onnxruntime.experimental.fbs.TensorDataType, + ): Tensor.DataType { switch (typeProto) { case onnx.TensorProto.DataType.INT8: return 'int8'; @@ -442,15 +482,15 @@ export class ProtoUtil { } } - static tensorDimsFromProto(dims: Array): number[] { + static tensorDimsFromProto(dims: Array): number[] { // get rid of Long type for dims - return dims.map(d => Long.isLong(d) ? d.toNumber() : d); + return dims.map((d) => (Long.isLong(d) ? d.toNumber() : d)); } static tensorValueTypeFromProto(valueType: onnx.TypeProto.ITensor): Graph.ValueType { return { tensorType: ProtoUtil.tensorDataTypeFromProto(valueType.elemType!), - shape: {dims: ProtoUtil.tensorDimsFromProto(valueType.shape!.dim!.map(d => d.dimValue!))} + shape: { dims: ProtoUtil.tensorDimsFromProto(valueType.shape!.dim!.map((d) => d.dimValue!)) }, }; } @@ -475,11 +515,11 @@ export class LongUtil { // This function is called to get a number from long type of data for attribute, dim, and ir version, // which values are signed integers. // To make it more generic, add an optional parameter to convert to a unsigned number. - static longToNumber(n: Long|flatbuffers.Long|number, unsigned?: boolean) { + static longToNumber(n: Long | flatbuffers.Long | number, unsigned?: boolean) { if (Long.isLong(n)) { return n.toNumber(); } else if (n instanceof flatbuffers.Long) { - return Long.fromValue({low: n.low, high: n.high, unsigned: unsigned ?? false}).toNumber(); + return Long.fromValue({ low: n.low, high: n.high, unsigned: unsigned ?? false }).toNumber(); } return n; } @@ -516,8 +556,9 @@ export class ShapeUtil { // size cannot be 0 or negative. if (dims[i] <= 0) { throw new Error( - // eslint-disable-next-line max-len - 'cannot get valid size from specified dimension range. Most likely the range contains 0 or negative values in them.'); + // eslint-disable-next-line max-len + 'cannot get valid size from specified dimension range. Most likely the range contains 0 or negative values in them.', + ); } size *= dims[i]; } @@ -583,7 +624,7 @@ export class ShapeUtil { } static normalizeAxes(axes: readonly number[], tensorRank: number): number[] { - return axes.map(x => this.normalizeAxis(x, tensorRank)); + return axes.map((x) => this.normalizeAxis(x, tensorRank)); } // Increment an index into a tensor (in lexicographic @@ -666,15 +707,18 @@ export class ShapeUtil { const oldTensorSize = ShapeUtil.size(originalDims); if (unknownDimension !== -1) { if (oldTensorSize % newTensorSize !== 0) { - throw new Error(`the input tensor cannot be reshaped to the requested shape. Input shape: [${ - originalDims}] Output shape: [${shapeHints}]`); + throw new Error( + `the input tensor cannot be reshaped to the requested shape. Input shape: [${ + originalDims + }] Output shape: [${shapeHints}]`, + ); } reshapedDims[unknownDimension] = oldTensorSize / newTensorSize; } // validate sizes from originalDims and reshapedDims match else { if (newTensorSize !== oldTensorSize) { - throw new Error('reshapedDims and originalDims don\'t have matching sizes'); + throw new Error("reshapedDims and originalDims don't have matching sizes"); } } return reshapedDims; @@ -793,10 +837,10 @@ export class ShapeUtil { for (let i = 0; i < axes.length; i++) { const axis = ShapeUtil.normalizeAxis(axes[i], outputDims.length); if (axis >= outputDims.length) { - throw new Error('\'axes\' has an out of range axis'); + throw new Error("'axes' has an out of range axis"); } if (outputDims[axis] !== 0) { - throw new Error('\'axes\' has a duplicate axis'); + throw new Error("'axes' has a duplicate axis"); } outputDims[axis] = 1; @@ -824,8 +868,12 @@ export class ShapeUtil { export class MathUtil { // y = (x*x) + y static sqr( - target: number[]|Tensor.NumberType, source: number[]|Tensor.NumberType, targetIndex: number, sourceIndex: number, - blockSize: number) { + target: number[] | Tensor.NumberType, + source: number[] | Tensor.NumberType, + targetIndex: number, + sourceIndex: number, + blockSize: number, + ) { if (sourceIndex < 0 || sourceIndex >= source.length) { throw new Error('sourceIndex out of bounds'); } @@ -846,8 +894,13 @@ export class MathUtil { // y = ax + y static axpy( - target: number[]|Tensor.NumberType, source: number[]|Tensor.NumberType, targetIndex: number, sourceIndex: number, - blockSize: number, alpha: number) { + target: number[] | Tensor.NumberType, + source: number[] | Tensor.NumberType, + targetIndex: number, + sourceIndex: number, + blockSize: number, + alpha: number, + ) { if (sourceIndex < 0 || sourceIndex >= source.length) { throw new Error('sourceIndex out of bounds'); } @@ -862,14 +915,19 @@ export class MathUtil { } for (let offset = 0; offset < blockSize; offset++) { - target[targetIndex + offset] += (alpha * source[sourceIndex + offset]); + target[targetIndex + offset] += alpha * source[sourceIndex + offset]; } } // y = pow(x, b) static powx( - target: number[]|Tensor.NumberType, source: number[]|Tensor.NumberType, targetIndex: number, sourceIndex: number, - blockSize: number, b: number) { + target: number[] | Tensor.NumberType, + source: number[] | Tensor.NumberType, + targetIndex: number, + sourceIndex: number, + blockSize: number, + b: number, + ) { if (sourceIndex < 0 || sourceIndex >= source.length) { throw new Error('sourceIndex out of bounds'); } @@ -890,8 +948,12 @@ export class MathUtil { // y = x * y static mul( - target: number[]|Tensor.NumberType, source: number[]|Tensor.NumberType, targetIndex: number, sourceIndex: number, - blockSize: number) { + target: number[] | Tensor.NumberType, + source: number[] | Tensor.NumberType, + targetIndex: number, + sourceIndex: number, + blockSize: number, + ) { if (sourceIndex < 0 || sourceIndex >= source.length) { throw new Error('sourceIndex out of bounds'); } @@ -906,7 +968,7 @@ export class MathUtil { } for (let offset = 0; offset < blockSize; offset++) { - target[targetIndex + offset] = (source[sourceIndex + offset] * target[targetIndex + offset]); + target[targetIndex + offset] = source[sourceIndex + offset] * target[targetIndex + offset]; } } } @@ -918,11 +980,15 @@ export class SplitUtil { * @param axis The dimension along which the Tensor will be split * @param splits Offsets for the start of each split */ - static splitShape(dims: readonly number[], axis: number, split: number[], numOutputs?: number): - [number[][], number[]] { + static splitShape( + dims: readonly number[], + axis: number, + split: number[], + numOutputs?: number, + ): [number[][], number[]] { if (split.length === 0) { if (!numOutputs) { - throw new Error('need to know number of outputs when the \'split\' attribute is not specified'); + throw new Error("need to know number of outputs when the 'split' attribute is not specified"); } SplitUtil.determineSplit(dims[axis], numOutputs, split); } @@ -962,8 +1028,12 @@ export class ReduceUtil { * @param op2 The operation to be performed between elements in the tensor */ static calcReduce( - a: Tensor, axes: number[], keepdims: boolean, op1: (b: number) => number, - op2: (a: number, b: number) => number): Tensor { + a: Tensor, + axes: number[], + keepdims: boolean, + op1: (b: number) => number, + op2: (a: number, b: number) => number, + ): Tensor { const dims = a.dims.slice(0); // if axes is not set, perform reduce on all axes if (axes.length === 0) { @@ -983,9 +1053,17 @@ export class ReduceUtil { // map index BroadcastUtil.fillIndex(indices, dims, indicesY); y.set( - indices, - ReduceUtil.calcReduceByAxis( - a.numberData, axes, dims, 0, ShapeUtil.indicesToOffset(indicesY, inputStrides), op1, op2)); + indices, + ReduceUtil.calcReduceByAxis( + a.numberData, + axes, + dims, + 0, + ShapeUtil.indicesToOffset(indicesY, inputStrides), + op1, + op2, + ), + ); } if (keepdims) { @@ -993,7 +1071,13 @@ export class ReduceUtil { } else { // keepdims == 0, calculate the expected shape return new Tensor( - ReduceUtil.calcReduceShape(dims, axes, keepdims), y.type, undefined, undefined, y.data, y.dataId); + ReduceUtil.calcReduceShape(dims, axes, keepdims), + y.type, + undefined, + undefined, + y.data, + y.dataId, + ); } } @@ -1009,8 +1093,14 @@ export class ReduceUtil { * @param op2 The operation to be performed between elements in the tensor */ static calcReduceByAxis( - input: Tensor.NumberType, axes: number[], dims: number[], curAxisInd: number, pos: number, - op1: (b: number) => number, op2: (a: number, b: number) => number): number { + input: Tensor.NumberType, + axes: number[], + dims: number[], + curAxisInd: number, + pos: number, + op1: (b: number) => number, + op2: (a: number, b: number) => number, + ): number { let res = 0; if (curAxisInd >= axes.length) { return op1(input[pos]); @@ -1018,8 +1108,10 @@ export class ReduceUtil { const axis = axes[curAxisInd]; const step = axis >= dims.length ? 1 : ShapeUtil.size(dims.slice(axis + 1)); for (let i = 0; i < dims[axis]; i++) { - res = i === 0 ? ReduceUtil.calcReduceByAxis(input, axes, dims, curAxisInd + 1, pos, op1, op2) : - op2(res, ReduceUtil.calcReduceByAxis(input, axes, dims, curAxisInd + 1, pos, op1, op2)); + res = + i === 0 + ? ReduceUtil.calcReduceByAxis(input, axes, dims, curAxisInd + 1, pos, op1, op2) + : op2(res, ReduceUtil.calcReduceByAxis(input, axes, dims, curAxisInd + 1, pos, op1, op2)); pos += step; } return res; @@ -1041,7 +1133,7 @@ export class ReduceUtil { outputDims[axes[i]] = 0; } } - return outputDims.filter(dim => dim !== 0); + return outputDims.filter((dim) => dim !== 0); } } @@ -1056,8 +1148,13 @@ export class PoolConvUtil { * @param pads Padding for the beginning and ending along each axis. */ static adjustPoolAttributes( - isGlobalOperator: boolean, inputDims: readonly number[], kernelShape: number[], strides: number[], - dilations: number[], pads: number[]) { + isGlobalOperator: boolean, + inputDims: readonly number[], + kernelShape: number[], + strides: number[], + dilations: number[], + pads: number[], + ) { if (!isGlobalOperator && kernelShape.length !== inputDims.length - 2) { throw new Error('length of specified kernel shapes should be 2 less than length of input dimensions'); } @@ -1120,8 +1217,13 @@ export class PoolConvUtil { // adjust pad values based on 'autoPad' attribute static adjustPadsBasedOnAutoPad( - inputDims: readonly number[], strides: readonly number[], dilations: readonly number[], - kernelShape: readonly number[], pads: number[], autoPad?: string) { + inputDims: readonly number[], + strides: readonly number[], + dilations: readonly number[], + kernelShape: readonly number[], + pads: number[], + autoPad?: string, + ) { if (!autoPad) { return; } @@ -1130,18 +1232,25 @@ export class PoolConvUtil { throw new Error('length of pads should be twice the length of data dimensions'); } - if (strides.length !== (inputDims.length - 2)) { + if (strides.length !== inputDims.length - 2) { throw new Error('length of strides should be the length of data dimensions'); } - if (kernelShape.length !== (inputDims.length - 2)) { + if (kernelShape.length !== inputDims.length - 2) { throw new Error('length of kernel shapes should be the length of data dimensions'); } for (let dim = 0; dim < inputDims.length - 2; dim++) { PoolConvUtil.adjustPadAndReturnShape( - inputDims[dim + 2], strides[dim], dilations[dim], kernelShape[dim], pads, dim, dim + inputDims.length - 2, - autoPad); + inputDims[dim + 2], + strides[dim], + dilations[dim], + kernelShape[dim], + pads, + dim, + dim + inputDims.length - 2, + autoPad, + ); } } @@ -1157,8 +1266,14 @@ export class PoolConvUtil { * dimension. Can take values NOTSET, SAME_UPPER, SAME_LOWER, or VALID. */ static computePoolOutputShape( - isGlobalOperator: boolean, inputDims: readonly number[], strides: number[], dilations: number[], - kernelShape: number[], pads: number[], autoPad?: string): number[] { + isGlobalOperator: boolean, + inputDims: readonly number[], + strides: number[], + dilations: number[], + kernelShape: number[], + pads: number[], + autoPad?: string, + ): number[] { if (inputDims.length <= 0) { throw new Error('input shape must be of size greater than 0'); } @@ -1167,7 +1282,15 @@ export class PoolConvUtil { const outputDims = [inputDims[0], inputDims[1]]; PoolConvUtil.computeShapeHelper( - isGlobalOperator, inputDims, outputDims, strides, dilations, kernelShape, pads, autoPad); + isGlobalOperator, + inputDims, + outputDims, + strides, + dilations, + kernelShape, + pads, + autoPad, + ); return outputDims; } @@ -1182,8 +1305,14 @@ export class PoolConvUtil { * dimension. Can take values NOTSET, SAME_UPPER, SAME_LOWER, or VALID. */ static computeConvOutputShape( - inputDims: readonly number[], filterDims: readonly number[], strides: number[], dilations: number[], - kernelShape: number[], pads: number[], autoPad?: string): number[] { + inputDims: readonly number[], + filterDims: readonly number[], + strides: number[], + dilations: number[], + kernelShape: number[], + pads: number[], + autoPad?: string, + ): number[] { if (inputDims.length <= 0 || filterDims.length <= 0) { throw new Error('invalid input tensor dims or invalid filter tensor dims'); } @@ -1199,17 +1328,33 @@ export class PoolConvUtil { // called by computePoolOutputShape() and computeConvOutputShape() // adjust pads based on 'autoPad' attribute prior to shape computation private static computeShapeHelper( - isGlobalOperator: boolean, inputDims: readonly number[], outputDims: number[], strides: readonly number[], - dilations: readonly number[], kernelShape: readonly number[], pads: number[], autoPad?: string) { + isGlobalOperator: boolean, + inputDims: readonly number[], + outputDims: number[], + strides: readonly number[], + dilations: readonly number[], + kernelShape: readonly number[], + pads: number[], + autoPad?: string, + ) { if (isGlobalOperator) { for (let dim = 0; dim < inputDims.length - 2; dim++) { outputDims.push(1); } } else { for (let dim = 0; dim < inputDims.length - 2; dim++) { - outputDims.push(PoolConvUtil.adjustPadAndReturnShape( - inputDims[dim + 2], strides[dim], dilations[dim], kernelShape[dim], pads, dim, dim + inputDims.length - 2, - autoPad)); + outputDims.push( + PoolConvUtil.adjustPadAndReturnShape( + inputDims[dim + 2], + strides[dim], + dilations[dim], + kernelShape[dim], + pads, + dim, + dim + inputDims.length - 2, + autoPad, + ), + ); } } } @@ -1217,15 +1362,22 @@ export class PoolConvUtil { // helper for computeShapeHelper() and adjustPadsBasedOnAutoPad() // adjusts pad value for given 'autoPad' string and computes output shape along a particular dimension private static adjustPadAndReturnShape( - inSize: number, stride: number, dilation: number, kernel: number, pads: number[], padHeadIndex: number, - padTailIndex: number, autoPad?: string): number { + inSize: number, + stride: number, + dilation: number, + kernel: number, + pads: number[], + padHeadIndex: number, + padTailIndex: number, + autoPad?: string, + ): number { const dkernel = dilation * (kernel - 1) + 1; if (autoPad && autoPad !== 'NOTSET') { switch (autoPad) { case 'VALID': pads[padHeadIndex] = 0; pads[padTailIndex] = 0; - return Math.floor(((inSize - dkernel) / stride) + 1); + return Math.floor((inSize - dkernel) / stride + 1); case 'SAME_LOWER': case 'SAME_UPPER': if (dilation !== 1) { @@ -1233,22 +1385,21 @@ export class PoolConvUtil { } else { const legacyTargetSize = (inSize + stride - 1) / stride; const padNeeded = (legacyTargetSize - 1) * stride + kernel - inSize; - pads[padHeadIndex] = - (autoPad === 'SAME_LOWER') ? Math.floor((padNeeded + 1) / 2) : Math.floor(padNeeded / 2); + pads[padHeadIndex] = autoPad === 'SAME_LOWER' ? Math.floor((padNeeded + 1) / 2) : Math.floor(padNeeded / 2); pads[padTailIndex] = padNeeded - pads[padHeadIndex]; - return Math.floor(((inSize + padNeeded - kernel) / stride) + 1); + return Math.floor((inSize + padNeeded - kernel) / stride + 1); } default: throw new Error('Unsupported AutoPad type'); } } else { - return Math.floor(((inSize + pads[padHeadIndex] + pads[padTailIndex] - dkernel) / stride) + 1); + return Math.floor((inSize + pads[padHeadIndex] + pads[padTailIndex] - dkernel) / stride + 1); } } } -export const MIN_CLIP = -3.4028234663852886e+38; -export const MAX_CLIP = 3.4028234663852886e+38; +export const MIN_CLIP = -3.4028234663852886e38; +export const MAX_CLIP = 3.4028234663852886e38; export function decodeUtf8String(buffer: Uint8Array): string { return new TextDecoder().decode(buffer); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index c701cf3a6df85..78147ffc09ab7 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -1,16 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; - -import {DataType, tensorDataTypeEnumToString} from '../wasm-common'; - -import {configureLogger, LOG_DEBUG} from './log'; -import {createView, TensorView} from './tensor-view'; -import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; -import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; -import {ProgramManager} from './webgpu/program-manager'; -import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; +import { Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END } from 'onnxruntime-common'; + +import { DataType, tensorDataTypeEnumToString } from '../wasm-common'; + +import { configureLogger, LOG_DEBUG } from './log'; +import { createView, TensorView } from './tensor-view'; +import { createGpuDataManager, downloadGpuData, GpuDataManager } from './webgpu/gpu-data-manager'; +import { RunFunction, WEBGPU_OP_RESOLVE_RULES } from './webgpu/op-resolve-rules'; +import { ProgramManager } from './webgpu/program-manager'; +import { + AdapterInfo, + ComputeContext, + GpuArchitecture, + GpuData, + GpuVendor, + ProgramInfo, + ProgramInputTensorInfoDependency, + SessionState, + TimestampQuery, +} from './webgpu/types'; interface CommandInfo { readonly kernelId: number; @@ -23,7 +33,7 @@ interface KernelInfo { readonly kernelType: string; readonly kernelName: string; readonly kernelEntry: RunFunction; - readonly attributes: [((attribute: unknown) => unknown)|undefined, unknown]; + readonly attributes: [((attribute: unknown) => unknown) | undefined, unknown]; } interface PendingKernelInfo { @@ -33,42 +43,47 @@ interface PendingKernelInfo { readonly outputTensorViews: readonly TensorView[]; } -const getProgramInputTensorInfoDependencyKey = - (inputTensors: readonly TensorView[], inputDependencies: readonly ProgramInputTensorInfoDependency[]): string => { - if (inputDependencies.length !== inputTensors.length) { - throw new Error(`inputDependencies length ${inputDependencies.length} is not equal to inputTensors length ${ - inputTensors.length}.`); - } +const getProgramInputTensorInfoDependencyKey = ( + inputTensors: readonly TensorView[], + inputDependencies: readonly ProgramInputTensorInfoDependency[], +): string => { + if (inputDependencies.length !== inputTensors.length) { + throw new Error( + `inputDependencies length ${inputDependencies.length} is not equal to inputTensors length ${ + inputTensors.length + }.`, + ); + } - const inputInfos: string[] = []; - for (let i = 0; i < inputTensors.length; ++i) { - const type = inputTensors[i].dataType; - switch (inputDependencies[i]) { - case 'none': { - inputInfos.push(''); - break; - } - case 'type': { - inputInfos.push(`${type}`); - break; - } - case 'rank': { - const rank = inputTensors[i].dims.length; - inputInfos.push(`${type};${rank}`); - break; - } - case 'dims': { - const dims = inputTensors[i].dims.join(','); - inputInfos.push(`${type};${dims}`); - break; - } - default: - throw new Error(`unsupported input dependency: ${inputDependencies[i]}`); - } + const inputInfos: string[] = []; + for (let i = 0; i < inputTensors.length; ++i) { + const type = inputTensors[i].dataType; + switch (inputDependencies[i]) { + case 'none': { + inputInfos.push(''); + break; + } + case 'type': { + inputInfos.push(`${type}`); + break; + } + case 'rank': { + const rank = inputTensors[i].dims.length; + inputInfos.push(`${type};${rank}`); + break; } + case 'dims': { + const dims = inputTensors[i].dims.join(','); + inputInfos.push(`${type};${dims}`); + break; + } + default: + throw new Error(`unsupported input dependency: ${inputDependencies[i]}`); + } + } - return inputInfos.join('|'); - }; + return inputInfos.join('|'); +}; /** * get a unique key representing the program from the program info, input shapes and types. @@ -77,22 +92,27 @@ const getProgramInputTensorInfoDependencyKey = * program. if the key is the same, the program shader source should be the same, so we can reuse the program. * */ -const getProgramInfoUniqueKey = - (programInfo: ProgramInfo, inputTensors: readonly TensorView[], is1DimensionDispatch: boolean): string => { - // final key format: - // []:is1DimensionDispatch:||... - let key = programInfo.name; - if (programInfo.shaderCache?.hint) { - key += '[' + programInfo.shaderCache.hint + ']'; - } - key += ':' + is1DimensionDispatch + - `:${ - getProgramInputTensorInfoDependencyKey( - inputTensors, - programInfo.shaderCache?.inputDependencies ?? - new Array(inputTensors.length).fill('dims'))}`; - return key; - }; +const getProgramInfoUniqueKey = ( + programInfo: ProgramInfo, + inputTensors: readonly TensorView[], + is1DimensionDispatch: boolean, +): string => { + // final key format: + // []:is1DimensionDispatch:||... + let key = programInfo.name; + if (programInfo.shaderCache?.hint) { + key += '[' + programInfo.shaderCache.hint + ']'; + } + key += + ':' + + is1DimensionDispatch + + `:${getProgramInputTensorInfoDependencyKey( + inputTensors, + programInfo.shaderCache?.inputDependencies ?? + new Array(inputTensors.length).fill('dims'), + )}`; + return key; +}; class AdapterInfoImpl implements AdapterInfo { readonly architecture?: string; @@ -136,14 +156,14 @@ export class WebGpuBackend { * `null` means no session is being run. * only valid when session.run is executed. */ - currentSessionId: number|null = null; + currentSessionId: number | null = null; /** * representing the kernel ID of which is currently being computed (CPU code perspective). * `null` means no kernel is being computed. * only one kernel can be computed at a moment. */ - currentKernelId: number|null = null; + currentKernelId: number | null = null; /** * a list of temporary GPU data for the current kernel. should release when the kernel done computation. */ @@ -155,11 +175,11 @@ export class WebGpuBackend { /** * a KernelID -> a custom data, which stores custom data owned by the specific kernel. */ - private kernelCustomData: Map; + private kernelCustomData: Map; /** * get the custom data of the current kernel */ - get currentKernelCustomData(): {[key: string]: unknown} { + get currentKernelCustomData(): { [key: string]: unknown } { if (this.currentKernelId === null) { throw new Error('currentKernelCustomData(): currentKernelId is null. (should not happen)'); } @@ -175,8 +195,8 @@ export class WebGpuBackend { // KernelID -> kernelInfo mapping kernels: Map; - private commandEncoder: GPUCommandEncoder|null = null; - private computePassEncoder: GPUComputePassEncoder|null = null; + private commandEncoder: GPUCommandEncoder | null = null; + private computePassEncoder: GPUComputePassEncoder | null = null; maxDispatchNumber = 16; pendingDispatchNumber = 0; @@ -233,7 +253,7 @@ export class WebGpuBackend { } this.device = await adapter.requestDevice(deviceDescriptor); - this.adapterInfo = new AdapterInfoImpl(adapter.info || await adapter.requestAdapterInfo()); + this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo())); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); this.kernels = new Map(); @@ -245,17 +265,25 @@ export class WebGpuBackend { // TODO: set up flags - this.device.onuncapturederror = ev => { + this.device.onuncapturederror = (ev) => { if (ev.error instanceof GPUValidationError) { // eslint-disable-next-line no-console console.error(`An uncaught WebGPU validation error was raised: ${ev.error.message}`); } }; - Object.defineProperty( - this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false}); - Object.defineProperty( - this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false}); + Object.defineProperty(this.env.webgpu, 'device', { + value: this.device, + writable: false, + enumerable: true, + configurable: false, + }); + Object.defineProperty(this.env.webgpu, 'adapter', { + value: adapter, + writable: false, + enumerable: true, + configurable: false, + }); // init queryType, which is necessary for InferenceSession.create this.setQueryType(); @@ -311,16 +339,27 @@ export class WebGpuBackend { let queryReadBuffer: GPUBuffer; if (this.queryType !== 'none') { this.commandEncoder.resolveQuerySet( - this.querySet!, 0, this.pendingDispatchNumber * 2, this.queryResolveBuffer!, 0); + this.querySet!, + 0, + this.pendingDispatchNumber * 2, + this.queryResolveBuffer!, + 0, + ); queryReadBuffer = this.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: this.pendingDispatchNumber * 2 * 8, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST}); + // eslint-disable-next-line no-bitwise + { size: this.pendingDispatchNumber * 2 * 8, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }, + ); this.pendingQueries.set(queryReadBuffer, this.pendingKernels); this.pendingKernels = []; this.commandEncoder.copyBufferToBuffer( - this.queryResolveBuffer!, 0, queryReadBuffer, 0, this.pendingDispatchNumber * 2 * 8); + this.queryResolveBuffer!, + 0, + queryReadBuffer, + 0, + this.pendingDispatchNumber * 2 * 8, + ); } this.device.queue.submit([this.commandEncoder.finish()]); @@ -358,10 +397,14 @@ export class WebGpuBackend { if (this.env.webgpu.profiling?.ondata) { this.env.webgpu.profiling.ondata({ version: 1, - inputsMetadata: inputTensorViews.map( - value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})), - outputsMetadata: outputTensorViews.map( - value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})), + inputsMetadata: inputTensorViews.map((value) => ({ + dims: value.dims, + dataType: tensorDataTypeEnumToString(value.dataType), + })), + outputsMetadata: outputTensorViews.map((value) => ({ + dims: value.dims, + dataType: tensorDataTypeEnumToString(value.dataType), + })), kernelId, kernelType, kernelName, @@ -380,8 +423,11 @@ export class WebGpuBackend { outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; }); // eslint-disable-next-line no-console - console.log(`[profiling] kernel "${kernelId}|${kernelType}|${kernelName}|${programName}" ${inputShapes}${ - outputShapes}execution time: ${endTime - startTime} ns`); + console.log( + `[profiling] kernel "${kernelId}|${kernelType}|${kernelName}|${programName}" ${inputShapes}${ + outputShapes + }execution time: ${endTime - startTime} ns`, + ); } TRACE('GPU', `${programName}::${startTimeU64}::${endTimeU64}`); } @@ -403,10 +449,14 @@ export class WebGpuBackend { * or persistent (owned by the current kernel) * @returns a TensorView array representing the result. */ - run(program: ProgramInfo, inputTensorViews: readonly TensorView[], outputIndices: readonly number[], - createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView, - createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView, - outputCount: number): TensorView[] { + run( + program: ProgramInfo, + inputTensorViews: readonly TensorView[], + outputIndices: readonly number[], + createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView, + createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView, + outputCount: number, + ): TensorView[] { TRACE_FUNC_BEGIN(program.name); // create info for inputs const inputDatas: GpuData[] = []; @@ -423,7 +473,7 @@ export class WebGpuBackend { inputDatas.push(gpuData); } - const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); + const { outputs, dispatchGroup, programUniforms } = program.getRunData(inputTensorViews); // check output indices const validatedOutputIndices = outputIndices.length === 0 ? outputs.map((_, i) => i) : outputIndices; @@ -438,8 +488,11 @@ export class WebGpuBackend { // value -1 and -2 are used for creating temporary and persistent outputs. // value -3 is used for placeholder output. So -3, -2, -1 and 0, 1, 2, ... are valid // output indices. see type definition of ComputeContextInputsOutputsMapping for more details. - if (!Number.isInteger(validatedOutputIndices[i]) || validatedOutputIndices[i] < -3 || - validatedOutputIndices[i] >= outputCount) { + if ( + !Number.isInteger(validatedOutputIndices[i]) || + validatedOutputIndices[i] < -3 || + validatedOutputIndices[i] >= outputCount + ) { throw new Error(`Invalid output index: ${validatedOutputIndices[i]}`); } if (validatedOutputIndices[i] === -3) { @@ -447,9 +500,10 @@ export class WebGpuBackend { } const isTemporary = validatedOutputIndices[i] === -1; const isPersistent = validatedOutputIndices[i] === -2; - const tensorView = (isTemporary || isPersistent) ? - createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : - createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); + const tensorView = + isTemporary || isPersistent + ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims) + : createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); outputTensorViews.push(tensorView); // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. if (tensorView.data === 0) { @@ -486,18 +540,19 @@ export class WebGpuBackend { // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors. // If we see such use case, we need to make a change here to support it. throw new Error( - `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`); + `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`, + ); } // load uniforms // TODO: add cache for uniform (is it necessary?) // - let uniformBufferBinding: GPUBindingResource|undefined; + let uniformBufferBinding: GPUBindingResource | undefined; if (programUniforms) { let currentOffset = 0; const offsets: number[] = []; - programUniforms.forEach(v => { + programUniforms.forEach((v) => { const data = typeof v.data === 'number' ? [v.data] : v.data; if (data.length === 0) { return; @@ -507,7 +562,7 @@ export class WebGpuBackend { let sizeOfVecOrMat; let baseAlignment; if (v.type === DataType.float16) { - baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); + baseAlignment = data.length > 4 ? 16 : data.length > 2 ? 8 : data.length * sizeOfElement; sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; } else { baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16; @@ -521,8 +576,8 @@ export class WebGpuBackend { // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte // length is N * SizeOf(mat2x4). const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4; - currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : - data.length * sizeOfElement; + currentOffset += + data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : data.length * sizeOfElement; }); // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set @@ -548,11 +603,11 @@ export class WebGpuBackend { }); const uniformBufferData = - // eslint-disable-next-line no-bitwise - this.gpuDataManager.create(currentOffset, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM); + // eslint-disable-next-line no-bitwise + this.gpuDataManager.create(currentOffset, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM); this.device.queue.writeBuffer(uniformBufferData.buffer, 0, arrayBuffer, 0, currentOffset); this.gpuDataManager.release(uniformBufferData.id); - uniformBufferBinding = {offset: 0, size: currentOffset, buffer: uniformBufferData.buffer}; + uniformBufferBinding = { offset: 0, size: currentOffset, buffer: uniformBufferData.buffer }; } const normalizedDispatchGroup = this.programManager.normalizeDispatchGroupSize(dispatchGroup); @@ -569,8 +624,11 @@ export class WebGpuBackend { // validate uniform variables if (programUniforms && artifact.uniformVariablesInfo) { if (programUniforms.length !== artifact.uniformVariablesInfo.length) { - throw new Error(`Uniform variables count mismatch: expect ${artifact.uniformVariablesInfo.length}, got ${ - programUniforms.length} in program "${artifact.programInfo.name}".`); + throw new Error( + `Uniform variables count mismatch: expect ${artifact.uniformVariablesInfo.length}, got ${ + programUniforms.length + } in program "${artifact.programInfo.name}".`, + ); } for (let i = 0; i < programUniforms.length; i++) { const uniform = programUniforms[i]; @@ -578,16 +636,22 @@ export class WebGpuBackend { const actualLength = typeof uniform.data === 'number' ? 1 : uniform.data.length; const [type, length] = artifact.uniformVariablesInfo[i]; if (actualType !== type || actualLength !== length) { - throw new Error(`Uniform variable ${i} mismatch: expect type ${type} with size ${length}, got type ${ - actualType} with size ${actualLength} in program "${artifact.programInfo.name}".`); + throw new Error( + `Uniform variable ${i} mismatch: expect type ${type} with size ${length}, got type ${ + actualType + } with size ${actualLength} in program "${artifact.programInfo.name}".`, + ); } } } LOG_DEBUG( - 'info', - () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ - normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); + 'info', + () => + `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ + normalizedDispatchGroup[1] + }x${normalizedDispatchGroup[2]}`, + ); if (this.queryType !== 'none' || this.sessionStatus === 'capturing') { const pendingKernelInfo: PendingKernelInfo = { @@ -660,7 +724,7 @@ export class WebGpuBackend { this.kernels.delete(kernelId); } - computeKernel(kernelId: number, context: ComputeContext, errors: Array>): number { + computeKernel(kernelId: number, context: ComputeContext, errors: Array>): number { const kernel = this.kernels.get(kernelId); if (!kernel) { throw new Error(`kernel not created: ${kernelId}`); @@ -691,14 +755,19 @@ export class WebGpuBackend { } kernelEntry(context, attributes[1]); - return 0; // ORT_OK + return 0; // ORT_OK } catch (e) { errors.push(Promise.resolve(`[WebGPU] Kernel "[${kernelType}] ${kernelName}" failed. ${e}`)); - return 1; // ORT_FAIL + return 1; // ORT_FAIL } finally { if (useErrorScope) { - errors.push(this.device.popErrorScope().then( - err => err ? `GPU validation error for kernel "[${kernelType}] ${kernelName}": ${err.message}` : null)); + errors.push( + this.device + .popErrorScope() + .then((err) => + err ? `GPU validation error for kernel "[${kernelType}] ${kernelName}": ${err.message}` : null, + ), + ); } for (const data of this.temporaryData) { @@ -725,7 +794,7 @@ export class WebGpuBackend { unregisterBuffers(sessionId: number): void { const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); if (sessionInputOutputMapping) { - sessionInputOutputMapping.forEach(bufferInfo => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1])); + sessionInputOutputMapping.forEach((bufferInfo) => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1])); this.sessionExternalDataMapping.delete(sessionId); } } @@ -736,8 +805,11 @@ export class WebGpuBackend { } return gpuData.buffer; } - createDownloader(gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes): - () => Promise { + createDownloader( + gpuBuffer: GPUBuffer, + size: number, + type: Tensor.GpuBufferDataTypes, + ): () => Promise { return async () => { const data = await downloadGpuData(this, gpuBuffer, size); return createView(data.buffer, type); @@ -754,8 +826,10 @@ export class WebGpuBackend { } setQueryType(): void { this.queryType = 'none'; - if (this.env.webgpu.profiling?.mode === 'default' || - (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) { + if ( + this.env.webgpu.profiling?.mode === 'default' || + (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace) + ) { if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) { this.queryType = 'inside-passes'; } else if (this.device.features.has('timestamp-query')) { @@ -768,8 +842,9 @@ export class WebGpuBackend { count: this.maxDispatchNumber * 2, }); this.queryResolveBuffer = this.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); + // eslint-disable-next-line no-bitwise + { size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE }, + ); } } } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 242f7e939cda0..ab24fa31909be 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -1,31 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from 'onnxruntime-common'; +import { Env } from 'onnxruntime-common'; -import type {OrtWasmModule} from '../wasm-types'; -import {DataType, getTensorElementSize} from '../wasm-common'; +import type { OrtWasmModule } from '../wasm-types'; +import { DataType, getTensorElementSize } from '../wasm-common'; -import {WebGpuBackend} from './backend-webgpu'; -import {LOG_DEBUG} from './log'; -import {TensorView} from './tensor-view'; -import {ShapeUtil} from './util'; -import {AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; +import { WebGpuBackend } from './backend-webgpu'; +import { LOG_DEBUG } from './log'; +import { TensorView } from './tensor-view'; +import { ShapeUtil } from './util'; +import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; /* eslint-disable no-bitwise */ class TensorViewImpl implements TensorView { constructor( - private module: OrtWasmModule, public readonly dataType: number, public readonly data: number, - public readonly dims: readonly number[]) {} + private module: OrtWasmModule, + public readonly dataType: number, + public readonly data: number, + public readonly dims: readonly number[], + ) {} getFloat32Array(): Float32Array { if (this.dataType !== DataType.float) { throw new Error('Invalid data type'); } const elementCount = ShapeUtil.size(this.dims); - return elementCount === 0 ? new Float32Array() : - new Float32Array(this.module.HEAP8.buffer, this.data, elementCount); + return elementCount === 0 + ? new Float32Array() + : new Float32Array(this.module.HEAP8.buffer, this.data, elementCount); } getBigInt64Array(): BigInt64Array { @@ -33,8 +37,9 @@ class TensorViewImpl implements TensorView { throw new Error('Invalid data type'); } const elementCount = ShapeUtil.size(this.dims); - return elementCount === 0 ? new BigInt64Array() : - new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount); + return elementCount === 0 + ? new BigInt64Array() + : new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount); } getInt32Array(): Int32Array { @@ -58,7 +63,7 @@ class ComputeContextImpl implements ComputeContext { readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; - get kernelCustomData(): {[key: string]: unknown} { + get kernelCustomData(): { [key: string]: unknown } { return this.backend.currentKernelCustomData; } get customDataBuffer(): Uint8Array { @@ -66,12 +71,16 @@ class ComputeContextImpl implements ComputeContext { } private customDataOffset = 0; private customDataSize = 0; - constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { + constructor( + private module: OrtWasmModule, + private backend: WebGpuBackend, + contextDataOffset: number, + ) { this.adapterInfo = backend.adapterInfo; const heapU32 = module.HEAPU32; // extract context data - let dataIndex = (contextDataOffset >>> 2); + let dataIndex = contextDataOffset >>> 2; this.opKernelContext = heapU32[dataIndex++]; const inputCount = heapU32[dataIndex++]; this.outputCount = heapU32[dataIndex++]; @@ -94,8 +103,9 @@ class ComputeContextImpl implements ComputeContext { getMaxComputeWorkgroupSizes(): [number, number, number] { return [ - this.backend.device.limits.maxComputeWorkgroupSizeX, this.backend.device.limits.maxComputeWorkgroupSizeY, - this.backend.device.limits.maxComputeWorkgroupSizeZ + this.backend.device.limits.maxComputeWorkgroupSizeX, + this.backend.device.limits.maxComputeWorkgroupSizeY, + this.backend.device.limits.maxComputeWorkgroupSizeZ, ]; } @@ -106,11 +116,11 @@ class ComputeContextImpl implements ComputeContext { compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] { // prepare inputs. inputs should always be valid data. const mappedInputs = - inputsOutputsMapping?.inputs?.map(i => typeof i === 'number' ? this.inputs[i] : i) ?? this.inputs; + inputsOutputsMapping?.inputs?.map((i) => (typeof i === 'number' ? this.inputs[i] : i)) ?? this.inputs; // prepare outputs. const outputIndices = inputsOutputsMapping?.outputs ?? []; const createKernelOutput = (index: number, dataType: number, dims: readonly number[]): TensorView => - new TensorViewImpl(this.module, dataType, this.output(index, dims), dims); + new TensorViewImpl(this.module, dataType, this.output(index, dims), dims); const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => { const elementSize = getTensorElementSize(dataType); if (!elementSize) { @@ -121,7 +131,13 @@ class ComputeContextImpl implements ComputeContext { return new TensorViewImpl(this.module, dataType, gpuDataId, dims); }; return this.backend.run( - program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput, this.outputCount); + program, + mappedInputs, + outputIndices, + createKernelOutput, + createTemporaryOutput, + this.outputCount, + ); } output(index: number, dims: readonly number[]): number { @@ -136,9 +152,10 @@ class ComputeContextImpl implements ComputeContext { return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { throw new Error( - `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + + `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + 'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' + - `Error: ${e}`); + `Error: ${e}`, + ); } finally { this.module.stackRestore(stack); } @@ -169,8 +186,12 @@ class ComputeContextImpl implements ComputeContext { * @param env - the ORT environment variable (ort.env) * @param gpuAdapter - the pre-created GPU adapter */ -export const init = - async(name: 'webgpu'|'webnn', module: OrtWasmModule, env: Env, gpuAdapter?: GPUAdapter): Promise => { +export const init = async ( + name: 'webgpu' | 'webnn', + module: OrtWasmModule, + env: Env, + gpuAdapter?: GPUAdapter, +): Promise => { const jsepInit = module.jsepInit; if (!jsepInit) { throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.'); @@ -203,29 +224,31 @@ export const init = }, // jsepCopyAsync(src, dst, size) - async(gpuDataId: number, dataOffset: number, size: number): - Promise => { - LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); + async (gpuDataId: number, dataOffset: number, size: number): Promise => { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, + ); - await backend.download( - gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size)); - }, + await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size)); + }, // jsepCreateKernel - (kernelType: string, kernelId: number, attribute: unknown) => backend.createKernel( - kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))), + (kernelType: string, kernelId: number, attribute: unknown) => + backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), // jsepRun - (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ - contextDataOffset}`); + 'verbose', + () => + `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ + contextDataOffset + }`, + ); const context = new ComputeContextImpl(module, backend, contextDataOffset); return backend.computeKernel(kernel, context, errors); }, @@ -234,7 +257,7 @@ export const init = // jsepCaptureEnd () => backend.captureEnd(), // jsepReplay - () => backend.replay() + () => backend.replay(), ]); } else { jsepInit('webnn'); diff --git a/js/web/lib/wasm/jsep/log.ts b/js/web/lib/wasm/jsep/log.ts index cb7d828611206..27a0f7b11a2be 100644 --- a/js/web/lib/wasm/jsep/log.ts +++ b/js/web/lib/wasm/jsep/log.ts @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from 'onnxruntime-common'; +import { Env } from 'onnxruntime-common'; -import {logLevelStringToEnum} from '../wasm-common'; +import { logLevelStringToEnum } from '../wasm-common'; type LogLevel = NonNullable; type MessageString = string; type MessageFunction = () => string; -type Message = MessageString|MessageFunction; +type Message = MessageString | MessageFunction; const logLevelPrefix = ['V', 'I', 'W', 'E', 'F']; @@ -17,8 +17,8 @@ const doLog = (level: number, message: string): void => { console.log(`[${logLevelPrefix[level]},${new Date().toISOString()}]${message}`); }; -let configLogLevel: LogLevel|undefined; -let debug: boolean|undefined; +let configLogLevel: LogLevel | undefined; +let debug: boolean | undefined; export const configureLogger = ($configLogLevel: LogLevel, $debug: boolean): void => { configLogLevel = $configLogLevel; diff --git a/js/web/lib/wasm/jsep/tensor-view.ts b/js/web/lib/wasm/jsep/tensor-view.ts index 69b9287f6de29..defc418c29264 100644 --- a/js/web/lib/wasm/jsep/tensor-view.ts +++ b/js/web/lib/wasm/jsep/tensor-view.ts @@ -1,13 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from 'onnxruntime-common'; - -import {tensorTypeToTypedArrayConstructor} from '../wasm-common'; - -export const createView = (dataBuffer: ArrayBuffer, type: Tensor.Type): Int32Array|Uint32Array|BigInt64Array| - BigUint64Array|Uint8Array|Float32Array|Float64Array|Int8Array|Int16Array|Uint16Array => - new (tensorTypeToTypedArrayConstructor(type))(dataBuffer); +import { Tensor } from 'onnxruntime-common'; + +import { tensorTypeToTypedArrayConstructor } from '../wasm-common'; + +export const createView = ( + dataBuffer: ArrayBuffer, + type: Tensor.Type, +): + | Int32Array + | Uint32Array + | BigInt64Array + | BigUint64Array + | Uint8Array + | Float32Array + | Float64Array + | Int8Array + | Int16Array + | Uint16Array => new (tensorTypeToTypedArrayConstructor(type))(dataBuffer); /** * a TensorView does not own the data. diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 9a1d5463f7843..5ae16d5625dc8 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -10,12 +10,11 @@ export class MatMulUtil { * @param b The shape of tensor B. Should be a tuple of 2 positive integers * @returns The expected shape of the result, or undefined if N/A */ - static calcMatMulShape(a: [number, number], b: [number, number]): [number, number]|undefined { - return (a[1] !== b[0]) ? undefined : [a[0], b[1]]; + static calcMatMulShape(a: [number, number], b: [number, number]): [number, number] | undefined { + return a[1] !== b[0] ? undefined : [a[0], b[1]]; } } - export class BroadcastUtil { /** * Calculate the expected shape when broadcasting 2 tensors @@ -24,7 +23,11 @@ export class BroadcastUtil { * @param isMatMul Whether the operation is MatMul * @returns The expected shape of the result, or undefined if N/A */ - static calcShape(adims: readonly number[], bdims: readonly number[], isMatMul = false): readonly number[]|undefined { + static calcShape( + adims: readonly number[], + bdims: readonly number[], + isMatMul = false, + ): readonly number[] | undefined { const arank = adims.length; const brank = bdims.length; if (arank === 0) { @@ -41,8 +44,10 @@ export class BroadcastUtil { if (arank < 2 || brank < 2) { return undefined; } - const cShapeMatMul = - MatMulUtil.calcMatMulShape([adims[arank - 2], adims[arank - 1]], [bdims[brank - 2], bdims[brank - 1]]); + const cShapeMatMul = MatMulUtil.calcMatMulShape( + [adims[arank - 2], adims[arank - 1]], + [bdims[brank - 2], bdims[brank - 1]], + ); if (cShapeMatMul === undefined) { return undefined; } @@ -92,7 +97,6 @@ export class BroadcastUtil { } } - export class ShapeUtil { /** * calculate the size (number of elements) @@ -159,8 +163,9 @@ export class ShapeUtil { // size cannot be negative. if (dims[i] < 0) { throw new Error( - // eslint-disable-next-line max-len - 'cannot get valid size from specified dimension range. Most likely the range contains negative values in them.'); + // eslint-disable-next-line max-len + 'cannot get valid size from specified dimension range. Most likely the range contains negative values in them.', + ); } size *= dims[i]; } @@ -194,7 +199,7 @@ export class ShapeUtil { } static normalizeAxes(axes: readonly number[], tensorRank?: number): number[] { - return axes.map(x => this.normalizeAxis(x, tensorRank ?? axes.length)); + return axes.map((x) => this.normalizeAxis(x, tensorRank ?? axes.length)); } /** @@ -245,8 +250,13 @@ export class PoolConvUtil { * @param pads Padding for the beginning and ending along each axis. */ static adjustPoolAttributes( - isGlobalOperator: boolean, inputDims: readonly number[], kernelShape: number[], strides: number[], - dilations: number[], pads: number[]): void { + isGlobalOperator: boolean, + inputDims: readonly number[], + kernelShape: number[], + strides: number[], + dilations: number[], + pads: number[], + ): void { if (!isGlobalOperator && kernelShape.length !== inputDims.length - 2) { throw new Error('length of specified kernel shapes should be 2 less than length of input dimensions'); } @@ -309,8 +319,14 @@ export class PoolConvUtil { // adjust pad values based on 'autoPad' attribute static adjustPadsBasedOnAutoPad( - inputDims: readonly number[], strides: readonly number[], dilations: readonly number[], - kernelShape: readonly number[], pads: number[], isChannelLast: boolean, autoPad?: string): void { + inputDims: readonly number[], + strides: readonly number[], + dilations: readonly number[], + kernelShape: readonly number[], + pads: number[], + isChannelLast: boolean, + autoPad?: string, + ): void { if (!autoPad) { return; } @@ -319,18 +335,25 @@ export class PoolConvUtil { throw new Error('length of pads should be twice the length of data dimensions'); } - if (strides.length !== (inputDims.length - 2)) { + if (strides.length !== inputDims.length - 2) { throw new Error('length of strides should be the length of data dimensions'); } - if (kernelShape.length !== (inputDims.length - 2)) { + if (kernelShape.length !== inputDims.length - 2) { throw new Error('length of kernel shapes should be the length of data dimensions'); } for (let dim = 0; dim < inputDims.length - 2; dim++) { PoolConvUtil.adjustPadAndReturnShape( - inputDims[dim + (isChannelLast ? 1 : 2)], strides[dim], dilations[dim], kernelShape[dim], pads, dim, - dim + inputDims.length - 2, autoPad); + inputDims[dim + (isChannelLast ? 1 : 2)], + strides[dim], + dilations[dim], + kernelShape[dim], + pads, + dim, + dim + inputDims.length - 2, + autoPad, + ); } } @@ -346,8 +369,14 @@ export class PoolConvUtil { * dimension. Can take values NOTSET, SAME_UPPER, SAME_LOWER, or VALID. */ static computePoolOutputShape( - isGlobalOperator: boolean, inputDims: readonly number[], strides: number[], dilations: number[], - kernelShape: number[], pads: number[], autoPad?: string): number[] { + isGlobalOperator: boolean, + inputDims: readonly number[], + strides: number[], + dilations: number[], + kernelShape: number[], + pads: number[], + autoPad?: string, + ): number[] { if (inputDims.length <= 0) { throw new Error('input shape must be of size greater than 0'); } @@ -356,7 +385,15 @@ export class PoolConvUtil { const outputDims = [inputDims[0], inputDims[1]]; PoolConvUtil.computeShapeHelper( - isGlobalOperator, inputDims, outputDims, strides, dilations, kernelShape, pads, autoPad); + isGlobalOperator, + inputDims, + outputDims, + strides, + dilations, + kernelShape, + pads, + autoPad, + ); return outputDims; } @@ -371,8 +408,14 @@ export class PoolConvUtil { * dimension. Can take values NOTSET, SAME_UPPER, SAME_LOWER, or VALID. */ static computeConvOutputShape( - inputDims: readonly number[], filterDims: readonly number[], strides: number[], dilations: number[], - kernelShape: number[], pads: number[], autoPad?: string): number[] { + inputDims: readonly number[], + filterDims: readonly number[], + strides: number[], + dilations: number[], + kernelShape: number[], + pads: number[], + autoPad?: string, + ): number[] { if (inputDims.length <= 0 || filterDims.length <= 0) { throw new Error('invalid input tensor dims or invalid filter tensor dims'); } @@ -388,17 +431,33 @@ export class PoolConvUtil { // called by computePoolOutputShape() and computeConvOutputShape() // adjust pads based on 'autoPad' attribute prior to shape computation private static computeShapeHelper( - isGlobalOperator: boolean, inputDims: readonly number[], outputDims: number[], strides: readonly number[], - dilations: readonly number[], kernelShape: readonly number[], pads: number[], autoPad?: string) { + isGlobalOperator: boolean, + inputDims: readonly number[], + outputDims: number[], + strides: readonly number[], + dilations: readonly number[], + kernelShape: readonly number[], + pads: number[], + autoPad?: string, + ) { if (isGlobalOperator) { for (let dim = 0; dim < inputDims.length - 2; dim++) { outputDims.push(1); } } else { for (let dim = 0; dim < inputDims.length - 2; dim++) { - outputDims.push(PoolConvUtil.adjustPadAndReturnShape( - inputDims[dim + 2], strides[dim], dilations[dim], kernelShape[dim], pads, dim, dim + inputDims.length - 2, - autoPad)); + outputDims.push( + PoolConvUtil.adjustPadAndReturnShape( + inputDims[dim + 2], + strides[dim], + dilations[dim], + kernelShape[dim], + pads, + dim, + dim + inputDims.length - 2, + autoPad, + ), + ); } } } @@ -406,15 +465,22 @@ export class PoolConvUtil { // helper for computeShapeHelper() and adjustPadsBasedOnAutoPad() // adjusts pad value for given 'autoPad' string and computes output shape along a particular dimension private static adjustPadAndReturnShape( - inSize: number, stride: number, dilation: number, kernel: number, pads: number[], padHeadIndex: number, - padTailIndex: number, autoPad?: string): number { + inSize: number, + stride: number, + dilation: number, + kernel: number, + pads: number[], + padHeadIndex: number, + padTailIndex: number, + autoPad?: string, + ): number { const dkernel = dilation * (kernel - 1) + 1; if (autoPad && autoPad !== 'NOTSET') { switch (autoPad) { case 'VALID': pads[padHeadIndex] = 0; pads[padTailIndex] = 0; - return Math.floor(((inSize - dkernel) / stride) + 1); + return Math.floor((inSize - dkernel) / stride + 1); case 'SAME_LOWER': case 'SAME_UPPER': if (dilation !== 1) { @@ -422,16 +488,15 @@ export class PoolConvUtil { } else { const legacyTargetSize = (inSize + stride - 1) / stride; const padNeeded = (legacyTargetSize - 1) * stride + kernel - inSize; - pads[padHeadIndex] = - (autoPad === 'SAME_LOWER') ? Math.floor((padNeeded + 1) / 2) : Math.floor(padNeeded / 2); + pads[padHeadIndex] = autoPad === 'SAME_LOWER' ? Math.floor((padNeeded + 1) / 2) : Math.floor(padNeeded / 2); pads[padTailIndex] = padNeeded - pads[padHeadIndex]; - return Math.floor(((inSize + padNeeded - kernel) / stride) + 1); + return Math.floor((inSize + padNeeded - kernel) / stride + 1); } default: throw new Error('Unsupported AutoPad type'); } } else { - return Math.floor(((inSize + pads[padHeadIndex] + pads[padTailIndex] - dkernel) / stride) + 1); + return Math.floor((inSize + pads[padHeadIndex] + pads[padTailIndex] - dkernel) / stride + 1); } } } @@ -441,8 +506,12 @@ export class GemmUtil { // and return back the shape of the output in the form of a tuple // will throw exception if the input shapes are not compatible static getShapeOfGemmResult( - leftShape: readonly number[], transLeft: boolean, rightShape: readonly number[], transRight: boolean, - biasShape?: readonly number[]): readonly number[] { + leftShape: readonly number[], + transLeft: boolean, + rightShape: readonly number[], + transRight: boolean, + biasShape?: readonly number[], + ): readonly number[] { if (leftShape.length !== 2 || rightShape.length !== 2) { throw new Error('shape need to be of size 2'); } @@ -485,6 +554,5 @@ export class GemmUtil { } } - -export const MIN_CLIP = -3.4028234663852886e+38; -export const MAX_CLIP = 3.4028234663852886e+38; +export const MIN_CLIP = -3.4028234663852886e38; +export const MAX_CLIP = 3.4028234663852886e38; diff --git a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts index ad56b92c1d869..19c25f9cba761 100644 --- a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts +++ b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts @@ -9,8 +9,10 @@ class AttributeWithCacheKeyImpl { private key: string; public get cacheKey(): string { if (!this.key) { - this.key = - Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); + this.key = Object.getOwnPropertyNames(this) + .sort() + .map((name) => `${(this as Record)[name]}`) + .join(';'); } return this.key; } @@ -23,5 +25,6 @@ export interface AttributeWithCacheKey { /** * create a new object from the given attribute, and add a cacheKey property to it */ -export const createAttributeWithCacheKey = >(attribute: T): T&AttributeWithCacheKey => - new AttributeWithCacheKeyImpl(attribute) as unknown as T & AttributeWithCacheKey; +export const createAttributeWithCacheKey = >( + attribute: T, +): T & AttributeWithCacheKey => new AttributeWithCacheKeyImpl(attribute) as unknown as T & AttributeWithCacheKey; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index a5c0a088efa6e..8e18a28acc364 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {WebGpuBackend} from '../backend-webgpu'; -import {LOG_DEBUG} from '../log'; +import { WebGpuBackend } from '../backend-webgpu'; +import { LOG_DEBUG } from '../log'; -import {GpuData, GpuDataId, GpuDataType} from './types'; +import { GpuData, GpuDataId, GpuDataType } from './types'; /** * manages GpuDataId -> GpuBuffer @@ -25,7 +25,7 @@ export interface GpuDataManager { /** * get GPU data by ID. */ - get(id: GpuDataId): GpuData|undefined; + get(id: GpuDataId): GpuData | undefined; /** * release the data on GPU by ID. * @@ -141,39 +141,46 @@ const createNewGpuDataId = () => guid++; * @param getTargetBuffer - optional. If provided, the data will be copied to the target buffer. Otherwise, a new buffer * will be created and returned. */ -export const downloadGpuData = - async(backend: WebGpuBackend, gpuBuffer: GPUBuffer, originalSize: number, getTargetBuffer?: () => Uint8Array): - Promise => { - const bufferSize = calcNormalizedBufferSize(originalSize); - const gpuReadBuffer = backend.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); - try { - const commandEncoder = backend.getCommandEncoder(); - backend.endComputePass(); - commandEncoder.copyBufferToBuffer( - gpuBuffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, - 0 /* destination offset */, bufferSize /* size */ - ); - backend.flush(); - - await gpuReadBuffer.mapAsync(GPUMapMode.READ); - - const arrayBuffer = gpuReadBuffer.getMappedRange(); - if (getTargetBuffer) { - // if we already have a CPU buffer to accept the data, no need to clone the ArrayBuffer. - const targetBuffer = getTargetBuffer(); - targetBuffer.set(new Uint8Array(arrayBuffer, 0, originalSize)); - return targetBuffer; - } else { - // the mapped ArrayBuffer will be released when the GPU buffer is destroyed. Need to clone the - // ArrayBuffer. - return new Uint8Array(arrayBuffer.slice(0, originalSize)); - } - } finally { - gpuReadBuffer.destroy(); - } - }; +export const downloadGpuData = async ( + backend: WebGpuBackend, + gpuBuffer: GPUBuffer, + originalSize: number, + getTargetBuffer?: () => Uint8Array, +): Promise => { + const bufferSize = calcNormalizedBufferSize(originalSize); + const gpuReadBuffer = backend.device.createBuffer( + // eslint-disable-next-line no-bitwise + { size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }, + ); + try { + const commandEncoder = backend.getCommandEncoder(); + backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + gpuBuffer /* source buffer */, + 0 /* source offset */, + gpuReadBuffer /* destination buffer */, + 0 /* destination offset */, + bufferSize /* size */, + ); + backend.flush(); + + await gpuReadBuffer.mapAsync(GPUMapMode.READ); + + const arrayBuffer = gpuReadBuffer.getMappedRange(); + if (getTargetBuffer) { + // if we already have a CPU buffer to accept the data, no need to clone the ArrayBuffer. + const targetBuffer = getTargetBuffer(); + targetBuffer.set(new Uint8Array(arrayBuffer, 0, originalSize)); + return targetBuffer; + } else { + // the mapped ArrayBuffer will be released when the GPU buffer is destroyed. Need to clone the + // ArrayBuffer. + return new Uint8Array(arrayBuffer.slice(0, originalSize)); + } + } finally { + gpuReadBuffer.destroy(); + } +}; class GpuDataManagerImpl implements GpuDataManager { // GPU Data ID => GPU Data ( storage buffer ) @@ -205,7 +212,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.externalBuffers = new Map(); this.capturedPendingBuffers = new Map(); - for (const [key, ] of bucketFreelist) { + for (const [key] of bucketFreelist) { bucketArr.push(key); this.freeBuffers.set(key, []); this.freeUniformBuffers.set(key, []); @@ -229,15 +236,15 @@ class GpuDataManagerImpl implements GpuDataManager { // create gpu buffer const gpuBufferForUploading = this.backend.device.createBuffer( - // eslint-disable-next-line no-bitwise - {mappedAtCreation: true, size, usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC}); + // eslint-disable-next-line no-bitwise + { mappedAtCreation: true, size, usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC }, + ); // copy (upload) data const arrayBuffer = gpuBufferForUploading.getMappedRange(); new Uint8Array(arrayBuffer).set(new Uint8Array(srcArrayBuffer, srcOffset, srcLength)); gpuBufferForUploading.unmap(); - // GPU copy const commandEncoder = this.backend.getCommandEncoder(); this.backend.endComputePass(); @@ -269,11 +276,16 @@ class GpuDataManagerImpl implements GpuDataManager { const commandEncoder = this.backend.getCommandEncoder(); this.backend.endComputePass(); commandEncoder.copyBufferToBuffer( - sourceGpuDataCache.gpuData.buffer, 0, destinationGpuDataCache.gpuData.buffer, 0, size); + sourceGpuDataCache.gpuData.buffer, + 0, + destinationGpuDataCache.gpuData.buffer, + 0, + size, + ); } registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number { - let id: number|undefined; + let id: number | undefined; if (previousBuffer) { id = this.externalBuffers.get(previousBuffer); if (id === undefined) { @@ -281,9 +293,12 @@ class GpuDataManagerImpl implements GpuDataManager { } if (buffer === previousBuffer) { LOG_DEBUG( - 'verbose', - () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ - id}, buffer is the same, skip.`); + 'verbose', + () => + `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ + id + }, buffer is the same, skip.`, + ); return id; } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. @@ -294,11 +309,12 @@ class GpuDataManagerImpl implements GpuDataManager { id = createNewGpuDataId(); } - this.storageCache.set(id, {gpuData: {id, type: GpuDataType.default, buffer}, originalSize}); + this.storageCache.set(id, { gpuData: { id, type: GpuDataType.default, buffer }, originalSize }); this.externalBuffers.set(buffer, id); LOG_DEBUG( - 'verbose', - () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`); + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`, + ); return id; } @@ -326,29 +342,29 @@ class GpuDataManagerImpl implements GpuDataManager { const buffers = freeBuffers.get(bufferSize); if (!buffers) { // no such bucket/freelist - create gpu buffer - gpuBuffer = this.backend.device.createBuffer({size: bufferSize, usage}); + gpuBuffer = this.backend.device.createBuffer({ size: bufferSize, usage }); } else { if (buffers.length > 0) { // in freelist, use it gpuBuffer = buffers.pop() as GPUBuffer; } else { // bucket empty, create gpu buffer - gpuBuffer = this.backend.device.createBuffer({size: bufferSize, usage}); + gpuBuffer = this.backend.device.createBuffer({ size: bufferSize, usage }); } } } else { // create gpu buffer - gpuBuffer = this.backend.device.createBuffer({size: bufferSize, usage}); + gpuBuffer = this.backend.device.createBuffer({ size: bufferSize, usage }); } - const gpuData = {id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer}; - this.storageCache.set(gpuData.id, {gpuData, originalSize: size}); + const gpuData = { id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer }; + this.storageCache.set(gpuData.id, { gpuData, originalSize: size }); LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.create(size=${size}) => id=${gpuData.id}`); return gpuData; } - get(id: GpuDataId): GpuData|undefined { + get(id: GpuDataId): GpuData | undefined { return this.storageCache.get(id)?.gpuData; } @@ -430,12 +446,12 @@ class GpuDataManagerImpl implements GpuDataManager { dispose() { this.freeBuffers.forEach((buffers) => { - buffers.forEach(buffer => { + buffers.forEach((buffer) => { buffer.destroy(); }); }); this.freeUniformBuffers.forEach((buffers) => { - buffers.forEach(buffer => { + buffers.forEach((buffer) => { buffer.destroy(); }); }); @@ -445,7 +461,7 @@ class GpuDataManagerImpl implements GpuDataManager { }); this.capturedPendingBuffers.forEach((buffers) => { - buffers.forEach(buffer => { + buffers.forEach((buffer) => { buffer.destroy(); }); }); @@ -459,7 +475,7 @@ class GpuDataManagerImpl implements GpuDataManager { // release the captured pending buffers. const pendingBuffers = this.capturedPendingBuffers.get(sessionId); if (pendingBuffers) { - pendingBuffers.forEach(buffer => { + pendingBuffers.forEach((buffer) => { buffer.destroy(); }); this.capturedPendingBuffers.delete(sessionId); @@ -468,4 +484,4 @@ class GpuDataManagerImpl implements GpuDataManager { } export const createGpuDataManager = (...args: ConstructorParameters): GpuDataManager => - new GpuDataManagerImpl(...args); + new GpuDataManagerImpl(...args); diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index e0288eebbe604..0808d45a307ca 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -1,49 +1,60 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; -import {attention} from './ops/attention'; -import {batchNorm} from './ops/batch-norm'; -import {biasAdd} from './ops/bias-add'; -import {biasSplitGelu} from './ops/bias-split-gelu'; +import { argMax, argMin, parseArgMinMaxAttributes } from './ops/argminmax'; +import { attention } from './ops/attention'; +import { batchNorm } from './ops/batch-norm'; +import { biasAdd } from './ops/bias-add'; +import { biasSplitGelu } from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; -import {concat, parseConcatAttributes} from './ops/concat'; -import {conv, parseConvAttributes} from './ops/conv'; -import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; -import {cumsum, parseCumSumAttributes} from './ops/cumsum'; -import {depthToSpace, parseDepthToSpaceAttributes} from './ops/depth-to-space'; -import {einsum, parseEinsumAttributes} from './ops/einsum'; -import {expand} from './ops/expand'; -import {fastGelu} from './ops/fast-gelu'; -import {gather, parseGatherAttributes} from './ops/gather'; -import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; -import {gemm, parseGemmAttributes} from './ops/gemm'; -import {groupQueryAttention, parseGroupQueryAttentionAttributes} from './ops/group-query-attention'; -import {instanceNorm} from './ops/instance-norm'; -import {layerNorm} from './ops/layer-norm'; -import {matMul} from './ops/matmul'; -import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits'; -import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multihead-attention'; -import {pad} from './ops/pad'; +import { concat, parseConcatAttributes } from './ops/concat'; +import { conv, parseConvAttributes } from './ops/conv'; +import { convTranspose, parseConvTransposeAttributes } from './ops/conv-transpose'; +import { cumsum, parseCumSumAttributes } from './ops/cumsum'; +import { depthToSpace, parseDepthToSpaceAttributes } from './ops/depth-to-space'; +import { einsum, parseEinsumAttributes } from './ops/einsum'; +import { expand } from './ops/expand'; +import { fastGelu } from './ops/fast-gelu'; +import { gather, parseGatherAttributes } from './ops/gather'; +import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements'; +import { gemm, parseGemmAttributes } from './ops/gemm'; +import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention'; +import { instanceNorm } from './ops/instance-norm'; +import { layerNorm } from './ops/layer-norm'; +import { matMul } from './ops/matmul'; +import { matMulNBits, parseMatMulNBitsAttributes } from './ops/matmulnbits'; +import { multiHeadAttention, parseMultiHeadAttentionAttributes } from './ops/multihead-attention'; +import { pad } from './ops/pad'; import * as pool from './ops/pool'; -import {dequantizeLinear, parseDequantizeLinearAttributes} from './ops/quantize-linear'; -import {range} from './ops/range'; -import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; -import {parseResizeAttributes, resize} from './ops/resize'; -import {rotaryEmbedding} from './ops/rotary-embedding'; -import {skipLayerNorm} from './ops/skip-layer-norm'; -import {parseSliceAttributes, slice} from './ops/slice'; -import {parseSoftmaxAttributes, softmax} from './ops/softmax'; -import {parseSplitAttributes, split} from './ops/split'; -import {tile} from './ops/tile'; -import {parseTransposeAttributes, transpose} from './ops/transpose'; +import { dequantizeLinear, parseDequantizeLinearAttributes } from './ops/quantize-linear'; +import { range } from './ops/range'; +import { + reduceL1, + reduceL2, + reduceLogSum, + reduceLogSumExp, + reduceMax, + reduceMean, + reduceMin, + reduceProd, + reduceSum, + reduceSumSquare, +} from './ops/reduce'; +import { parseResizeAttributes, resize } from './ops/resize'; +import { rotaryEmbedding } from './ops/rotary-embedding'; +import { skipLayerNorm } from './ops/skip-layer-norm'; +import { parseSliceAttributes, slice } from './ops/slice'; +import { parseSoftmaxAttributes, softmax } from './ops/softmax'; +import { parseSplitAttributes, split } from './ops/split'; +import { tile } from './ops/tile'; +import { parseTransposeAttributes, transpose } from './ops/transpose'; import * as unaryOps from './ops/unary-op'; -import {where} from './ops/where'; -import {ComputeContext} from './types'; +import { where } from './ops/where'; +import { ComputeContext } from './types'; export type RunFunction = (context: ComputeContext, attribute?: unknown) => void; export type ParseAttributeFunction = (attributeRaw: unknown) => unknown; -export type OperatorImplementation = [RunFunction]|[RunFunction, ParseAttributeFunction]; +export type OperatorImplementation = [RunFunction] | [RunFunction, ParseAttributeFunction]; export const WEBGPU_OP_RESOLVE_RULES: Map = new Map([ ['Abs', [unaryOps.abs]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 24006d393592a..7884a3cd1a684 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -19,59 +19,76 @@ // // modified to fit the needs of the project -import {DataType} from '../../../../wasm-common'; -import {LOG_DEBUG} from '../../../log'; -import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; -import {ConvAttributes} from '../conv'; -import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; +import { DataType } from '../../../../wasm-common'; +import { LOG_DEBUG } from '../../../log'; +import { TensorView } from '../../../tensor-view'; +import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types'; +import { + createTensorShapeVariables, + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from '../common'; +import { ConvAttributes } from '../conv'; +import { appendActivationUniforms, appendActivationUniformsData, getActivationSnippet } from '../fuse-utils'; -import {biasSnippet, typeSnippet} from './activation_util'; -import {utilFunctions} from './conv_util'; -import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; +import { biasSnippet, typeSnippet } from './activation_util'; +import { utilFunctions } from './conv_util'; +import { makeMatMulPackedSource, makeMatMulPackedVec4Source } from './matmul_packed_webgpu'; -const conv2dCommonSnippet = - (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, - attributes: ConvAttributes, innerElementSizeX = 4, innerElementSizeW = 4, innerElementSize = 4, - dataType = 'f32'): string => { - const getXSnippet = (innerElementSize: number) => { - switch (innerElementSize) { - case 1: - return 'resData = x[xIndex];'; - case 3: - return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`; - case 4: - return 'resData = x[xIndex / 4];'; - default: - throw new Error(`innerElementSize ${innerElementSize} is not supported.`); - } - }; - const getWSnippet = (innerElementSize: number) => { - switch (innerElementSize) { - case 1: - return 'return w[row * i32(uniforms.w_shape[3]) + colIn];'; - case 4: - return 'return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];'; - default: - throw new Error(`innerElementSize ${innerElementSize} is not supported.`); - } - }; - const coordASnippet = isChannelsLast ? ` +const conv2dCommonSnippet = ( + isChannelsLast: boolean, + fitAOuter: boolean, + fitBOuter: boolean, + fitInner: boolean, + addBias = false, + attributes: ConvAttributes, + innerElementSizeX = 4, + innerElementSizeW = 4, + innerElementSize = 4, + dataType = 'f32', +): string => { + const getXSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'resData = x[xIndex];'; + case 3: + return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`; + case 4: + return 'resData = x[xIndex / 4];'; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return w[row * i32(uniforms.w_shape[3]) + colIn];'; + case 4: + return 'return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];'; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const coordASnippet = isChannelsLast + ? ` let coord = vec4(batch, xRow, xCol, xCh); - ` : - ` + ` + : ` let coord = vec4(batch, xCh, xRow, xCol); `; - const coordResSnippet = isChannelsLast ? ` + const coordResSnippet = isChannelsLast + ? ` let coords = vec4( batch, row / outWidth, row % outWidth, col); - ` : - ` + ` + : ` let coords = vec4( batch, row, @@ -79,11 +96,11 @@ const conv2dCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; - const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; - const row = isChannelsLast ? 'row' : 'col'; - const col = isChannelsLast ? 'col' : 'row'; - const readXSnippet = ` + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; + const readXSnippet = ` let inChannels = i32(uniforms.w_shape[2]); let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; @@ -104,34 +121,35 @@ const conv2dCommonSnippet = } return resData;`; - const sampleX = isChannelsLast ? (fitAOuter && fitInner ? ` + const sampleX = isChannelsLast + ? fitAOuter && fitInner + ? ` let col = colIn * ${innerElementSizeX}; - ${readXSnippet}` : - ` + ${readXSnippet}` + : ` let col = colIn * ${innerElementSizeX}; if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : - (fitInner && fitBOuter ? ` + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);` + : fitInner && fitBOuter + ? ` let col = colIn * ${innerElementSizeX}; - ${readXSnippet}` : - ` + ${readXSnippet}` + : ` let col = colIn * ${innerElementSizeX}; if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`; - const sampleW = `${getWSnippet(innerElementSizeW)}`; + const sampleW = `${getWSnippet(innerElementSizeW)}`; - const resType = typeSnippet(innerElementSize, dataType); - const aType = - isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); - const bType = - isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const applyActivation = getActivationSnippet(attributes, resType, dataType); - const userCode = ` + const resType = typeSnippet(innerElementSize, dataType); + const aType = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); + const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); + const applyActivation = getActivationSnippet(attributes, resType, dataType); + const userCode = ` fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} } @@ -152,69 +170,82 @@ const conv2dCommonSnippet = setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value); } }`; - return userCode; - }; + return userCode; +}; -export const createConv2DMatMulProgramInfo = - (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[], dimAOuter: number, - dimBOuter: number, dimInner: number, hasBias: boolean, sequentialAccessByThreads: boolean): ProgramInfo => { - const isChannelsLast = attributes.format === 'NHWC'; - const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; - const batchSize = outputShape[0]; - const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; - const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; - const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - // TODO: enable vec4 for NCHW - const isVec4 = isChannelsLast && (inChannels % 4 === 0 || inChannels % 3 === 0) && outChannels % 4 === 0; +export const createConv2DMatMulProgramInfo = ( + inputs: readonly TensorView[], + attributes: ConvAttributes, + outputShape: readonly number[], + dimAOuter: number, + dimBOuter: number, + dimInner: number, + hasBias: boolean, + sequentialAccessByThreads: boolean, +): ProgramInfo => { + const isChannelsLast = attributes.format === 'NHWC'; + const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const batchSize = outputShape[0]; + const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; + const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; + const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; + // TODO: enable vec4 for NCHW + const isVec4 = isChannelsLast && (inChannels % 4 === 0 || inChannels % 3 === 0) && outChannels % 4 === 0; - // TODO: fine tune size - const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; - const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = [8, 8, 1]; - const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; - const dispatch = [ - Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), - Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), - Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) - ]; + // TODO: fine tune size + const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = [8, 8, 1]; + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; + const dispatch = [ + Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), + Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]), + ]; - LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); + LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; - const tileAOuter = workGroupSize[1] * elementsPerThread[1]; - const tileBOuter = workGroupSize[0] * elementsPerThread[0]; - const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - const fitAOuter = dimAOuter % tileAOuter === 0; - const fitBOuter = dimBOuter % tileBOuter === 0; - const fitInner = dimInner % tileInner === 0; - const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; + const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; + const tileAOuter = workGroupSize[1] * elementsPerThread[1]; + const tileBOuter = workGroupSize[0] * elementsPerThread[0]; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + const fitAOuter = dimAOuter % tileAOuter === 0; + const fitBOuter = dimBOuter % tileBOuter === 0; + const fitInner = dimInner % tileInner === 0; + const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; - const programUniforms: ProgramUniform[] = [ - {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, - {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}, - {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations} - ]; - appendActivationUniformsData(attributes, programUniforms); - programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; - if (hasBias) { - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); - } - programUniforms.push(...createTensorShapeVariables(outputShape)); + const programUniforms: ProgramUniform[] = [ + { type: DataType.int32, data: dimAOuter }, + { type: DataType.int32, data: dimBOuter }, + { type: DataType.int32, data: dimInner }, + { type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]] }, + { type: DataType.int32, data: attributes.strides }, + { type: DataType.int32, data: attributes.dilations }, + ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const uniforms: UniformsArrayType = [ - {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, - {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, - {name: 'dilation', type: 'i32', length: 2} - ]; - appendActivationUniforms(attributes, uniforms); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + { name: 'dim_a_outer', type: 'i32' }, + { name: 'dim_b_outer', type: 'i32' }, + { name: 'dim_inner', type: 'i32' }, + { name: 'pad', type: 'i32', length: 2 }, + { name: 'stride', type: 'i32', length: 2 }, + { name: 'dilation', type: 'i32', length: 2 }, + ]; + appendActivationUniforms(attributes, uniforms); - // TODO: support component 2, 3. - const components = isVec4 ? 4 : 1; - const t = tensorTypeToWsglStorageType(inputs[0].dataType); - let declareFunctions = ` + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } @@ -222,50 +253,72 @@ export const createConv2DMatMulProgramInfo = let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; - const x = inputVariable( - 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); - const inputVariables = [x, w]; - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); - declareFunctions += ` + const x = inputVariable( + 'x', + inputs[0].dataType, + inputs[0].dims.length, + innerElementSize === 3 ? 1 : innerElementSize, + ); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; - } + } - return ` + return ` ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} + ${conv2dCommonSnippet( + isChannelsLast, + fitAOuter, + fitBOuter, + fitInner, + hasBias, + attributes, + elementsSize[0], + elementsSize[1], + elementsSize[2], + t, + )} ${ - conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], - elementsSize[2], t)} - ${ - isVec4 ? - makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : - makeMatMulPackedSource( - elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}`; - }; - return { - name: 'Conv2DMatMul', - shaderCache: { - hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ - tileAOuter};${tileBOuter};${tileInner}`, - inputDependencies - }, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms, - }), - getShaderSource - }; - }; + isVec4 + ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) + : makeMatMulPackedSource( + elementsPerThread, + workGroupSize, + t, + undefined, + !isChannelsLast, + tileInner, + false, + undefined, + sequentialAccessByThreads, + ) + }`; + }; + return { + name: 'Conv2DMatMul', + shaderCache: { + hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ + tileAOuter + };${tileBOuter};${tileInner}`, + inputDependencies, + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, + programUniforms, + }), + getShaderSource, + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts index a2e5428385101..b5cf049346f6f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts @@ -19,16 +19,24 @@ // // modified to fit the needs of the project -import {DataType} from '../../../../wasm-common'; -import {LOG_DEBUG} from '../../../log'; -import {TensorView} from '../../../tensor-view'; -import {ShapeUtil} from '../../../util'; -import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; -import {ConvAttributes} from '../conv'; -import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; +import { DataType } from '../../../../wasm-common'; +import { LOG_DEBUG } from '../../../log'; +import { TensorView } from '../../../tensor-view'; +import { ShapeUtil } from '../../../util'; +import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types'; +import { + createTensorShapeVariables, + getElementAt, + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from '../common'; +import { ConvAttributes } from '../conv'; +import { appendActivationUniforms, appendActivationUniformsData, getActivationSnippet } from '../fuse-utils'; -import {typeSnippet} from './activation_util'; +import { typeSnippet } from './activation_util'; const arrayProduct = (arr: number[]) => { let product = 1; @@ -38,8 +46,8 @@ const arrayProduct = (arr: number[]) => { return product; }; -const parse3TupleParam = (param: number|[number, number, number]): [number, number, number] => - typeof param === 'number' ? [param, param, param] : param; +const parse3TupleParam = (param: number | [number, number, number]): [number, number, number] => + typeof param === 'number' ? [param, param, param] : param; const getEffectiveFilterSize = (filterSize: number, dilation: number): number => { if (dilation <= 1) { @@ -49,90 +57,123 @@ const getEffectiveFilterSize = (filterSize: number, dilation: number): number => return filterSize + (filterSize - 1) * (dilation - 1); }; -const computeDefaultPad = - (inputShape: [number, number]|[number, number, number, number], fieldSize: number, stride: number, dilation = 1): - number => { - const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation); - return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2); - }; +const computeDefaultPad = ( + inputShape: [number, number] | [number, number, number, number], + fieldSize: number, + stride: number, + dilation = 1, +): number => { + const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation); + return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2); +}; -const computeOutputShape4D = - (inShape: [number, number, number, number], filterShape: [number, number, number], outChannels: number, - strides: [number, number, number], zeroPad?: number): [number, number, number, number] => { - if (zeroPad == null) { - // eslint-disable-next-line no-param-reassign - zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]); - } - const outShape: [number, number, number, number] = [0, 0, 0, outChannels]; - for (let index = 0; index < 3; index++) { - if (inShape[index] + 2 * zeroPad >= filterShape[index]) { - outShape[index] = Math.trunc((inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] + 1); - } - } - return outShape; - }; +const computeOutputShape4D = ( + inShape: [number, number, number, number], + filterShape: [number, number, number], + outChannels: number, + strides: [number, number, number], + zeroPad?: number, +): [number, number, number, number] => { + if (zeroPad == null) { + // eslint-disable-next-line no-param-reassign + zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]); + } + const outShape: [number, number, number, number] = [0, 0, 0, outChannels]; + for (let index = 0; index < 3; index++) { + if (inShape[index] + 2 * zeroPad >= filterShape[index]) { + outShape[index] = Math.trunc((inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] + 1); + } + } + return outShape; +}; -const get3DPadAndOutInfo = - (pad: number|string|number[], inDepth: number, inHeight: number, inWidth: number, strideDepth: number, - strideHeight: number, strideWidth: number, filterDepth: number, filterHeight: number, - filterWidth: number): {padInfo: PadInfo3D; outDepth: number; outHeight: number; outWidth: number} => { - let padInfo: PadInfo3D; - let outDepth: number; - let outHeight: number; - let outWidth: number; +const get3DPadAndOutInfo = ( + pad: number | string | number[], + inDepth: number, + inHeight: number, + inWidth: number, + strideDepth: number, + strideHeight: number, + strideWidth: number, + filterDepth: number, + filterHeight: number, + filterWidth: number, +): { padInfo: PadInfo3D; outDepth: number; outHeight: number; outWidth: number } => { + let padInfo: PadInfo3D; + let outDepth: number; + let outHeight: number; + let outWidth: number; - if (pad === 'VALID') { - // eslint-disable-next-line no-param-reassign - pad = 0; - } + if (pad === 'VALID') { + // eslint-disable-next-line no-param-reassign + pad = 0; + } - if (typeof pad === 'number') { - padInfo = {top: pad, bottom: pad, left: pad, right: pad, front: pad, back: pad}; - const outShape = computeOutputShape4D( - [inDepth, inHeight, inWidth, 1], [filterDepth, filterHeight, filterWidth], 1, - [strideDepth, strideHeight, strideWidth], pad); - outDepth = outShape[0]; - outHeight = outShape[1]; - outWidth = outShape[2]; - } else if (Array.isArray(pad)) { - if (!pad.every((val, _, arr) => val === arr[0])) { - throw Error(`Unsupported padding parameter: ${pad}`); - } - padInfo = {top: pad[0], bottom: pad[1], left: pad[2], right: pad[3], front: pad[4], back: pad[5]}; - const outShape = computeOutputShape4D( - [inDepth, inHeight, inWidth, 1], [filterDepth, filterHeight, filterWidth], 1, - [strideDepth, strideHeight, strideWidth], pad[0]); - outDepth = outShape[0]; - outHeight = outShape[1]; - outWidth = outShape[2]; - } else if (pad === 'SAME_UPPER') { - // TODO: support 'SAME_LOWER'. - outDepth = Math.ceil(inDepth / strideDepth); - outHeight = Math.ceil(inHeight / strideHeight); - outWidth = Math.ceil(inWidth / strideWidth); - const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth; - const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; - const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; - const front = Math.floor(padAlongDepth / 2); - const back = padAlongDepth - front; - const top = Math.floor(padAlongHeight / 2); - const bottom = padAlongHeight - top; - const left = Math.floor(padAlongWidth / 2); - const right = padAlongWidth - left; + if (typeof pad === 'number') { + padInfo = { top: pad, bottom: pad, left: pad, right: pad, front: pad, back: pad }; + const outShape = computeOutputShape4D( + [inDepth, inHeight, inWidth, 1], + [filterDepth, filterHeight, filterWidth], + 1, + [strideDepth, strideHeight, strideWidth], + pad, + ); + outDepth = outShape[0]; + outHeight = outShape[1]; + outWidth = outShape[2]; + } else if (Array.isArray(pad)) { + if (!pad.every((val, _, arr) => val === arr[0])) { + throw Error(`Unsupported padding parameter: ${pad}`); + } + padInfo = { top: pad[0], bottom: pad[1], left: pad[2], right: pad[3], front: pad[4], back: pad[5] }; + const outShape = computeOutputShape4D( + [inDepth, inHeight, inWidth, 1], + [filterDepth, filterHeight, filterWidth], + 1, + [strideDepth, strideHeight, strideWidth], + pad[0], + ); + outDepth = outShape[0]; + outHeight = outShape[1]; + outWidth = outShape[2]; + } else if (pad === 'SAME_UPPER') { + // TODO: support 'SAME_LOWER'. + outDepth = Math.ceil(inDepth / strideDepth); + outHeight = Math.ceil(inHeight / strideHeight); + outWidth = Math.ceil(inWidth / strideWidth); + const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth; + const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; + const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; + const front = Math.floor(padAlongDepth / 2); + const back = padAlongDepth - front; + const top = Math.floor(padAlongHeight / 2); + const bottom = padAlongHeight - top; + const left = Math.floor(padAlongWidth / 2); + const right = padAlongWidth - left; - padInfo = {top, bottom, left, right, front, back}; - } else { - throw Error(`Unknown padding parameter: ${pad}`); - } - return {padInfo, outDepth, outHeight, outWidth}; - }; + padInfo = { top, bottom, left, right, front, back }; + } else { + throw Error(`Unknown padding parameter: ${pad}`); + } + return { padInfo, outDepth, outHeight, outWidth }; +}; type PadInfo3D = { - top: number; left: number; right: number; bottom: number; front: number; back: number; + top: number; + left: number; + right: number; + bottom: number; + front: number; + back: number; }; export type Conv3DInfo = { - batchSize: number; inDepth: number; inHeight: number; inWidth: number; inChannels: number; outDepth: number; + batchSize: number; + inDepth: number; + inHeight: number; + inWidth: number; + inChannels: number; + outDepth: number; outHeight: number; outWidth: number; outChannels: number; @@ -155,130 +196,157 @@ export type Conv3DInfo = { filterShape: [number, number, number, number, number]; }; -export const computeConv3DInfo = - (inShape: [number, number, number, number, number], filterShape: [number, number, number, number, number], - strides: number|[number, number, number], dilations: number|[number, number, number], pad: number|string|number[], - depthwise = false, dataFormat: 'channelsFirst'|'channelsLast' = 'channelsLast'): Conv3DInfo => { - let batchSize, inDepth, inHeight, inWidth, inChannels; - if (dataFormat === 'channelsLast') { - [batchSize, inDepth, inHeight, inWidth, inChannels] = inShape; - } else if (dataFormat === 'channelsFirst') { - [batchSize, inChannels, inDepth, inHeight, inWidth] = inShape; - } else { - throw new Error(`Unknown dataFormat ${dataFormat}`); - } - const [filterChannels, , filterDepth, filterHeight, filterWidth] = filterShape; +export const computeConv3DInfo = ( + inShape: [number, number, number, number, number], + filterShape: [number, number, number, number, number], + strides: number | [number, number, number], + dilations: number | [number, number, number], + pad: number | string | number[], + depthwise = false, + dataFormat: 'channelsFirst' | 'channelsLast' = 'channelsLast', +): Conv3DInfo => { + let batchSize, inDepth, inHeight, inWidth, inChannels; + if (dataFormat === 'channelsLast') { + [batchSize, inDepth, inHeight, inWidth, inChannels] = inShape; + } else if (dataFormat === 'channelsFirst') { + [batchSize, inChannels, inDepth, inHeight, inWidth] = inShape; + } else { + throw new Error(`Unknown dataFormat ${dataFormat}`); + } + const [filterChannels, , filterDepth, filterHeight, filterWidth] = filterShape; - const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides); - const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations); + const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides); + const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations); - const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth); - const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); - const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); - const {padInfo, outDepth, outHeight, outWidth} = get3DPadAndOutInfo( - pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, - effectiveFilterHeight, effectiveFilterWidth); + const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth); + const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); + const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); + const { padInfo, outDepth, outHeight, outWidth } = get3DPadAndOutInfo( + pad, + inDepth, + inHeight, + inWidth, + strideDepth, + strideHeight, + strideWidth, + effectiveFilterDepth, + effectiveFilterHeight, + effectiveFilterWidth, + ); - const outChannels = depthwise ? filterChannels * inChannels : filterChannels; + const outChannels = depthwise ? filterChannels * inChannels : filterChannels; - let outShape: [number, number, number, number, number] = [0, 0, 0, 0, 0]; - if (dataFormat === 'channelsFirst') { - outShape = [batchSize, outChannels, outDepth, outHeight, outWidth]; - } else if (dataFormat === 'channelsLast') { - outShape = [batchSize, outDepth, outHeight, outWidth, outChannels]; - } + let outShape: [number, number, number, number, number] = [0, 0, 0, 0, 0]; + if (dataFormat === 'channelsFirst') { + outShape = [batchSize, outChannels, outDepth, outHeight, outWidth]; + } else if (dataFormat === 'channelsLast') { + outShape = [batchSize, outDepth, outHeight, outWidth, outChannels]; + } - return { - batchSize, - dataFormat, - inDepth, - inHeight, - inWidth, - inChannels, - outDepth, - outHeight, - outWidth, - outChannels, - padInfo, - strideDepth, - strideHeight, - strideWidth, - filterDepth, - filterHeight, - filterWidth, - effectiveFilterDepth, - effectiveFilterHeight, - effectiveFilterWidth, - dilationDepth, - dilationHeight, - dilationWidth, - inShape, - outShape, - filterShape - }; - }; + return { + batchSize, + dataFormat, + inDepth, + inHeight, + inWidth, + inChannels, + outDepth, + outHeight, + outWidth, + outChannels, + padInfo, + strideDepth, + strideHeight, + strideWidth, + filterDepth, + filterHeight, + filterWidth, + effectiveFilterDepth, + effectiveFilterHeight, + effectiveFilterWidth, + dilationDepth, + dilationHeight, + dilationWidth, + inShape, + outShape, + filterShape, + }; +}; -export const createConv3DNaiveProgramInfo = - (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[], - filterDims: readonly number[], pads: readonly number[], dataFormat: string): ProgramInfo => { - const isChannelLast = dataFormat === 'channelsLast'; - const inChannels = isChannelLast ? inputs[0].dims[3] : inputs[0].dims[1]; - // TODO: enable vec4. - const isVec4 = false; - const workGroupSize: [number, number, number] = [64, 1, 1]; - const dispatchLayout = {x: outputShape.map((_, i) => i)}; - const dispatch = [Math.ceil(arrayProduct(dispatchLayout.x.map(d => outputShape[d])) / (workGroupSize[0])), 1, 1]; +export const createConv3DNaiveProgramInfo = ( + inputs: readonly TensorView[], + attributes: ConvAttributes, + outputShape: readonly number[], + filterDims: readonly number[], + pads: readonly number[], + dataFormat: string, +): ProgramInfo => { + const isChannelLast = dataFormat === 'channelsLast'; + const inChannels = isChannelLast ? inputs[0].dims[3] : inputs[0].dims[1]; + // TODO: enable vec4. + const isVec4 = false; + const workGroupSize: [number, number, number] = [64, 1, 1]; + const dispatchLayout = { x: outputShape.map((_, i) => i) }; + const dispatch = [Math.ceil(arrayProduct(dispatchLayout.x.map((d) => outputShape[d])) / workGroupSize[0]), 1, 1]; - LOG_DEBUG('verbose', () => `[conv3d_naive_webgpu] dispatch = ${dispatch}`); + LOG_DEBUG('verbose', () => `[conv3d_naive_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (isChannelLast && inChannels % 4 !== 0 ? 3 : 4) : 1; - const outputSize = ShapeUtil.size(outputShape); - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: filterDims}, - {type: DataType.uint32, data: pads}, {type: DataType.uint32, data: attributes.strides}, - {type: DataType.uint32, data: attributes.dilations} - ]; - appendActivationUniformsData(attributes, programUniforms); - programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; - const hasBias = inputs.length === 3; - if (hasBias) { - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); - } - programUniforms.push(...createTensorShapeVariables(outputShape)); + const innerElementSize = isVec4 ? (isChannelLast && inChannels % 4 !== 0 ? 3 : 4) : 1; + const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: filterDims }, + { type: DataType.uint32, data: pads }, + { type: DataType.uint32, data: attributes.strides }, + { type: DataType.uint32, data: attributes.dilations }, + ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + const hasBias = inputs.length === 3; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'filter_dims', type: 'u32', length: filterDims.length}, - {name: 'pads', type: 'u32', length: pads.length}, - {name: 'strides', type: 'u32', length: attributes.strides.length}, - {name: 'dilations', type: 'u32', length: attributes.dilations.length} - ]; - appendActivationUniforms(attributes, uniforms); - // TODO: support component 2, 3. - const components = isVec4 ? 4 : 1; - const t = tensorTypeToWsglStorageType(inputs[0].dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'filter_dims', type: 'u32', length: filterDims.length }, + { name: 'pads', type: 'u32', length: pads.length }, + { name: 'strides', type: 'u32', length: attributes.strides.length }, + { name: 'dilations', type: 'u32', length: attributes.dilations.length }, + ]; + appendActivationUniforms(attributes, uniforms); + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); - const x = inputVariable( - 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); - const inputVariables = [x, w]; - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - let declareFunctions = ''; - if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); - declareFunctions += ` + const x = inputVariable( + 'x', + inputs[0].dataType, + inputs[0].dims.length, + innerElementSize === 3 ? 1 : innerElementSize, + ); + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + let declareFunctions = ''; + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` fn getBiasByOutputCoords(coords : array) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[${isChannelLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${ - isVec4 ? '/ 4' : ''}]; + isVec4 ? '/ 4' : '' + }]; }`; - } - const resType = typeSnippet(innerElementSize, t); - const applyActivation = getActivationSnippet(attributes, resType, t); + } + const resType = typeSnippet(innerElementSize, t); + const applyActivation = getActivationSnippet(attributes, resType, t); - return ` + return ` ${declareFunctions} fn getX(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> f32 { let aIndices = array(d0, d1, d2, d3, d4); @@ -294,24 +362,38 @@ export const createConv3DNaiveProgramInfo = let coords = ${output.offsetToIndices('global_idx')}; let batch = ${getElementAt('coords', 0, x.rank)}; let d2 = ${ - isChannelLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)}; + isChannelLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank) + }; let xFRCCorner = vec3(${ - isChannelLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)}, + isChannelLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank) + }, ${isChannelLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)}, ${ - isChannelLast ? getElementAt('coords', 3, x.rank) : - getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads; + isChannelLast ? getElementAt('coords', 3, x.rank) : getElementAt('coords', 4, x.rank) + }) * uniforms.strides - uniforms.pads; let xFCorner = xFRCCorner.x; let xRCorner = xFRCCorner.y; let xCCorner = xFRCCorner.z; let xShapeY = ${ - isChannelLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)}; + isChannelLast + ? getElementAt('uniforms.x_shape', 1, x.rank) + : getElementAt('uniforms.x_shape', 2, x.rank) + }; let xShapeZ = ${ - isChannelLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)}; + isChannelLast + ? getElementAt('uniforms.x_shape', 2, x.rank) + : getElementAt('uniforms.x_shape', 3, x.rank) + }; let xShapeW = ${ - isChannelLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)}; + isChannelLast + ? getElementAt('uniforms.x_shape', 3, x.rank) + : getElementAt('uniforms.x_shape', 4, x.rank) + }; let xShapeU = ${ - isChannelLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)}; + isChannelLast + ? getElementAt('uniforms.x_shape', 4, x.rank) + : getElementAt('uniforms.x_shape', 1, x.rank) + }; let inputDepthNearestVec4 = (xShapeU / 4) * 4; let inputDepthVec4Remainder = xShapeU % 4; @@ -336,18 +418,20 @@ export const createConv3DNaiveProgramInfo = for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) { ${ - isChannelLast ? `let xValues = vec4( + isChannelLast + ? `let xValues = vec4( getX(batch, xF, xR, xC, d1), getX(batch, xF, xR, xC, d1 + 1), getX(batch, xF, xR, xC, d1 + 2), getX(batch, xF, xR, xC, d1 + 3)); - ` : - `let xValues = vec4( + ` + : `let xValues = vec4( getX(batch, d1, xF, xR, xC), getX(batch, d1 + 1, xF, xR, xC), getX(batch, d1 + 2, xF, xR, xC), getX(batch, d1 + 3, xF, xR, xC)); - `} + ` + } let wValues = vec4( getW(d2, d1, wF, wR, wC), getW(d2, d1 + 1, wF, wR, wC), @@ -357,36 +441,42 @@ export const createConv3DNaiveProgramInfo = } if (inputDepthVec4Remainder == 1) { ${ - isChannelLast ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4) - * getW(d2, inputDepthNearestVec4, wF, wR, wC);` : - `value += getX(batch, inputDepthNearestVec4, xF, xR, xC) - * getW(d2, inputDepthNearestVec4, wF, wR, wC);`} + isChannelLast + ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4) + * getW(d2, inputDepthNearestVec4, wF, wR, wC);` + : `value += getX(batch, inputDepthNearestVec4, xF, xR, xC) + * getW(d2, inputDepthNearestVec4, wF, wR, wC);` + } } else if (inputDepthVec4Remainder == 2) { ${ - isChannelLast ? `let xValues = vec2( + isChannelLast + ? `let xValues = vec2( getX(batch, xF, xR, xC, inputDepthNearestVec4), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1)); - ` : - `let xValues = vec2( + ` + : `let xValues = vec2( getX(batch, inputDepthNearestVec4, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC)); - `} + ` + } let wValues = vec2( getW(d2, inputDepthNearestVec4, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC)); value += dot(xValues, wValues); } else if (inputDepthVec4Remainder == 3) { ${ - isChannelLast ? `let xValues = vec3( + isChannelLast + ? `let xValues = vec3( getX(batch, xF, xR, xC, inputDepthNearestVec4), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1), getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2)); - ` : - `let xValues = vec3( + ` + : `let xValues = vec3( getX(batch, inputDepthNearestVec4, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC), getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC)); - `} + ` + } let wValues = vec3( getW(d2, inputDepthNearestVec4, wF, wR, wC), getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC), @@ -400,16 +490,15 @@ export const createConv3DNaiveProgramInfo = ${applyActivation} result[global_idx] = f32(value); }`; - }; - return { - name: 'Conv3DNaive', - shaderCache: - {hint: `${attributes.cacheKey};${isChannelLast};${innerElementSize};${hasBias}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms, - }), - getShaderSource - }; - }; + }; + return { + name: 'Conv3DNaive', + shaderCache: { hint: `${attributes.cacheKey};${isChannelLast};${innerElementSize};${hasBias}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, + programUniforms, + }), + getShaderSource, + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 080b24a2432aa..ca0ec0f9e6674 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -19,27 +19,38 @@ // // modified to fit the needs of the project -import {DataType} from '../../../../wasm-common'; -import {LOG_DEBUG} from '../../../log'; -import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; -import {ConvTransposeAttributes} from '../conv-transpose'; -import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; +import { DataType } from '../../../../wasm-common'; +import { LOG_DEBUG } from '../../../log'; +import { TensorView } from '../../../tensor-view'; +import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types'; +import { + createTensorShapeVariables, + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from '../common'; +import { ConvTransposeAttributes } from '../conv-transpose'; +import { appendActivationUniforms, appendActivationUniformsData, getActivationSnippet } from '../fuse-utils'; -import {biasSnippet} from './activation_util'; -import {utilFunctions} from './conv_util'; -import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; +import { biasSnippet } from './activation_util'; +import { utilFunctions } from './conv_util'; +import { makeMatMulPackedSource, makeMatMulPackedVec4Source } from './matmul_packed_webgpu'; -const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string, - innerElementSize = 4): string => { - const getWSnippet = (innerElementSize: number) => { - switch (innerElementSize) { - case 1: - return 'return w[getIndexFromCoords4D(coord, vec4(uniforms.w_shape))];'; - case 4: - return ` +const conv2dTransposeCommonSnippet = ( + isChannelsLast: boolean, + addBias = false, + attributes: ConvTransposeAttributes, + type: string, + innerElementSize = 4, +): string => { + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return w[getIndexFromCoords4D(coord, vec4(uniforms.w_shape))];'; + case 4: + return ` let coord1 = vec4(coordX, coordY, col + 1, rowInner); let coord2 = vec4(coordX, coordY, col + 2, rowInner); let coord3 = vec4(coordX, coordY, col + 3, rowInner); @@ -49,25 +60,27 @@ const conv2dTransposeCommonSnippet = let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; return ${type}(v0, v1, v2, v3); `; - default: - throw new Error(`innerElementSize ${innerElementSize} is not supported.`); - } - }; - const coordASnippet = isChannelsLast ? ` + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const coordASnippet = isChannelsLast + ? ` let coord = vec4(batch, iXR, iXC, xCh); - ` : - ` + ` + : ` let coord = vec4(batch, xCh, iXR, iXC); `; - const coordResSnippet = isChannelsLast ? ` + const coordResSnippet = isChannelsLast + ? ` let coords = vec4( batch, row / outWidth, row % outWidth, col); - ` : - ` + ` + : ` let coords = vec4( batch, row, @@ -75,12 +88,12 @@ const conv2dTransposeCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; - const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; - const row = isChannelsLast ? 'row' : 'col'; - const col = isChannelsLast ? 'col' : 'row'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; - const readASnippet = ` + const readASnippet = ` let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; @@ -102,27 +115,30 @@ const conv2dTransposeCommonSnippet = ${coordASnippet} return x[getIndexFromCoords4D(coord, vec4(uniforms.x_shape))/${innerElementSize}];`; - const sampleA = isChannelsLast ? ` + const sampleA = isChannelsLast + ? ` let col = colIn * ${innerElementSize}; if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readASnippet} } - return ${type}(0.0);` : - ` + return ${type}(0.0);` + : ` let col = colIn * ${innerElementSize}; if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readASnippet} } return ${type}(0.0);`; - const sampleW = ` + const sampleW = ` let col = colIn * ${innerElementSize}; let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels); let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1]; if (${ - isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' : - 'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) { + isChannelsLast + ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' + : 'row < uniforms.dim_inner && col < uniforms.dim_a_outer' + } && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -130,8 +146,8 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - const applyActivation = getActivationSnippet(attributes, type); - const userCode = ` + const applyActivation = getActivationSnippet(attributes, type); + const userCode = ` fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } @@ -151,114 +167,140 @@ const conv2dTransposeCommonSnippet = result[getIndexFromCoords4D(coords, vec4(uniforms.result_shape))/${innerElementSize}] = value; } }`; - return userCode; - }; + return userCode; +}; -export const createConv2DTransposeMatMulProgramInfo = - (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, outputShape: readonly number[], - dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, - sequentialAccessByThreads: boolean): ProgramInfo => { - const isChannelsLast = attributes.format === 'NHWC'; - const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; - const batchSize = outputShape[0]; - const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; - const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; - const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - // TODO: enable vec4 for NCHW - const isVec4 = isChannelsLast && (inChannels % 4 === 0 && inChannels % 3) && outChannels % 4 === 0; +export const createConv2DTransposeMatMulProgramInfo = ( + inputs: readonly TensorView[], + attributes: ConvTransposeAttributes, + outputShape: readonly number[], + dimAOuter: number, + dimBOuter: number, + dimInner: number, + hasBias: boolean, + sequentialAccessByThreads: boolean, +): ProgramInfo => { + const isChannelsLast = attributes.format === 'NHWC'; + const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const batchSize = outputShape[0]; + const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; + const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; + const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; + // TODO: enable vec4 for NCHW + const isVec4 = isChannelsLast && inChannels % 4 === 0 && inChannels % 3 && outChannels % 4 === 0; - // TODO: fine tune size - const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; - const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = [8, 8, 1]; - const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; - const dispatch = [ - Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), - Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), - Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) - ]; + // TODO: fine tune size + const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = [8, 8, 1]; + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; + const dispatch = [ + Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), + Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]), + ]; - LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); + LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? 4 : 1; - const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - const components = isVec4 ? 4 : 1; - const filterDims = - [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; - const effectiveFilterDims = [ - filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)), - filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1)) - ]; - const pads = [ - effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), - effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2) - ]; + const innerElementSize = isVec4 ? 4 : 1; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + const components = isVec4 ? 4 : 1; + const filterDims = [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const effectiveFilterDims = [ + filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1)), + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2), + ]; - const programUniforms: ProgramUniform[] = [ - {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, - {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides}, - {type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims}, - {type: DataType.int32, data: pads} - ]; - appendActivationUniformsData(attributes, programUniforms); - programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); + const programUniforms: ProgramUniform[] = [ + { type: DataType.int32, data: dimAOuter }, + { type: DataType.int32, data: dimBOuter }, + { type: DataType.int32, data: dimInner }, + { type: DataType.int32, data: attributes.strides }, + { type: DataType.int32, data: attributes.dilations }, + { type: DataType.int32, data: filterDims }, + { type: DataType.int32, data: pads }, + ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; - if (hasBias) { - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); - } - programUniforms.push(...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - const inputVariables = [x, w]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; - let declareFunctions = ''; - if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); - declareFunctions += ` + let declareFunctions = ''; + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; - } + } - const uniforms: UniformsArrayType = [ - {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, - {name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2}, - {name: 'filter_dims', type: 'i32', length: filterDims.length}, - {name: 'pads', type: 'i32', length: pads.length} - ]; - appendActivationUniforms(attributes, uniforms); - const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1); - if (elemType !== 'f16' && elemType !== 'f32') { - throw new Error(`elemType ${elemType} is not supported.`); - } - return ` + const uniforms: UniformsArrayType = [ + { name: 'dim_a_outer', type: 'i32' }, + { name: 'dim_b_outer', type: 'i32' }, + { name: 'dim_inner', type: 'i32' }, + { name: 'strides', type: 'i32', length: 2 }, + { name: 'dilations', type: 'i32', length: 2 }, + { name: 'filter_dims', type: 'i32', length: filterDims.length }, + { name: 'pads', type: 'i32', length: pads.length }, + ]; + appendActivationUniforms(attributes, uniforms); + const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1); + if (elemType !== 'f16' && elemType !== 'f32') { + throw new Error(`elemType ${elemType} is not supported.`); + } + return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)} ${ - isVec4 ? makeMatMulPackedVec4Source( - elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) : - makeMatMulPackedSource( - elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false, - undefined, sequentialAccessByThreads)}`; - }; + isVec4 + ? makeMatMulPackedVec4Source( + elementsPerThread, + workGroupSize, + elemType, + undefined, + !isChannelsLast, + tileInner, + ) + : makeMatMulPackedSource( + elementsPerThread, + workGroupSize, + elemType, + undefined, + !isChannelsLast, + tileInner, + false, + undefined, + sequentialAccessByThreads, + ) + }`; + }; - return { - name: 'Conv2DTransposeMatMul', - shaderCache: - {hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms - }), - getShaderSource - }; - }; + return { + name: 'Conv2DTransposeMatMul', + shaderCache: { hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, + programUniforms, + }), + getShaderSource, + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 45c89406e1731..2a8756e435b8e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -17,43 +17,57 @@ // sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts -import {DataType} from '../../../../wasm-common'; -import {LOG_DEBUG} from '../../../log'; -import {TensorView} from '../../../tensor-view'; -import {ShapeUtil} from '../../../util'; -import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; -import {ConvTransposeAttributes} from '../conv-transpose'; +import { DataType } from '../../../../wasm-common'; +import { LOG_DEBUG } from '../../../log'; +import { TensorView } from '../../../tensor-view'; +import { ShapeUtil } from '../../../util'; +import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types'; +import { + createTensorShapeVariables, + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from '../common'; +import { ConvTransposeAttributes } from '../conv-transpose'; -const createConvTranspose2DOpProgramShaderSource = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean, - is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType, - isChannelsLast = false): string => { - const rowDim = isChannelsLast ? 1 : 2; - const colDim = isChannelsLast ? 2 : 3; - const channelDim = isChannelsLast ? 3 : 1; - const workPerThread = isVec4 ? 2 : 1; +const createConvTranspose2DOpProgramShaderSource = ( + shaderHelper: ShaderHelper, + inputs: readonly TensorView[], + outputShape: readonly number[], + hasBias: boolean, + is1DimensionDispatch: boolean, + isVec4 = false, + dataType: string, + uniforms: UniformsArrayType, + isChannelsLast = false, +): string => { + const rowDim = isChannelsLast ? 1 : 2; + const colDim = isChannelsLast ? 2 : 3; + const channelDim = isChannelsLast ? 3 : 1; + const workPerThread = isVec4 ? 2 : 1; - let declareFunctions = ` + let declareFunctions = ` fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); }`; - if (hasBias) { - declareFunctions += ` + if (hasBias) { + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; - } - const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); - const inputVariables = [dy, w]; - if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + } + const components = isVec4 ? 4 : 1; + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); + const inputVariables = [dy, w]; + if (hasBias) { + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); + } + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - const codeSnippet4 = `{ + const codeSnippet4 = `{ let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; @@ -157,7 +171,7 @@ const createConvTranspose2DOpProgramShaderSource = ${output.set('batch', 'r', 'c + i', 'd1', 'value')}; } }`; - const codeSnippet = ` + const codeSnippet = ` let outputIndices = ${output.offsetToIndices('global_idx')}; let batch = ${output.indicesGet('outputIndices', 0)}; let d1 = ${output.indicesGet('outputIndices', channelDim)}; @@ -197,8 +211,10 @@ const createConvTranspose2DOpProgramShaderSource = var inputChannel = groupId * uniforms.input_channels_per_group; for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { let xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : - dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; + isChannelsLast + ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') + : dy.get('batch', 'inputChannel', 'idyR', 'idyC') + }; let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; dotProd = dotProd + xValue * wValue; inputChannel = inputChannel + 1; @@ -209,101 +225,113 @@ const createConvTranspose2DOpProgramShaderSource = ${output.setByOffset('global_idx', 'value')}; `; - return ` + return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; ${isVec4 ? codeSnippet4 : codeSnippet}}`; - }; +}; -export const createConvTranspose2DProgramInfo = - (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfo => { - const hasBias = inputs.length > 2; - // const isChannelsLast = attributes.format === 'NHWC'; - const outputShape = attributes.outputShape; - const outputSize = ShapeUtil.size(outputShape); +export const createConvTranspose2DProgramInfo = ( + inputs: readonly TensorView[], + attributes: ConvTransposeAttributes, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], +): ProgramInfo => { + const hasBias = inputs.length > 2; + // const isChannelsLast = attributes.format === 'NHWC'; + const outputShape = attributes.outputShape; + const outputSize = ShapeUtil.size(outputShape); - // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - // TODO Enable isVec4 for performance - // Disabled due to weight matrix layout issue - // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0; - const dispatch = [ - Math.ceil(outputSize / 64), - 1, - 1, - ]; - LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); + // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + // TODO Enable isVec4 for performance + // Disabled due to weight matrix layout issue + // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0; + const dispatch = [Math.ceil(outputSize / 64), 1, 1]; + LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const isChannelsLast = attributes.format === 'NHWC'; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; - const strides = [attributes.strides[0], attributes.strides[1]]; - const filterDims = - [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; - const dilations = [attributes.dilations[0], attributes.dilations[1]]; - const effectiveFilterDims = [ - filterDims[0] + - (attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)), - filterDims[1] + - (attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)) - ]; - const pads = [ - effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), - effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2 - ]; + const isChannelsLast = attributes.format === 'NHWC'; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + const strides = [attributes.strides[0], attributes.strides[1]]; + const filterDims = [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const dilations = [attributes.dilations[0], attributes.dilations[1]]; + const effectiveFilterDims = [ + filterDims[0] + + (attributes.dilations[0] <= 1 + ? 0 + : (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + + (attributes.dilations[1] <= 1 + ? 0 + : (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)), + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2, + ]; - const isVec4 = false; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; + const isVec4 = false; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[0] / group; + const outputChannelsPerGroup = wShape[1]; - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: strides}, - {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations}, - {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads}, - {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup}, - ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims) - ]; - if (hasBias) { - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); - } - programUniforms.push(...createTensorShapeVariables(outputShape)); + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: strides }, + { type: DataType.uint32, data: filterDims }, + { type: DataType.uint32, data: dilations }, + { type: DataType.uint32, data: effectiveFilterDims }, + { type: DataType.int32, data: pads }, + { type: DataType.uint32, data: inputChannelsPerGroup }, + { type: DataType.uint32, data: outputChannelsPerGroup }, + ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims), + ]; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; - const getShaderSource = (shaderHelper: ShaderHelper) => { - const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length}, - {name: 'filter_dims', type: 'u32', length: filterDims.length}, - {name: 'dilations', type: 'u32', length: filterDims.length}, - {name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length}, - {name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'}, - {name: 'output_channels_per_group', type: 'u32'} - ]; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - return `${ - createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms, - isChannelsLast)}`; - }; - return { - name: 'ConvTranspose2D', - shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies}, - getRunData: () => ({ - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - outputs: [{ - dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, - dataType: inputs[0].dataType - }], - programUniforms - }), - getShaderSource - }; - }; + const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'strides', type: 'u32', length: strides.length }, + { name: 'filter_dims', type: 'u32', length: filterDims.length }, + { name: 'dilations', type: 'u32', length: filterDims.length }, + { name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length }, + { name: 'pads', type: 'i32', length: pads.length }, + { name: 'input_channels_per_group', type: 'u32' }, + { name: 'output_channels_per_group', type: 'u32' }, + ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return `${createConvTranspose2DOpProgramShaderSource( + shaderHelper, + inputs, + outputShape, + hasBias, + is1DimensionDispatch, + isVec4, + dataType, + uniforms, + isChannelsLast, + )}`; + }; + return { + name: 'ConvTranspose2D', + shaderCache: { hint: `${attributes.cacheKey};`, inputDependencies }, + getRunData: () => ({ + dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, + outputs: [ + { + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType, + }, + ], + programUniforms, + }), + getShaderSource, + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts index 6f2c0231104dc..9bf9dda7c3b8a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts @@ -19,7 +19,7 @@ // // modified to fit the needs of the project -export const utilFunctions = (strideStr: string) => (` +export const utilFunctions = (strideStr: string) => ` fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 { return dot(coords, vec4( shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1)); @@ -28,4 +28,4 @@ fn getOutputIndexFromCoords(coords : vec4) -> i32 { return dot(coords, vec4( i32(${strideStr}.x), i32(${strideStr}.y), i32(${strideStr}.z), 1)); } -`); +`; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 9b37247167bab..f9bc015055c9f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -19,14 +19,29 @@ // // modified to fit the needs of the project -import {DataType} from '../../../../wasm-common'; -import {TensorView} from '../../../tensor-view'; -import {ShapeUtil} from '../../../util'; -import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; -import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; - -import {typeSnippet} from './activation_util'; +import { DataType } from '../../../../wasm-common'; +import { TensorView } from '../../../tensor-view'; +import { ShapeUtil } from '../../../util'; +import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types'; +import { + createTensorShapeVariables, + getBroadcastDims, + IndicesHelper, + inputVariable, + internalVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from '../common'; +import { + appendActivationUniforms, + appendActivationUniformsData, + getActivationSnippet, + InternalActivationAttributes, +} from '../fuse-utils'; + +import { typeSnippet } from './activation_util'; const writeDataToSubAVec4Snippet = (transpose: boolean, batchDims?: IndicesHelper) => { if (transpose) { @@ -35,7 +50,6 @@ const writeDataToSubAVec4Snippet = (transpose: boolean, batchDims?: IndicesHelpe kStart + inputRow, globalRowStart / innerElementSize + inputCol${batchDims ? ', batchIndices' : ''}); `; - } else { return ` mm_Asub[inputRow][inputCol] = mm_readA(batch, @@ -70,27 +84,41 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) = } }; -export const makeMatMulPackedVec4Source = - (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, - transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { - const tileAOuter = workgroupSize[1] * workPerThread[1]; - const tileBOuter = workgroupSize[0] * workPerThread[0]; - const tileAWidth = transposeA ? tileAOuter : tileInner; - const tileAHight = transposeA ? tileInner : tileAOuter; - const innerElementSize = tileAWidth / workgroupSize[0]; - const rowPerThreadB = tileInner / workgroupSize[1]; - - if (!(((transposeA && innerElementSize === 4 && workPerThread[1] === 4) || - (!transposeA && (innerElementSize === 3 || innerElementSize === 4))) && - tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0 && workPerThread[0] === 4)) { - throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${ - innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4. +export const makeMatMulPackedVec4Source = ( + workPerThread: number[], + workgroupSize: [number, number, number], + type = 'f32', + batchDims?: IndicesHelper, + transposeA = false, + tileInner = 32, + splitK = false, + splitedDimInner = 32, +): string => { + const tileAOuter = workgroupSize[1] * workPerThread[1]; + const tileBOuter = workgroupSize[0] * workPerThread[0]; + const tileAWidth = transposeA ? tileAOuter : tileInner; + const tileAHight = transposeA ? tileInner : tileAOuter; + const innerElementSize = tileAWidth / workgroupSize[0]; + const rowPerThreadB = tileInner / workgroupSize[1]; + + if ( + !( + ((transposeA && innerElementSize === 4 && workPerThread[1] === 4) || + (!transposeA && (innerElementSize === 3 || innerElementSize === 4))) && + tileAWidth % workgroupSize[0] === 0 && + tileInner % workgroupSize[1] === 0 && + workPerThread[0] === 4 + ) + ) { + throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${ + innerElementSize + } and workPerThread[1] ${workPerThread[1]} must be 4. Otherwise, innerElementSize ${innerElementSize} must be 3 or 4. tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${ - tileInner} must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${ - workPerThread[0]} must be 4.`); - } - return ` + tileInner + } must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${workPerThread[0]} must be 4.`); + } + return ` var mm_Asub: array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; var mm_Bsub: array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; @@ -133,7 +161,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, let inputRow = tileRowB + innerRow; let inputCol = tileCol; mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol${ - batchDims ? ', batchIndices' : ''}); + batchDims ? ', batchIndices' : '' + }); } kStart = kStart + tileInner; workgroupBarrier(); @@ -155,7 +184,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]); } }`; - }; +}; const writeDataToSubASnippet = (transpose: boolean, batchDims?: IndicesHelper) => { if (transpose) { @@ -164,7 +193,6 @@ const writeDataToSubASnippet = (transpose: boolean, batchDims?: IndicesHelper) = kStart + inputRow, globalRowStart + inputCol${batchDims ? ', batchIndices' : ''}); `; - } else { return ` mm_Asub[inputRow][inputCol] = mm_readA(batch, @@ -175,30 +203,42 @@ const writeDataToSubASnippet = (transpose: boolean, batchDims?: IndicesHelper) = }; const readDataFromSubASnippet = (transposeA: boolean) => - transposeA ? 'let ACached = mm_Asub[k][tileRow + innerRow];' : 'let ACached = mm_Asub[tileRow + innerRow][k];'; + transposeA ? 'let ACached = mm_Asub[k][tileRow + innerRow];' : 'let ACached = mm_Asub[tileRow + innerRow][k];'; // sequentialAccessByThreads means sequential data in memory is accessed by // threads, instead of a single thread (default behavior). -export const makeMatMulPackedSource = - (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, - transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, - sequentialAccessByThreads = false): string => { - const tileAOuter = workPerThread[1] * workgroupSize[1]; - const tileBOuter = workPerThread[0] * workgroupSize[0]; - const tileAWidth = transposeA ? tileAOuter : tileInner; - const tileAHight = transposeA ? tileInner : tileAOuter; - - if (!(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 && - tileInner % workgroupSize[1] === 0)) { - throw new Error(`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${ - workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${ - workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`); - } - const rowPerThreadA = tileAHight / workgroupSize[1]; - const colPerThreadA = tileAWidth / workgroupSize[0]; - const rowPerThreadB = tileInner / workgroupSize[1]; - const matmulSnippet = sequentialAccessByThreads ? - ` +export const makeMatMulPackedSource = ( + workPerThread: number[], + workgroupSize: [number, number, number], + type = 'f32', + batchDims?: IndicesHelper, + transposeA = false, + tileInner = 32, + splitK = false, + splitedDimInner = 32, + sequentialAccessByThreads = false, +): string => { + const tileAOuter = workPerThread[1] * workgroupSize[1]; + const tileBOuter = workPerThread[0] * workgroupSize[0]; + const tileAWidth = transposeA ? tileAOuter : tileInner; + const tileAHight = transposeA ? tileInner : tileAOuter; + + if ( + !(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0) + ) { + throw new Error( + `tileAHight ${tileAHight} must be divisible by workgroupSize[1]${ + workgroupSize[1] + }, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${ + workgroupSize[0] + }, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`, + ); + } + const rowPerThreadA = tileAHight / workgroupSize[1]; + const colPerThreadA = tileAWidth / workgroupSize[0]; + const rowPerThreadB = tileInner / workgroupSize[1]; + const matmulSnippet = sequentialAccessByThreads + ? ` let localRow = i32(localId.y); let localCol = i32(localId.x); let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; @@ -231,8 +271,10 @@ export const makeMatMulPackedSource = } for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { let ACached = ${ - transposeA ? `mm_Asub[k][localRow + innerRow * ${workgroupSize[1]}];` : - `mm_Asub[localRow + innerRow * ${workgroupSize[1]}][k];`} + transposeA + ? `mm_Asub[k][localRow + innerRow * ${workgroupSize[1]}];` + : `mm_Asub[localRow + innerRow * ${workgroupSize[1]}][k];` + } for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol]; @@ -248,8 +290,8 @@ export const makeMatMulPackedSource = mm_write(batch, gRow, gCol, acc[innerRow][innerCol]); } } - ` : - ` + ` + : ` let tileRow = i32(localId.y) * rowPerThread; let tileCol = i32(localId.x) * colPerThread; @@ -310,7 +352,7 @@ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { } `; - return ` + return ` var mm_Asub : array, ${tileAHight}>; var mm_Bsub : array, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; @@ -324,54 +366,62 @@ fn main(@builtin(local_invocation_id) localId : vec3, let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let num_tiles = ${ - splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; + splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1' + }; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; ${matmulSnippet} } `; - }; +}; -const matMulReadWriteFnSource = - (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[], - batchShapes: Array, isChannelsLast = false): string => { - const [batchAShape, batchBShape, batchShape] = batchShapes; - const [batchVariable, aVariable, bVariable, outputVariable] = variables; - const broadCastADims = getBroadcastDims(batchAShape, batchShape); - const broadCastBDims = getBroadcastDims(batchBShape, batchShape); - const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); - const getAIndices = () => { - const aRank = aVariable.rank; - const batchRank = batchVariable.rank; - let resStr = `var aIndices: ${aVariable.type.indices};`; - for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; - } - broadCastADims.forEach(i => { - resStr += `\naIndices[${i}] = 0;`; - }); - resStr += `\naIndices[${aRank - 2}] = u32(row); +const matMulReadWriteFnSource = ( + component: number, + hasBias: boolean, + applyActivation: string, + variables: IndicesHelper[], + batchShapes: Array, + isChannelsLast = false, +): string => { + const [batchAShape, batchBShape, batchShape] = batchShapes; + const [batchVariable, aVariable, bVariable, outputVariable] = variables; + const broadCastADims = getBroadcastDims(batchAShape, batchShape); + const broadCastBDims = getBroadcastDims(batchBShape, batchShape); + const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); + const getAIndices = () => { + const aRank = aVariable.rank; + const batchRank = batchVariable.rank; + let resStr = `var aIndices: ${aVariable.type.indices};`; + for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { + resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; + } + broadCastADims.forEach((i) => { + resStr += `\naIndices[${i}] = 0;`; + }); + resStr += `\naIndices[${aRank - 2}] = u32(row); aIndices[${aRank - 1}] = u32(colIn);`; - return resStr; - }; - const getBIndices = () => { - const bRank = bVariable.rank; - const batchRank = batchVariable.rank; - let resStr = `var bIndices: ${bVariable.type.indices};`; - for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; - } - broadCastBDims.forEach(i => { - resStr += `\nbIndices[${i}] = 0;`; - }); - resStr += `\nbIndices[${bRank - 2}] = u32(row); + return resStr; + }; + const getBIndices = () => { + const bRank = bVariable.rank; + const batchRank = batchVariable.rank; + let resStr = `var bIndices: ${bVariable.type.indices};`; + for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { + resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; + } + broadCastBDims.forEach((i) => { + resStr += `\nbIndices[${i}] = 0;`; + }); + resStr += `\nbIndices[${bRank - 2}] = u32(row); bIndices[${bRank - 1}] = u32(colIn);`; - return resStr; - }; - const source = ` - fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component, dataType)} { + return resStr; + }; + const source = ` + fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${typeSnippet( + component, + dataType, + )} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) @@ -382,8 +432,10 @@ const matMulReadWriteFnSource = return value; } - fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component, dataType)} { + fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${typeSnippet( + component, + dataType, + )} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) @@ -400,104 +452,120 @@ const matMulReadWriteFnSource = var value = valueIn; let coords = vec3(batch, row, colIn); ${ - hasBias ? - `value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` : - '' } + hasBias + ? `value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` + : '' + } ${applyActivation} ${outputVariable.setByIndices('vec3(coords)', 'value')} } } `; - return source; - }; + return source; +}; -export const createMatmulProgramInfo = - (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], - reshapedOutputShape?: readonly number[], - isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { - const aShape = inputs[0].dims; - const bShape = inputs[1].dims; - const outerDimsA = aShape.slice(0, -2); - const outerDimsB = bShape.slice(0, -2); - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const batchSize = ShapeUtil.size(outerDims); - const dimAOuter = aShape[aShape.length - 2]; - const dimInner = aShape[aShape.length - 1]; - const dimBOuter = bShape[bShape.length - 1]; - const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - - // TODO: fine tune size - const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; - const workgroupSize: [number, number, number] = [8, 8, 1]; - const dispatch = [ - Math.ceil(dimBOuter / workgroupSize[0] / elementsPerThread[0]), - Math.ceil(dimAOuter / workgroupSize[1] / elementsPerThread[1]), - Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) - ]; - - const components = isVec4 ? 4 : 1; - const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const aRank = aShapeTemp.length; - const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const bRank = bShapeTemp.length; - const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - const programUniforms: ProgramUniform[] = [ - {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, - {type: DataType.int32, data: dimInner} - ]; - appendActivationUniformsData(activationAttributes, programUniforms); - programUniforms.push(...createTensorShapeVariables(outerDims, aShapeTemp, bShapeTemp)); - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; - - const hasBias = inputs.length > 2; - if (hasBias) { - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); - } - programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const batchRank = outerDims.length; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - - const A = inputVariable('a', inputs[0].dataType, aRank, components); - const B = inputVariable('b', inputs[1].dataType, bRank, components); - const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); - const inputVariables = [A, B]; - if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); - } - const uniforms: UniformsArrayType = - [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; - appendActivationUniforms(activationAttributes, uniforms); - const baseType = tensorTypeToWsglStorageType(output.type.tensor); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); - const declareFunctions = matMulReadWriteFnSource( - components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], - isChannelsLast); - return ` - ${ - shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( - ...inputVariables, output)} +export const createMatmulProgramInfo = ( + inputs: readonly TensorView[], + activationAttributes: InternalActivationAttributes, + outputShape: readonly number[], + reshapedOutputShape?: readonly number[], + isChannelsLast = false /* only used for conv2dByMatMul*/, +): ProgramInfo => { + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; + const outerDimsA = aShape.slice(0, -2); + const outerDimsB = bShape.slice(0, -2); + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); + const batchSize = ShapeUtil.size(outerDims); + const dimAOuter = aShape[aShape.length - 2]; + const dimInner = aShape[aShape.length - 1]; + const dimBOuter = bShape[bShape.length - 1]; + const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; + + // TODO: fine tune size + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; + const workgroupSize: [number, number, number] = [8, 8, 1]; + const dispatch = [ + Math.ceil(dimBOuter / workgroupSize[0] / elementsPerThread[0]), + Math.ceil(dimAOuter / workgroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]), + ]; + + const components = isVec4 ? 4 : 1; + const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; + const aRank = aShapeTemp.length; + const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; + const bRank = bShapeTemp.length; + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; + const programUniforms: ProgramUniform[] = [ + { type: DataType.int32, data: dimAOuter }, + { type: DataType.int32, data: dimBOuter }, + { type: DataType.int32, data: dimInner }, + ]; + appendActivationUniformsData(activationAttributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(outerDims, aShapeTemp, bShapeTemp)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + + const hasBias = inputs.length > 2; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const A = inputVariable('a', inputs[0].dataType, aRank, components); + const B = inputVariable('b', inputs[1].dataType, bRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); + const inputVariables = [A, B]; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + } + const uniforms: UniformsArrayType = [ + { name: 'dim_a_outer', type: 'i32' }, + { name: 'dim_b_outer', type: 'i32' }, + { name: 'dim_inner', type: 'i32' }, + ]; + appendActivationUniforms(activationAttributes, uniforms); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); + const declareFunctions = matMulReadWriteFnSource( + components, + hasBias, + applyActivation, + [batchDims, A, B, output], + [outerDimsA, outerDimsB, outerDims], + isChannelsLast, + ); + return ` + ${shaderHelper + .registerUniforms(uniforms) + .registerInternalVariables(batchDims) + .declareVariables(...inputVariables, output)} ${declareFunctions} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + isVec4 + ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) + : makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims) + } `; - }; - return { - name: 'MatMul', - shaderCache: { - hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`, - inputDependencies - }, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms - }), - getShaderSource, - }; - }; + }; + return { + name: 'MatMul', + shaderCache: { + hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`, + inputDependencies, + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, + programUniforms, + }), + getShaderSource, + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts index 1f27525f370f3..efec6eaa207c7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts @@ -5,12 +5,12 @@ // performance limitations when the reduced axis is long. Need to add // a optimized codepath for this. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext } from '../types'; -import {createReduceProgramInfo, ReduceOp} from './reduce'; +import { createReduceProgramInfo, ReduceOp } from './reduce'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length === 0 || inputs.length > 2) { @@ -33,24 +33,33 @@ export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes) const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`input_indices[${k}] = 0;`); // first element + idxZero.push(`input_indices[${k}] = 0;`); // first element } } return [ - `${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`, + `${idxZero.join('\n')}`, + `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`, `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) { value = ${input.getByIndices('input_indices')}; best_index = i32(last_index); }`, - '', output.setByOffset('global_idx', 'best_index') + '', + output.setByOffset('global_idx', 'best_index'), ]; }; context.compute( - createReduceProgramInfo( - 'ArgMin', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp, - [attributes.axis], DataType.int64, attributes.keepDims), - {inputs: [0]}); + createReduceProgramInfo( + 'ArgMin', + { hint: attributes.cacheKey, inputDependencies: ['rank'] }, + [context.inputs[0]], + argMinMaxOp, + [attributes.axis], + DataType.int64, + attributes.keepDims, + ), + { inputs: [0] }, + ); }; export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => { @@ -59,25 +68,34 @@ export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes) const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`input_indices[${k}] = 0;`); // first element + idxZero.push(`input_indices[${k}] = 0;`); // first element } } return [ - `${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`, + `${idxZero.join('\n')}`, + `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`, `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) { value = ${input.getByIndices('input_indices')}; best_index = i32(last_index); }`, - '', output.setByOffset('global_idx', 'best_index') + '', + output.setByOffset('global_idx', 'best_index'), ]; }; context.compute( - createReduceProgramInfo( - 'argMax', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp, - [attributes.axis], DataType.int64, attributes.keepDims), - {inputs: [0]}); + createReduceProgramInfo( + 'argMax', + { hint: attributes.cacheKey, inputDependencies: ['rank'] }, + [context.inputs[0]], + argMinMaxOp, + [attributes.axis], + DataType.int64, + attributes.keepDims, + ), + { inputs: [0] }, + ); }; export const parseArgMinMaxAttributes = (attributes: Record): ArgMinMaxAttributes => - createAttributeWithCacheKey(attributes as Omit); + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 30a406cd21230..0008fd1aff62e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -1,35 +1,44 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; - -import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; + +import { + getMaxComponents, + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + tensorTypeToWsglValueType, + UniformDataElementType, + UniformsArrayType, +} from './common'; export const enum AttentionQkvFormat { - unknown, // enum value not set, or depends on qkv projection implementation details - qkvBNSH, // for non-packed qkv, permuted - qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention - qkvBSN3H, // for TRT fused attention, qkv are packed - qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) - qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed - qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed. - qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed + unknown, // enum value not set, or depends on qkv projection implementation details + qkvBNSH, // for non-packed qkv, permuted + qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + qkvBSN3H, // for TRT fused attention, qkv are packed + qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed + qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed } export const enum AttentionMaskType { - none, // No mask - mask1dKeySeqLen, // [batch_size], key sequence length - mask1dEndStart, // [2 * batch_size] with end positions and start positions - mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], - // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., - // key_start[batch_size - 1], key_end[batch_size - 1]] - mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. - mask2dKeyPadding, // [batch_size, total_sequence_length] - mask3dAttention, // [batch_size, sequence_length, total_sequence_length] - mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] - maskUnknown + none, // No mask + mask1dKeySeqLen, // [batch_size], key sequence length + mask1dEndStart, // [2 * batch_size] with end positions and start positions + mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + mask2dKeyPadding, // [batch_size, total_sequence_length] + mask3dAttention, // [batch_size, sequence_length, total_sequence_length] + mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + maskUnknown, } export interface AttentionParameters { @@ -243,8 +252,9 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor } const elementsPerThread = Math.ceil(d / components / WG); const programUniforms: ProgramUniform[] = [ - {type: DataType.float, data: 1 / d}, {type: DataType.uint32, data: dComp}, - {type: DataType.uint32, data: elementsPerThread} + { type: DataType.float, data: 1 / d }, + { type: DataType.uint32, data: dComp }, + { type: DataType.uint32, data: elementsPerThread }, ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const f32Type = tensorTypeToWsglValueType(DataType.float, components); @@ -252,16 +262,17 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor const getShaderSource = (shaderHelper: ShaderHelper) => { const inputHelper = outputVariable('x', input.dataType, input.dims, components); const elemValueType = tensorTypeToWsglValueType(input.dataType); - const uniforms: UniformsArrayType = - [{name: 'd_inv', type: 'f32'}, {name: 'd_comp', type: 'u32'}, {name: 'elements_per_thread', type: 'u32'}]; + const uniforms: UniformsArrayType = [ + { name: 'd_inv', type: 'f32' }, + { name: 'd_comp', type: 'u32' }, + { name: 'elements_per_thread', type: 'u32' }, + ]; return ` var thread_max: array; var thread_sum: array; ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)} - ${shaderHelper.mainStart([ - WG, 1, 1 - ])} + ${shaderHelper.mainStart([WG, 1, 1])} let local_offset = local_idx * uniforms.elements_per_thread; let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset; @@ -326,100 +337,110 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor return { name: 'AttentionProbsSoftmax', - shaderCache: {hint: `${WG};${dataType};${components}`}, + shaderCache: { hint: `${WG};${dataType};${components}` }, getShaderSource, - getRunData: () => ({outputs: [], dispatchGroup: {x: n}, programUniforms}), + getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }), }; }; -const createAttentionProbsProgramInfo = - (context: ComputeContext, q: TensorView, key: TensorView, pastKey: TensorView|undefined, - relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs, - pastSequenceLength: number) => { - const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; - const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; - const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1; - const presentKeyShape = presentKey ? - [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] : - undefined; - - // TODO: handle mask - - const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; - const components = getMaxComponents(parameters.headSize); - const vectorizedHeadSize = parameters.headSize / components; - const TILE_SIZE = 12; - const dispatch = { - x: Math.ceil(totalSequenceLength / TILE_SIZE), - y: Math.ceil(parameters.sequenceLength / TILE_SIZE), - z: parameters.batchSize * parameters.numHeads - }; - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, - {type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads}, - {type: DataType.float, data: alpha}, {type: DataType.uint32, data: pastSequenceLength}, - {type: DataType.uint32, data: parameters.kvSequenceLength} - ]; - - const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; - if (pastKey) { - inputDependencies.push('type'); - } - if (relativePositionBias) { - inputDependencies.push('type'); - } - const outputs = [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}]; - if (presentKey) { - outputs.push({dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default}); - } - const getShaderSource = (shaderHelper: ShaderHelper) => { - const qInput = inputVariable('q', q.dataType, q.dims, components); - const kInput = inputVariable('key', key.dataType, key.dims, components); - const inputVars = [qInput, kInput]; - if (pastKey) { - const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components); - inputVars.push(pastKeyInput); - } - if (relativePositionBias) { - inputVars.push( - inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims)); - } - const output = outputVariable('output', q.dataType, probsShape); - const outputVars = [output]; - if (presentKey) { - outputVars.push(outputVariable('present_key', q.dataType, presentKeyShape!, components)); - } - const f32Type = tensorTypeToWsglValueType(DataType.float, components); +const createAttentionProbsProgramInfo = ( + context: ComputeContext, + q: TensorView, + key: TensorView, + pastKey: TensorView | undefined, + relativePositionBias: TensorView | undefined, + parameters: AttentionParameters, + attributes: AttentionAttrs, + pastSequenceLength: number, +) => { + const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; + const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; + const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1; + const presentKeyShape = presentKey + ? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] + : undefined; + + // TODO: handle mask + + const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + const components = getMaxComponents(parameters.headSize); + const vectorizedHeadSize = parameters.headSize / components; + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(totalSequenceLength / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads, + }; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: parameters.sequenceLength }, + { type: DataType.uint32, data: vectorizedHeadSize }, + { type: DataType.uint32, data: totalSequenceLength }, + { type: DataType.uint32, data: parameters.numHeads }, + { type: DataType.float, data: alpha }, + { type: DataType.uint32, data: pastSequenceLength }, + { type: DataType.uint32, data: parameters.kvSequenceLength }, + ]; - const uniforms: UniformsArrayType = [ - {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, - {name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType}, - {name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'} - ]; - return ` + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + if (pastKey) { + inputDependencies.push('type'); + } + if (relativePositionBias) { + inputDependencies.push('type'); + } + const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }]; + if (presentKey) { + outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default }); + } + const getShaderSource = (shaderHelper: ShaderHelper) => { + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const inputVars = [qInput, kInput]; + if (pastKey) { + const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components); + inputVars.push(pastKeyInput); + } + if (relativePositionBias) { + inputVars.push(inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims)); + } + const output = outputVariable('output', q.dataType, probsShape); + const outputVars = [output]; + if (presentKey) { + outputVars.push(outputVariable('present_key', q.dataType, presentKeyShape!, components)); + } + const f32Type = tensorTypeToWsglValueType(DataType.float, components); + + const uniforms: UniformsArrayType = [ + { name: 'M', type: 'u32' }, + { name: 'K', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'num_heads', type: 'u32' }, + { name: 'alpha', type: 'f32' as UniformDataElementType }, + { name: 'past_sequence_length', type: 'u32' }, + { name: 'kv_sequence_length', type: 'u32' }, + ]; + return ` const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} - ${shaderHelper.mainStart([ - TILE_SIZE, TILE_SIZE, 1 - ])} + ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} // x holds the N and y holds the M let headIdx = workgroup_id.z; let m = workgroup_id.y * TILE_SIZE; let n = workgroup_id.x * TILE_SIZE; let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; ${(() => { - if (pastKey && presentKey) { - return ` + if (pastKey && presentKey) { + return ` let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx; let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`; - } else { - return ` + } else { + return ` let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`; - } - })()} + } + })()} ${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''} var value = ${f32Type}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { @@ -429,22 +450,21 @@ const createAttentionProbsProgramInfo = if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) { var idx = TILE_SIZE * local_id.y + local_id.x; ${(() => { - if (pastKey && presentKey) { - return ` + if (pastKey && presentKey) { + return ` if (n + local_id.y < uniforms.past_sequence_length) { tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; } else { tileK[idx] = key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x]; }`; - } else { - return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];'; - } - })()} + } else { + return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];'; + } + })()} ${ - presentKey ? - 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : - ''} + presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : '' + } } workgroupBarrier(); @@ -459,105 +479,115 @@ const createAttentionProbsProgramInfo = if (global_id.y < uniforms.M && global_id.x < uniforms.N) { let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x; var sum: f32 = ${(() => { - switch (components) { - case 1: - return 'value'; - case 2: - return 'value.x + value.y'; - case 4: - return 'value.x + value.y + value.z + value.w'; - default: - throw new Error(`Unsupported components: ${components}`); - } - })()}; + switch (components) { + case 1: + return 'value'; + case 2: + return 'value.x + value.y'; + case 4: + return 'value.x + value.y + value.z + value.w'; + default: + throw new Error(`Unsupported components: ${components}`); + } + })()}; output[outputIdx] = ${output.type.value} (sum * uniforms.alpha) + ${ - relativePositionBias ? 'relative_position_bias[outputIdx]' : '0.0'}; + relativePositionBias ? 'relative_position_bias[outputIdx]' : '0.0' + }; } }`; - }; - return { - name: 'AttentionProbs', - shaderCache: { - hint: `${components};${relativePositionBias !== undefined};${pastKey !== undefined};${context.outputCount}`, - inputDependencies - }, - getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}), - getShaderSource, - }; - }; - - -const createVxAttentionScoreProgramInfo = - (context: ComputeContext, probs: TensorView, v: TensorView, pastValue: TensorView|undefined, - params: AttentionParameters, pastSequenceLength: number) => { - const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; - const nReps = params.nReps ? params.nReps : 1; - const repeatedVHiddenSize = params.vHiddenSize * nReps; - const presentValue = params.kvNumHeads == null && context.outputCount > 1; - const presentValueShape = - presentValue ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] : undefined; - const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize]; - const TILE_SIZE = 12; - const dispatch = { - x: Math.ceil(params.vHeadSize / TILE_SIZE), - y: Math.ceil(params.sequenceLength / TILE_SIZE), - z: params.batchSize * params.numHeads - }; - - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength}, - {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, - {type: DataType.uint32, data: repeatedVHiddenSize}, {type: DataType.uint32, data: pastSequenceLength}, - {type: DataType.uint32, data: params.kvSequenceLength} - ]; - const inputDependencies: ProgramInputTensorInfoDependency[] = - pastValue ? ['type', 'type', 'type'] : ['type', 'type']; - const outputs = [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}]; - if (presentValue) { - outputs.push({dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default}); - } - const getShaderSource = (shaderHelper: ShaderHelper) => { - const probsHelper = inputVariable('probs', probs.dataType, probs.dims); - const vHelper = inputVariable('v', v.dataType, v.dims); - const inputVars = [probsHelper, vHelper]; - if (pastValue) { - inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims)); - } - const output = outputVariable('output', probs.dataType, outputShape); - const outputVars = [output]; - if (presentValue) { - outputVars.push(outputVariable('present_value', probs.dataType, presentValueShape!)); - } - const uniforms: UniformsArrayType = [ - {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, - {name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'}, - {name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'} - ]; - return ` + }; + return { + name: 'AttentionProbs', + shaderCache: { + hint: `${components};${relativePositionBias !== undefined};${pastKey !== undefined};${context.outputCount}`, + inputDependencies, + }, + getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }), + getShaderSource, + }; +}; + +const createVxAttentionScoreProgramInfo = ( + context: ComputeContext, + probs: TensorView, + v: TensorView, + pastValue: TensorView | undefined, + params: AttentionParameters, + pastSequenceLength: number, +) => { + const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; + const nReps = params.nReps ? params.nReps : 1; + const repeatedVHiddenSize = params.vHiddenSize * nReps; + const presentValue = params.kvNumHeads == null && context.outputCount > 1; + const presentValueShape = presentValue + ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] + : undefined; + const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize]; + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(params.vHeadSize / TILE_SIZE), + y: Math.ceil(params.sequenceLength / TILE_SIZE), + z: params.batchSize * params.numHeads, + }; + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: params.sequenceLength }, + { type: DataType.uint32, data: totalSequenceLength }, + { type: DataType.uint32, data: params.vHeadSize }, + { type: DataType.uint32, data: params.numHeads }, + { type: DataType.uint32, data: repeatedVHiddenSize }, + { type: DataType.uint32, data: pastSequenceLength }, + { type: DataType.uint32, data: params.kvSequenceLength }, + ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = pastValue ? ['type', 'type', 'type'] : ['type', 'type']; + const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }]; + if (presentValue) { + outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default }); + } + const getShaderSource = (shaderHelper: ShaderHelper) => { + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const inputVars = [probsHelper, vHelper]; + if (pastValue) { + inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims)); + } + const output = outputVariable('output', probs.dataType, outputShape); + const outputVars = [output]; + if (presentValue) { + outputVars.push(outputVariable('present_value', probs.dataType, presentValueShape!)); + } + const uniforms: UniformsArrayType = [ + { name: 'M', type: 'u32' }, + { name: 'K', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'num_heads', type: 'u32' }, + { name: 'v_hidden_size', type: 'u32' }, + { name: 'past_sequence_length', type: 'u32' }, + { name: 'kv_sequence_length', type: 'u32' }, + ]; + return ` const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} - ${shaderHelper.mainStart([ - TILE_SIZE, TILE_SIZE, 1 - ])} + ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} let headIdx = workgroup_id.z; let m = global_id.y; let n = global_id.x; let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; ${(() => { - if (pastValue && presentValue) { - return ` + if (pastValue && presentValue) { + return ` let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n; let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n; `; - } else { - return ` + } else { + return ` let offsetB = headIdx * uniforms.N * uniforms.K + n; `; - } - })()} + } + })()} ${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''} var value = ${probsHelper.type.storage}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { @@ -599,60 +629,82 @@ const createVxAttentionScoreProgramInfo = output[outputIdx] = value; } }`; - }; - - return { - name: 'AttentionScore', - shaderCache: {hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies}, - getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}), - getShaderSource, - }; - }; - -export const applyAttention = - (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, - _past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined, - relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { - const outputCount = context.outputCount; - const pastSequenceLength = - parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; - const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; - - const inputsK = (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey) ? [q, k, pastKey] : [q, k]; - if (relativePositionBias) { - inputsK.push(relativePositionBias); - } + }; + + return { + name: 'AttentionScore', + shaderCache: { hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies }, + getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }), + getShaderSource, + }; +}; - // Run AttentionProbs - const probs = context.compute( - createAttentionProbsProgramInfo( - context, q, k, outputCount > 1 ? pastKey : undefined, relativePositionBias, parameters, attributes, - pastSequenceLength), - {inputs: inputsK, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [-1, 1] : [-1]})[0]; - - // Run Softmax - context.compute( - createInPlaceSoftmaxProgramInfo( - context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, - totalSequenceLength), - {inputs: [probs], outputs: []}); - - // Run AttrionScore - const inputsV = - (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue) ? [probs, v, pastValue] : [probs, v]; - context.compute( - createVxAttentionScoreProgramInfo( - context, probs, v, outputCount > 1 && pastValue ? pastValue : undefined, parameters, pastSequenceLength), - {inputs: inputsV, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [0, 2] : [0]}); - }; +export const applyAttention = ( + context: ComputeContext, + q: TensorView, + k: TensorView, + v: TensorView, + _maskIndex: TensorView | undefined, + _past: TensorView | undefined, + pastKey: TensorView | undefined, + pastValue: TensorView | undefined, + relativePositionBias: TensorView | undefined, + parameters: AttentionParameters, + attributes: AttentionAttrs, +) => { + const outputCount = context.outputCount; + const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; + const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; + + const inputsK = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey ? [q, k, pastKey] : [q, k]; + if (relativePositionBias) { + inputsK.push(relativePositionBias); + } + + // Run AttentionProbs + const probs = context.compute( + createAttentionProbsProgramInfo( + context, + q, + k, + outputCount > 1 ? pastKey : undefined, + relativePositionBias, + parameters, + attributes, + pastSequenceLength, + ), + { inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] }, + )[0]; + + // Run Softmax + context.compute( + createInPlaceSoftmaxProgramInfo( + context, + probs, + parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + totalSequenceLength, + ), + { inputs: [probs], outputs: [] }, + ); + + // Run AttrionScore + const inputsV = + parameters.kvNumHeads === undefined && outputCount > 1 && pastValue ? [probs, v, pastValue] : [probs, v]; + context.compute( + createVxAttentionScoreProgramInfo( + context, + probs, + v, + outputCount > 1 && pastValue ? pastValue : undefined, + parameters, + pastSequenceLength, + ), + { inputs: inputsV, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0] }, + ); +}; const prepare = (context: ComputeContext, parameters: AttentionParameters) => { - const outputShape = [ - parameters.batchSize, - parameters.numHeads, - parameters.sequenceLength, - parameters.headSize, - ]; + const outputShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, parameters.headSize]; const M = parameters.sequenceLength; const K = parameters.inputHiddenSize; const N = parameters.headSize; @@ -660,14 +712,17 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { const dispatch = { x: Math.ceil(parameters.headSize / TILE_SIZE), y: Math.ceil(parameters.sequenceLength / TILE_SIZE), - z: parameters.batchSize * parameters.numHeads + z: parameters.batchSize * parameters.numHeads, }; const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N}, - {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize}, - {type: DataType.uint32, data: parameters.hiddenSize}, - {type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} + { type: DataType.uint32, data: M }, + { type: DataType.uint32, data: K }, + { type: DataType.uint32, data: N }, + { type: DataType.uint32, data: parameters.numHeads }, + { type: DataType.uint32, data: parameters.headSize }, + { type: DataType.uint32, data: parameters.hiddenSize }, + { type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize }, ]; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -680,8 +735,13 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { const dataType = input.type.storage; const uniforms: UniformsArrayType = [ - {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'num_heads', type: 'u32'}, - {name: 'head_size', type: 'u32'}, {name: 'hidden_size', type: 'u32'}, {name: 'ldb', type: 'u32'} + { name: 'M', type: 'u32' }, + { name: 'K', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'num_heads', type: 'u32' }, + { name: 'head_size', type: 'u32' }, + { name: 'hidden_size', type: 'u32' }, + { name: 'ldb', type: 'u32' }, ]; return ` const TILE_SIZE = ${TILE_SIZE}u; @@ -690,9 +750,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.registerUniforms(uniforms).declareVariables(input, weight, bias, outputQ, outputK, outputV)} - ${shaderHelper.mainStart([ - TILE_SIZE, TILE_SIZE, 1 - ])} + ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} let batchIndex = workgroup_id.z / uniforms.num_heads; let headNumber = workgroup_id.z % uniforms.num_heads; let m = global_id.y; @@ -744,21 +802,22 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { }; return context.compute( - { - name: 'AttentionPrepare', - shaderCache: {inputDependencies: ['type', 'type', 'type']}, - getRunData: () => ({ - outputs: [ - {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, - {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, - {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, - ], - dispatchGroup: dispatch, - programUniforms - }), - getShaderSource, - }, - {inputs, outputs: [-1, -1, -1]}); + { + name: 'AttentionPrepare', + shaderCache: { inputDependencies: ['type', 'type', 'type'] }, + getRunData: () => ({ + outputs: [ + { dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default }, + { dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default }, + { dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default }, + ], + dispatchGroup: dispatch, + programUniforms, + }), + getShaderSource, + }, + { inputs, outputs: [-1, -1, -1] }, + ); }; export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => { @@ -767,5 +826,16 @@ export const attention = (context: ComputeContext, attributes: AttentionAttrs): const [q, k, v] = prepare(context, params); return applyAttention( - context, q, k, v, context.inputs[4], undefined, undefined, undefined, context.inputs[5], params, attributes); + context, + q, + k, + v, + context.inputs[4], + undefined, + undefined, + undefined, + context.inputs[5], + params, + attributes, + ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index 39b932375891b..b0d21297a1b24 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -1,22 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env} from 'onnxruntime-common'; +import { env } from 'onnxruntime-common'; -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo } from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper } from './common'; export interface BatchNormAttributes extends AttributeWithCacheKey { readonly epsilon: number; readonly momentum: number; readonly spatial: boolean; readonly trainingMode: boolean; - readonly format: 'NHWC'|'NCHW'; + readonly format: 'NHWC' | 'NCHW'; readonly outputCount: number; } @@ -38,10 +38,12 @@ const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttr }; if (inputs[0].dims.length > 1) { - const shape = attributes.format === 'NHWC' ? - (attributes.spatial ? inputs[0].dims.slice(-1) : - inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1))) : - inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined); + const shape = + attributes.format === 'NHWC' + ? attributes.spatial + ? inputs[0].dims.slice(-1) + : inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1)) + : inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined); checkShapeEqual(inputs[1].dims, shape, 'Invalid input scale'); checkShapeEqual(inputs[2].dims, shape, 'Invalid input B'); checkShapeEqual(inputs[3].dims, shape, 'Invalid input mean'); @@ -54,50 +56,55 @@ const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttr } }; -const createBatchNormInferenceProgramInfo = - (inputs: readonly TensorView[], attributes: BatchNormAttributes): ProgramInfo => { - const {epsilon, spatial, format} = attributes; - const yShape = inputs[0].dims; - const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1; - const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; - const outputSize = ShapeUtil.size(yShape) / components; - // Only support uniforms for opset version >= 9 (spatial = true). - const useShapesUniforms = spatial; - const shapeOrRank = useShapesUniforms ? yShape.length : yShape; - const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); - const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents); - const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents); - const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents); - const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components); - // TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type. - // Otherwise, the shader compilation will fail. - const calcCOffset = (): string => { - let cOffset = ''; - if (spatial) { - cOffset = `let cOffset = ${ - yShape.length === 1 ? '0u' : - format === 'NHWC' ? `outputIndices[${yShape.length - 1}] / ${components}` : - 'outputIndices[1]'};`; - } else { - if (format === 'NCHW') { - cOffset = ` +const createBatchNormInferenceProgramInfo = ( + inputs: readonly TensorView[], + attributes: BatchNormAttributes, +): ProgramInfo => { + const { epsilon, spatial, format } = attributes; + const yShape = inputs[0].dims; + const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1; + const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; + const outputSize = ShapeUtil.size(yShape) / components; + // Only support uniforms for opset version >= 9 (spatial = true). + const useShapesUniforms = spatial; + const shapeOrRank = useShapesUniforms ? yShape.length : yShape; + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents); + const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents); + const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents); + const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components); + // TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type. + // Otherwise, the shader compilation will fail. + const calcCOffset = (): string => { + let cOffset = ''; + if (spatial) { + cOffset = `let cOffset = ${ + yShape.length === 1 + ? '0u' + : format === 'NHWC' + ? `outputIndices[${yShape.length - 1}] / ${components}` + : 'outputIndices[1]' + };`; + } else { + if (format === 'NCHW') { + cOffset = ` ${y.indicesSet('outputIndices', '0', '0')} let cOffset = ${y.indicesToOffset('outputIndices')};`; - } else { - // update C channel. - cOffset = `var cIndices = ${scale.type.indices}(0); + } else { + // update C channel. + cOffset = `var cIndices = ${scale.type.indices}(0); cIndices[0] = outputIndices[${yShape.length - 1}];`; - // update D1 x ... x Dn channels. - for (let i = 1; i < scale.rank; i++) { - cOffset += `cIndices[${i}] = outputIndices[${i}];`; - } - cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`; - } + // update D1 x ... x Dn channels. + for (let i = 1; i < scale.rank; i++) { + cOffset += `cIndices[${i}] = outputIndices[${i}];`; } - return cOffset; - }; - const getInferenceModeShaderSource = (helper: ShaderHelper) => ` + cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`; + } + } + return cOffset; + }; + const getInferenceModeShaderSource = (helper: ShaderHelper) => ` const epsilon = ${epsilon}; ${helper.registerUniform('outputSize', 'u32').declareVariables(x, scale, bias, inputMean, inputVar, y)} ${helper.mainStart()} @@ -112,34 +119,29 @@ const createBatchNormInferenceProgramInfo = let value = (x - inputMean) * inverseSqrt(inputVar + epsilon) * scale + bias; ${y.setByOffset('global_idx', 'value')} }`; - return { - name: 'BatchNormalization', - shaderCache: { - hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`, - inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined, - }, - getShaderSource: getInferenceModeShaderSource, - getRunData: () => ({ - outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: useShapesUniforms ? - [ - {type: DataType.uint32, data: outputSize}, - ...createTensorShapeVariables(yShape), - ] : - [ - {type: DataType.uint32, data: outputSize}, - ], - }), - }; - }; + return { + name: 'BatchNormalization', + shaderCache: { + hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`, + inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined, + }, + getShaderSource: getInferenceModeShaderSource, + getRunData: () => ({ + outputs: [{ dims: inputs[0].dims, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms: useShapesUniforms + ? [{ type: DataType.uint32, data: outputSize }, ...createTensorShapeVariables(yShape)] + : [{ type: DataType.uint32, data: outputSize }], + }), + }; +}; export const parseBatchNormAttributes = (attributes: Record): BatchNormAttributes => - createAttributeWithCacheKey(attributes as Omit); + createAttributeWithCacheKey(attributes as Omit); export const batchNorm = (context: ComputeContext, attributes: Record): void => { - const {inputs, outputCount} = context; - const updatedAttributes = parseBatchNormAttributes({...attributes, outputCount}); + const { inputs, outputCount } = context; + const updatedAttributes = parseBatchNormAttributes({ ...attributes, outputCount }); if (env.webgpu.validateInputContent) { validateInputs(inputs, updatedAttributes); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts index e2b8412000ef9..dd59d5f03d47d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo} from '../types'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo } from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import { inputVariable, outputVariable, ShaderHelper } from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims.length !== 3) { @@ -52,8 +52,8 @@ const createBiasAddProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => return { name: 'BiasAdd', getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index 089fecd758e30..78de2d91d89ad 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo} from '../types'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo } from '../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; -import {erfImpl} from './unary-op'; +import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType } from './common'; +import { erfImpl } from './unary-op'; const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims.length !== 3) { @@ -60,8 +60,8 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI return { name: 'BiasSplitGelu', getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index a094fffe239c4..53c2ca2fa47d6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -1,82 +1,100 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {BroadcastUtil, ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { BroadcastUtil, ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo } from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common'; type BuiltinFunctionName = string; type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; -type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ - scalar: BinaryCustomExpression; - vector: BinaryCustomExpression; -}; +type BinaryFunctionCall = + | BuiltinFunctionName + | BinaryCustomExpression + | { + scalar: BinaryCustomExpression; + vector: BinaryCustomExpression; + }; -const createBinaryOpProgramShader = - (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], - vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, - typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => { - let expressionScalar: BinaryCustomExpression; - let expressionVector: BinaryCustomExpression; - if (typeof funcCall === 'string') { - expressionScalar = expressionVector = (a, b) => `${funcCall}((${a}),(${b}))`; - } else if (typeof funcCall === 'function') { - expressionScalar = expressionVector = funcCall; - } else { - expressionScalar = funcCall.scalar; - expressionVector = funcCall.vector; - } +const createBinaryOpProgramShader = ( + shaderHelper: ShaderHelper, + dimsA: readonly number[], + dimsB: readonly number[], + dimsOutput: readonly number[], + vectorize: boolean, + doBroadcast: boolean, + sharedDimensionDivisibleBy4: boolean, + funcCall: BinaryFunctionCall, + typeA: number, + typeB: number, + typeOutput: number, + additionalImplementation?: string, +) => { + let expressionScalar: BinaryCustomExpression; + let expressionVector: BinaryCustomExpression; + if (typeof funcCall === 'string') { + expressionScalar = expressionVector = (a, b) => `${funcCall}((${a}),(${b}))`; + } else if (typeof funcCall === 'function') { + expressionScalar = expressionVector = funcCall; + } else { + expressionScalar = funcCall.scalar; + expressionVector = funcCall.vector; + } - const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4); - const a = inputVariable('aData', typeA, dimsA.length, 4); - const b = inputVariable('bData', typeB, dimsB.length, 4); + const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4); + const a = inputVariable('aData', typeA, dimsA.length, 4); + const b = inputVariable('bData', typeB, dimsB.length, 4); - let assignment: string; - if (vectorize) { - if (doBroadcast) { - const isAOneElement = ShapeUtil.size(dimsA) === 1; - const isBOneElement = ShapeUtil.size(dimsB) === 1; - const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0; - const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0; - if (isAOneElement || isBOneElement) { - assignment = output.setByOffset( - 'global_idx', - expressionVector( - isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'), - isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx'))); - } else { - assignment = ` + let assignment: string; + if (vectorize) { + if (doBroadcast) { + const isAOneElement = ShapeUtil.size(dimsA) === 1; + const isBOneElement = ShapeUtil.size(dimsB) === 1; + const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0; + if (isAOneElement || isBOneElement) { + assignment = output.setByOffset( + 'global_idx', + expressionVector( + isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'), + isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx'), + ), + ); + } else { + assignment = ` let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; let offsetA = ${a.broadcastedIndicesToOffset('outputIndices', output)}; let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)}; - ${ - output.setByOffset( - 'global_idx', - expressionVector( - sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4 ? - a.getByOffset('offsetA / 4u') : - `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`, - sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4 ? - b.getByOffset('offsetB / 4u') : - `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`))} + ${output.setByOffset( + 'global_idx', + expressionVector( + sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4 + ? a.getByOffset('offsetA / 4u') + : `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`, + sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4 + ? b.getByOffset('offsetB / 4u') + : `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`, + ), + )} `; - } - } else { - assignment = output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx'))); - } - } else { - if (!doBroadcast) { - throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.'); - } + } + } else { + assignment = output.setByOffset( + 'global_idx', + expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx')), + ); + } + } else { + if (!doBroadcast) { + throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.'); + } - const singleAssignment = (resStr: string, x: number, typeCast = '') => { - const expressionA = `aData[indexA${x}][componentA${x}]`; - const expressionB = `bData[indexB${x}][componentB${x}]`; - return ` + const singleAssignment = (resStr: string, x: number, typeCast = '') => { + const expressionA = `aData[indexA${x}][componentA${x}]`; + const expressionB = `bData[indexB${x}][componentB${x}]`; + return ` let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; @@ -86,26 +104,26 @@ const createBinaryOpProgramShader = let componentB${x} = offsetB${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB)}); `; - }; - if (typeOutput === DataType.bool) { - assignment = ` + }; + if (typeOutput === DataType.bool) { + assignment = ` var data = vec4(0); ${singleAssignment('data', 0, 'u32')} ${singleAssignment('data', 1, 'u32')} ${singleAssignment('data', 2, 'u32')} ${singleAssignment('data', 3, 'u32')} outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; - } else { - assignment = ` + } else { + assignment = ` ${singleAssignment('outputData[global_idx]', 0)} ${singleAssignment('outputData[global_idx]', 1)} ${singleAssignment('outputData[global_idx]', 2)} ${singleAssignment('outputData[global_idx]', 3)} `; - } - } + } + } - return ` + return ` ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(a, b, output)} ${additionalImplementation ?? ''} @@ -114,85 +132,116 @@ const createBinaryOpProgramShader = ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} ${assignment} }`; - }; +}; -const createBinaryOpProgramInfo = - (name: string, cacheKey: string, a: TensorView, b: TensorView, funcCall: BinaryFunctionCall, - additionalImplementation?: string, outputDataType: number = a.dataType): ProgramInfo => { - const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims); - let outputShape = a.dims; - let outputSize = ShapeUtil.size(a.dims); +const createBinaryOpProgramInfo = ( + name: string, + cacheKey: string, + a: TensorView, + b: TensorView, + funcCall: BinaryFunctionCall, + additionalImplementation?: string, + outputDataType: number = a.dataType, +): ProgramInfo => { + const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims); + let outputShape = a.dims; + let outputSize = ShapeUtil.size(a.dims); - let vectorize = false; - let sharedDimensionDivisibleBy4 = false; + let vectorize = false; + let sharedDimensionDivisibleBy4 = false; - // TODO: deal with zero-sized tensors (eg. dims=[1,0]) - const cacheKeyAux = [isBroadcast]; - if (isBroadcast) { - const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); - if (!calculatedShape) { - throw new Error('Can\'t perform binary op on the given tensors'); - } - outputShape = calculatedShape; - outputSize = ShapeUtil.size(outputShape); - const isAOneElement = ShapeUtil.size(a.dims) === 1; - const isBOneElement = ShapeUtil.size(b.dims) === 1; - const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0; - const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0; - cacheKeyAux.push(isAOneElement); - cacheKeyAux.push(isBOneElement); - cacheKeyAux.push(aLastDimDivisibleBy4); - cacheKeyAux.push(bLastDimDivisibleBy4); - // check whether vectorize can be enabled - let sharedDimension = 1; - for (let i = 1; i < outputShape.length; i++) { - const dimA = a.dims[a.dims.length - i] ?? 1; - const dimB = b.dims[b.dims.length - i] ?? 1; - if (dimA === dimB) { - sharedDimension *= dimA; - } else { - break; - } - } - if (sharedDimension % 4 === 0) { - sharedDimensionDivisibleBy4 = true; - vectorize = true; - } else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) { - vectorize = true; - } + // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + const cacheKeyAux = [isBroadcast]; + if (isBroadcast) { + const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); + if (!calculatedShape) { + throw new Error("Can't perform binary op on the given tensors"); + } + outputShape = calculatedShape; + outputSize = ShapeUtil.size(outputShape); + const isAOneElement = ShapeUtil.size(a.dims) === 1; + const isBOneElement = ShapeUtil.size(b.dims) === 1; + const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0; + cacheKeyAux.push(isAOneElement); + cacheKeyAux.push(isBOneElement); + cacheKeyAux.push(aLastDimDivisibleBy4); + cacheKeyAux.push(bLastDimDivisibleBy4); + // check whether vectorize can be enabled + let sharedDimension = 1; + for (let i = 1; i < outputShape.length; i++) { + const dimA = a.dims[a.dims.length - i] ?? 1; + const dimB = b.dims[b.dims.length - i] ?? 1; + if (dimA === dimB) { + sharedDimension *= dimA; } else { - // element-wise - vectorize = true; + break; } - cacheKeyAux.push(vectorize); + } + if (sharedDimension % 4 === 0) { + sharedDimensionDivisibleBy4 = true; + vectorize = true; + } else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) { + vectorize = true; + } + } else { + // element-wise + vectorize = true; + } + cacheKeyAux.push(vectorize); - return { - name, - shaderCache: { - hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), - inputDependencies: ['rank', 'rank'], - }, - getShaderSource: (shaderHelper) => createBinaryOpProgramShader( - shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, - a.dataType, b.dataType, outputDataType, additionalImplementation), - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, - programUniforms: [ - {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ...createTensorShapeVariables(a.dims, b.dims, outputShape) - ], - }), - }; - }; + return { + name, + shaderCache: { + hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), + inputDependencies: ['rank', 'rank'], + }, + getShaderSource: (shaderHelper) => + createBinaryOpProgramShader( + shaderHelper, + a.dims, + b.dims, + outputShape, + vectorize, + isBroadcast, + sharedDimensionDivisibleBy4, + funcCall, + a.dataType, + b.dataType, + outputDataType, + additionalImplementation, + ), + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: outputDataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */) }, + programUniforms: [ + { type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4) }, + ...createTensorShapeVariables(a.dims, b.dims, outputShape), + ], + }), + }; +}; -const runBinaryOp = - (context: ComputeContext, name: string, funcCall: BinaryFunctionCall, additionalImplementation?: string, - cacheKey?: string, outputDataType?: number): void => { - context.compute(createBinaryOpProgramInfo( - name, cacheKey ?? '', context.inputs[0], context.inputs[1], funcCall, additionalImplementation, - outputDataType)); - }; +const runBinaryOp = ( + context: ComputeContext, + name: string, + funcCall: BinaryFunctionCall, + additionalImplementation?: string, + cacheKey?: string, + outputDataType?: number, +): void => { + context.compute( + createBinaryOpProgramInfo( + name, + cacheKey ?? '', + context.inputs[0], + context.inputs[1], + funcCall, + additionalImplementation, + outputDataType, + ), + ); +}; export const add = (context: ComputeContext): void => { runBinaryOp(context, 'Add', (a, b) => `${a}+${b}`); @@ -204,8 +253,13 @@ export const div = (context: ComputeContext): void => { export const equal = (context: ComputeContext): void => { runBinaryOp( - context, 'Equal', ({scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4(${a}==${b})`}), undefined, - undefined, DataType.bool); + context, + 'Equal', + { scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4(${a}==${b})` }, + undefined, + undefined, + DataType.bool, + ); }; export const mul = (context: ComputeContext): void => { @@ -216,8 +270,10 @@ export const pow = (context: ComputeContext): void => { const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value; const roundStr = type === 'i32' ? 'round' : ''; runBinaryOp( - context, 'Pow', ({scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})`}), - ` + context, + 'Pow', + { scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})` }, + ` fn pow_custom(a : ${type}, b : ${type}) -> ${type} { if (b == ${type}(0.0)) { return ${type}(1.0); @@ -225,13 +281,15 @@ export const pow = (context: ComputeContext): void => { return ${type}(pow(f32(a), f32(b))); // NaN } return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${ - roundStr}(pow(f32(abs(a)), f32(b)))); + roundStr + }(pow(f32(abs(a)), f32(b)))); } fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> { // TODO: implement vectorized pow return vec4<${type}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w)); } - `); + `, + ); }; export const sub = (context: ComputeContext): void => { @@ -240,24 +298,44 @@ export const sub = (context: ComputeContext): void => { export const greater = (context: ComputeContext): void => { runBinaryOp( - context, 'Greater', ({scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4(${a}>${b})`}), undefined, - undefined, DataType.bool); + context, + 'Greater', + { scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4(${a}>${b})` }, + undefined, + undefined, + DataType.bool, + ); }; export const less = (context: ComputeContext): void => { runBinaryOp( - context, 'Less', ({scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4(${a}<${b})`}), undefined, - undefined, DataType.bool); + context, + 'Less', + { scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4(${a}<${b})` }, + undefined, + undefined, + DataType.bool, + ); }; export const greaterOrEqual = (context: ComputeContext): void => { runBinaryOp( - context, 'GreaterOrEqual', ({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4(${a}>=${b})`}), - undefined, undefined, DataType.bool); + context, + 'GreaterOrEqual', + { scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4(${a}>=${b})` }, + undefined, + undefined, + DataType.bool, + ); }; export const lessOrEqual = (context: ComputeContext): void => { runBinaryOp( - context, 'LessOrEqual', ({scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4(${a}<=${b})`}), - undefined, undefined, DataType.bool); + context, + 'LessOrEqual', + { scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4(${a}<=${b})` }, + undefined, + undefined, + DataType.bool, + ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index ec2831a3cca04..7696f22d44abd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {ShapeUtil} from '../../util'; -import {ProgramUniform, ProgramUniformVariableInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { ShapeUtil } from '../../util'; +import { ProgramUniform, ProgramUniformVariableInfo } from '../types'; /** * constant value for a workgroup size. @@ -119,7 +119,7 @@ export interface IndicesHelper { * * @param init - initial value. */ - readonly indices: (...init: ReadonlyArray) => string; + readonly indices: (...init: ReadonlyArray) => string; /** * WGSL code of a statement for setting indices. @@ -130,7 +130,7 @@ export interface IndicesHelper { * * @returns a WGSL statement */ - readonly indicesSet: (varIndices: string, idx: number|string, value: number|string) => void; + readonly indicesSet: (varIndices: string, idx: number | string, value: number | string) => void; /** * WGSL code of an `u32` expression for getting indices. @@ -140,7 +140,7 @@ export interface IndicesHelper { * * @returns an `u32` expression */ - readonly indicesGet: (varIndices: string, idx: number|string) => string; + readonly indicesGet: (varIndices: string, idx: number | string) => string; /** * WGSL code for a statement for setting data at the given indices. @@ -148,7 +148,7 @@ export interface IndicesHelper { * @param indicesAndValue - an array of numbers or strings (WGSL `u32` expression) representing the indices, followed * by the value to set. This array should have exactly `shape.length + 1` elements. */ - readonly set: (...indicesAndValue: ReadonlyArray) => string; + readonly set: (...indicesAndValue: ReadonlyArray) => string; /** * WGSL code for a statement for setting data at the given indices variable. @@ -164,14 +164,14 @@ export interface IndicesHelper { * @param offset - a number or a string (WGSL `u32` expression) representing the offset. * @param value - the value to set. should be a WGSL expression. */ - readonly setByOffset: (offset: number|string, value: string) => string; + readonly setByOffset: (offset: number | string, value: string) => string; /** * WGSL code for an expression for getting data at the given indices. * * @param indices - an array of numbers or strings (WGSL `u32` expression) representing the indices. */ - readonly get: (...indices: ReadonlyArray) => string; + readonly get: (...indices: ReadonlyArray) => string; /** * WGSL code for an expression for getting data at the given indices variable. @@ -185,7 +185,7 @@ export interface IndicesHelper { * * @param offset - a number or a string (WGSL `u32` expression) representing the offset. */ - readonly getByOffset: (offset: number|string) => string; + readonly getByOffset: (offset: number | string) => string; /** * name of the data variable @@ -195,7 +195,7 @@ export interface IndicesHelper { /** * whether the helper is for an input, an output or an internal variable. */ - readonly usage: 'input'|'output'|'internal'; + readonly usage: 'input' | 'output' | 'internal'; /** * the rank of the input or output. @@ -213,7 +213,7 @@ export interface IndicesHelper { readonly strides: string; } -const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { +const getWgslMappedType = (type: number, components: 1 | 2 | 3 | 4): string | [string, string] => { if (components === 3) { throw new Error('vec3 has same alignment as vec4, use vec4 instead'); } @@ -249,22 +249,24 @@ const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, s } }; -export const tensorTypeToWsglStorageType = (type: DataType, components: 1|2|3|4 = 1) => { +export const tensorTypeToWsglStorageType = (type: DataType, components: 1 | 2 | 3 | 4 = 1) => { const mappedType = getWgslMappedType(type, components); return typeof mappedType === 'string' ? mappedType : mappedType[0]; }; -export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = 1) => { +export const tensorTypeToWsglValueType = (type: DataType, components: 1 | 2 | 3 | 4 = 1) => { const mappedType = getWgslMappedType(type, components); return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; export const createTensorShapeVariables = (...dims: ReadonlyArray): ProgramUniform[] => { const programUniforms: ProgramUniform[] = []; - dims.forEach(dim => { + dims.forEach((dim) => { if (dim.length !== 0) { programUniforms.push( - {type: DataType.uint32, data: dim}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dim)}); + { type: DataType.uint32, data: dim }, + { type: DataType.uint32, data: ShapeUtil.computeStrides(dim) }, + ); } }); return programUniforms; @@ -340,26 +342,30 @@ export const sumVector = (name: string, components: number) => { * @param length - the length of variable. * @param type - the type of variable, optional. */ -export const getElementAt = - (name: string, index: number|string, length: number, type?: UniformDataElementType): string => { - if (name.startsWith('uniforms.') && length > 4) { - if (typeof (index) === 'string') { - if (type === 'f16') { - return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`; - } else { - return `${name}[(${index}) / 4][(${index}) % 4]`; - } - } else { - if (type === 'f16') { - return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`; - } else { - return `${name}[${Math.floor(index / 4)}][${index % 4}]`; - } - } +export const getElementAt = ( + name: string, + index: number | string, + length: number, + type?: UniformDataElementType, +): string => { + if (name.startsWith('uniforms.') && length > 4) { + if (typeof index === 'string') { + if (type === 'f16') { + return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`; } else { - return length > 1 ? `${name}[${index}]` : name; + return `${name}[(${index}) / 4][(${index}) % 4]`; } - }; + } else { + if (type === 'f16') { + return `${name}[${Math.floor(index / 8)}][${Math.floor((index % 8) / 4)}][${(index % 8) % 4}]`; + } else { + return `${name}[${Math.floor(index / 4)}][${index % 4}]`; + } + } + } else { + return length > 1 ? `${name}[${index}]` : name; + } +}; /** * A helper function to get a IndicesHelper for a given input or output. @@ -371,46 +377,53 @@ export const getElementAt = * @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for * vec4. */ -const createIndicesHelper = - (name: string, tensorType: number, shapeOrRank: number|readonly number[], usage: IndicesHelper['usage'], - components: 1|2|3|4): IndicesHelper => { - const useUniform = typeof shapeOrRank === 'number'; - const rank = useUniform ? shapeOrRank : shapeOrRank.length; - const rankIdentity = [...new Array(rank).keys()]; - const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}` : `array`; - const mappedType = getWgslMappedType(tensorType, components); - const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1]; - const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0]; - const type = {indices: indicesType, value: valueType, storage: storageType, tensor: tensorType}; - - const normalizeDim = (dim: number|string): string => typeof dim === 'string' ? dim : `${dim}u`; - - const implementationUsed = { - offsetToIndices: false, - indicesToOffset: false, - broadcastedIndicesToOffset: false, - set: false, - setByIndices: false, - get: false, - getByIndices: false, - }; - - const uniformPrefix = useUniform ? 'uniforms.' : ''; - const shape = `${uniformPrefix}${name}_shape`; - const strides = `${uniformPrefix}${name}_strides`; - - let o2iSnippet = ''; - for (let i = 0; i < rank - 1; i++) { - o2iSnippet += ` +const createIndicesHelper = ( + name: string, + tensorType: number, + shapeOrRank: number | readonly number[], + usage: IndicesHelper['usage'], + components: 1 | 2 | 3 | 4, +): IndicesHelper => { + const useUniform = typeof shapeOrRank === 'number'; + const rank = useUniform ? shapeOrRank : shapeOrRank.length; + const rankIdentity = [...new Array(rank).keys()]; + const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}` : `array`; + const mappedType = getWgslMappedType(tensorType, components); + const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1]; + const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0]; + const type = { indices: indicesType, value: valueType, storage: storageType, tensor: tensorType }; + + const normalizeDim = (dim: number | string): string => (typeof dim === 'string' ? dim : `${dim}u`); + + const implementationUsed = { + offsetToIndices: false, + indicesToOffset: false, + broadcastedIndicesToOffset: false, + set: false, + setByIndices: false, + get: false, + getByIndices: false, + }; + + const uniformPrefix = useUniform ? 'uniforms.' : ''; + const shape = `${uniformPrefix}${name}_shape`; + const strides = `${uniformPrefix}${name}_strides`; + + let o2iSnippet = ''; + for (let i = 0; i < rank - 1; i++) { + o2iSnippet += ` let dim${i} = current / ${getElementAt(strides, i, rank)}; let rest${i} = current % ${getElementAt(strides, i, rank)}; indices[${i}] = dim${i}; current = rest${i}; `; - } - o2iSnippet += `indices[${rank - 1}] = current;`; + } + o2iSnippet += `indices[${rank - 1}] = current;`; - const offsetToIndicesImplementation = rank < 2 ? '' : ` + const offsetToIndicesImplementation = + rank < 2 + ? '' + : ` fn o2i_${name}(offset: u32) -> ${type.indices} { var indices: ${type.indices}; var current = offset; @@ -418,254 +431,272 @@ const createIndicesHelper = return indices; }`; - const offsetToIndices = (varOffset: string) => { - implementationUsed.offsetToIndices = true; - return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; - }; + const offsetToIndices = (varOffset: string) => { + implementationUsed.offsetToIndices = true; + return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; + }; - const offsets: string[] = []; - if (rank >= 2) { - for (let i = rank - 1; i >= 0; i--) { - offsets.push(`${getElementAt(strides, i, rank)} * (indices[${i}])`); - } - } + const offsets: string[] = []; + if (rank >= 2) { + for (let i = rank - 1; i >= 0; i--) { + offsets.push(`${getElementAt(strides, i, rank)} * (indices[${i}])`); + } + } - const indicesToOffsetImplementation = rank < 2 ? '' : ` + const indicesToOffsetImplementation = + rank < 2 + ? '' + : ` fn i2o_${name}(indices: ${type.indices}) -> u32 { return ${offsets.join('+')}; }`; - const indicesToOffset = (varIndices: string) => { - implementationUsed.indicesToOffset = true; - return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; - }; + const indicesToOffset = (varIndices: string) => { + implementationUsed.indicesToOffset = true; + return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; + }; - const indices = (...init: ReadonlyArray) => - rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`; + const indices = (...init: ReadonlyArray) => + rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`; - const indicesGet = (varIndices: string, idx: number|string) => { - if (rank < 2) { - return `${varIndices}`; - } else { - return `${getElementAt(varIndices, idx, rank)}`; - } - }; + const indicesGet = (varIndices: string, idx: number | string) => { + if (rank < 2) { + return `${varIndices}`; + } else { + return `${getElementAt(varIndices, idx, rank)}`; + } + }; - const indicesSet = (varIndices: string, idx: number|string, value: string) => { - if (rank < 2) { - return `${varIndices}=${value};`; - } else { - return `${getElementAt(varIndices, idx, rank)}=${value};`; - } - }; - - const broadcastedIndicesToOffsetImplementation: {[key: string]: string} = {}; - const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => { - implementationUsed.broadcastedIndicesToOffset = true; - const implKey = `${output.name}broadcastedIndicesTo${name}Offset`; - if (implKey in broadcastedIndicesToOffsetImplementation) { - return `${implKey}(${varIndices})`; - } - const offsets = []; - for (let i = rank - 1; i >= 0; i--) { - const idx = output.indicesGet('outputIndices', i + output.rank - rank); - offsets.push(`${indicesGet(strides, i)} * (${idx} % ${indicesGet(shape, i)})`); - } - broadcastedIndicesToOffsetImplementation[implKey] = - `fn ${implKey}(outputIndices: ${output.type.indices}) -> u32 { + const indicesSet = (varIndices: string, idx: number | string, value: string) => { + if (rank < 2) { + return `${varIndices}=${value};`; + } else { + return `${getElementAt(varIndices, idx, rank)}=${value};`; + } + }; + + const broadcastedIndicesToOffsetImplementation: { [key: string]: string } = {}; + const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => { + implementationUsed.broadcastedIndicesToOffset = true; + const implKey = `${output.name}broadcastedIndicesTo${name}Offset`; + if (implKey in broadcastedIndicesToOffsetImplementation) { + return `${implKey}(${varIndices})`; + } + const offsets = []; + for (let i = rank - 1; i >= 0; i--) { + const idx = output.indicesGet('outputIndices', i + output.rank - rank); + offsets.push(`${indicesGet(strides, i)} * (${idx} % ${indicesGet(shape, i)})`); + } + broadcastedIndicesToOffsetImplementation[implKey] = `fn ${implKey}(outputIndices: ${output.type.indices}) -> u32 { return ${offsets.length > 0 ? offsets.join('+') : '0u'}; }`; - return `${implKey}(${varIndices})`; - }; - - const setByOffset = (offset: number|string, value: string) => (() => { - if (type.storage === type.value) { - return `${name}[${offset}]=${value};`; - } else if (type.storage === 'vec2' && type.value === 'i32') { - // int64, components === 1 - return `${name}[${offset}]=vec2(u32(${value}), select(0u, 0xFFFFFFFFu, ${value} < 0));`; - } else if (type.storage === 'vec2' && type.value === 'u32') { - // uint64, components === 1 - return `${name}[${offset}]=vec2(u32(${value}), 0u);`; - } else if (type.storage === 'u32' && type.value === 'vec4') { - // bool, components === 4 - return `${name}[${offset}]=dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(${value}));`; - } else { - throw new Error(`not supported combination of storage type ${type.storage} and value type ${type.value} yet`); - } - })(); - - const getByOffset = (offset: number|string) => (() => { - if (type.storage === type.value) { - return `${name}[${offset}]`; - } else if (type.storage === 'vec2' && type.value === 'i32') { - // int64, components === 1 - return `i32(${name}[${offset}].x)`; - } else if (type.storage === 'vec2' && type.value === 'u32') { - // uint64, components === 1 - return `u32(${name}[${offset}].x)`; - } else if (type.storage === 'u32' && type.value === 'vec4') { - // bool, components === 4 - return `vec4(bool(${name}[${offset}] & 0xFFu), bool(${name}[${offset}] & 0xFF00u), bool(${name}[${ - offset}] & 0xFF0000u), bool(${name}[${offset}] & 0xFF000000u))`; - } else { - throw new Error(`not supported combination of storage type ${type.storage} and value type ${type.value} yet`); - } - })(); + return `${implKey}(${varIndices})`; + }; + + const setByOffset = (offset: number | string, value: string) => + (() => { + if (type.storage === type.value) { + return `${name}[${offset}]=${value};`; + } else if (type.storage === 'vec2' && type.value === 'i32') { + // int64, components === 1 + return `${name}[${offset}]=vec2(u32(${value}), select(0u, 0xFFFFFFFFu, ${value} < 0));`; + } else if (type.storage === 'vec2' && type.value === 'u32') { + // uint64, components === 1 + return `${name}[${offset}]=vec2(u32(${value}), 0u);`; + } else if (type.storage === 'u32' && type.value === 'vec4') { + // bool, components === 4 + return `${name}[${offset}]=dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(${value}));`; + } else { + throw new Error(`not supported combination of storage type ${type.storage} and value type ${type.value} yet`); + } + })(); + + const getByOffset = (offset: number | string) => + (() => { + if (type.storage === type.value) { + return `${name}[${offset}]`; + } else if (type.storage === 'vec2' && type.value === 'i32') { + // int64, components === 1 + return `i32(${name}[${offset}].x)`; + } else if (type.storage === 'vec2' && type.value === 'u32') { + // uint64, components === 1 + return `u32(${name}[${offset}].x)`; + } else if (type.storage === 'u32' && type.value === 'vec4') { + // bool, components === 4 + return `vec4(bool(${name}[${offset}] & 0xFFu), bool(${name}[${offset}] & 0xFF00u), bool(${name}[${ + offset + }] & 0xFF0000u), bool(${name}[${offset}] & 0xFF000000u))`; + } else { + throw new Error(`not supported combination of storage type ${type.storage} and value type ${type.value} yet`); + } + })(); - const getByIndicesImplementation = rank < 2 ? '' : ` + const getByIndicesImplementation = + rank < 2 + ? '' + : ` fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { return ${getByOffset(`i2o_${name}(indices)`)}; }`; - const getImplementation = rank < 2 ? '' : (() => { - const functionParams = rankIdentity.map(i => `d${i}: u32`).join(', '); - const dimsParams = rankIdentity.map(i => `d${i}`).join(', '); - return ` + const getImplementation = + rank < 2 + ? '' + : (() => { + const functionParams = rankIdentity.map((i) => `d${i}: u32`).join(', '); + const dimsParams = rankIdentity.map((i) => `d${i}`).join(', '); + return ` fn get_${name}(${functionParams}) -> ${valueType} { return get_${name}ByIndices(${indices(dimsParams)}); }`; - })(); + })(); - const get = (...indices: ReadonlyArray) => { - if (indices.length !== rank) { - throw new Error(`indices length must be ${rank}`); - } - - const normalizedIndices = indices.map(normalizeDim).join(','); - - if (rank === 0) { - return getByOffset('0u'); - } else if (rank === 1) { - return getByOffset(normalizedIndices[0]); - } else { - implementationUsed.get = true; - implementationUsed.getByIndices = true; - implementationUsed.indicesToOffset = true; - return `get_${name}(${normalizedIndices})`; - } - }; + const get = (...indices: ReadonlyArray) => { + if (indices.length !== rank) { + throw new Error(`indices length must be ${rank}`); + } - const getByIndices = (varIndices: string) => { - if (rank < 2) { - return getByOffset(varIndices); - } else { - implementationUsed.getByIndices = true; - implementationUsed.indicesToOffset = true; - return `get_${name}ByIndices(${varIndices})`; - } - }; + const normalizedIndices = indices.map(normalizeDim).join(','); + + if (rank === 0) { + return getByOffset('0u'); + } else if (rank === 1) { + return getByOffset(normalizedIndices[0]); + } else { + implementationUsed.get = true; + implementationUsed.getByIndices = true; + implementationUsed.indicesToOffset = true; + return `get_${name}(${normalizedIndices})`; + } + }; + + const getByIndices = (varIndices: string) => { + if (rank < 2) { + return getByOffset(varIndices); + } else { + implementationUsed.getByIndices = true; + implementationUsed.indicesToOffset = true; + return `get_${name}ByIndices(${varIndices})`; + } + }; - const setByIndicesImplementation = rank < 2 ? '' : ` + const setByIndicesImplementation = + rank < 2 + ? '' + : ` fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) { ${setByOffset(`i2o_${name}(indices)`, 'value')} }`; - const setImplementation = rank < 2 ? '' : (() => { - const functionParams = rankIdentity.map(i => `d${i}: u32`).join(', '); - const dimsParams = rankIdentity.map(i => `d${i}`).join(', '); - return ` + const setImplementation = + rank < 2 + ? '' + : (() => { + const functionParams = rankIdentity.map((i) => `d${i}: u32`).join(', '); + const dimsParams = rankIdentity.map((i) => `d${i}`).join(', '); + return ` fn set_${name}(${functionParams}, value: ${valueType}) { set_${name}ByIndices(${indices(dimsParams)}, value); }`; - })(); - - const set = (...indicesAndValue: ReadonlyArray) => { - if (indicesAndValue.length !== rank + 1) { - throw new Error(`indices length must be ${rank}`); - } - const value = indicesAndValue[rank]; - if (typeof value !== 'string') { - throw new Error('value must be string'); - } - - const normalizedIndices = indicesAndValue.slice(0, rank).map(normalizeDim).join(','); + })(); - if (rank === 0) { - return setByOffset('0u', value); - } else if (rank === 1) { - return setByOffset(normalizedIndices[0], value); - } else { - implementationUsed.set = true; - implementationUsed.setByIndices = true; - implementationUsed.indicesToOffset = true; - return `set_${name}(${normalizedIndices}, ${value})`; - } - }; + const set = (...indicesAndValue: ReadonlyArray) => { + if (indicesAndValue.length !== rank + 1) { + throw new Error(`indices length must be ${rank}`); + } + const value = indicesAndValue[rank]; + if (typeof value !== 'string') { + throw new Error('value must be string'); + } - const setByIndices = (varIndices: string, value: string) => { - if (rank < 2) { - return setByOffset(varIndices, value); - } else { - implementationUsed.setByIndices = true; - implementationUsed.indicesToOffset = true; - return `set_${name}ByIndices(${varIndices}, ${value});`; - } - }; - - const impl = () => { - const impls = []; - let needShapeStrides = false; - if (implementationUsed.offsetToIndices) { - impls.push(offsetToIndicesImplementation); - needShapeStrides = true; - } - if (implementationUsed.indicesToOffset) { - impls.push(indicesToOffsetImplementation); - needShapeStrides = true; - } - if (implementationUsed.broadcastedIndicesToOffset) { - Object.values(broadcastedIndicesToOffsetImplementation).forEach(impl => impls.push(impl)); - needShapeStrides = true; - } - if (implementationUsed.set) { - impls.push(setImplementation); - needShapeStrides = true; - } - if (implementationUsed.setByIndices) { - impls.push(setByIndicesImplementation); - needShapeStrides = true; - } - if (implementationUsed.get) { - impls.push(getImplementation); - needShapeStrides = true; - } - if (implementationUsed.getByIndices) { - impls.push(getByIndicesImplementation); - needShapeStrides = true; - } - if (!useUniform && needShapeStrides) { - impls.unshift( - `const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`, - `const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`); - } - return impls.join('\n'); - }; - - return { - impl, - type, - offsetToIndices, - indicesToOffset, - broadcastedIndicesToOffset, - indices, - indicesGet, - indicesSet, - set, - setByOffset, - setByIndices, - get, - getByOffset, - getByIndices, - // isVec4, - usage, - name, - strides, - shape, - rank - }; - }; + const normalizedIndices = indicesAndValue.slice(0, rank).map(normalizeDim).join(','); + + if (rank === 0) { + return setByOffset('0u', value); + } else if (rank === 1) { + return setByOffset(normalizedIndices[0], value); + } else { + implementationUsed.set = true; + implementationUsed.setByIndices = true; + implementationUsed.indicesToOffset = true; + return `set_${name}(${normalizedIndices}, ${value})`; + } + }; + + const setByIndices = (varIndices: string, value: string) => { + if (rank < 2) { + return setByOffset(varIndices, value); + } else { + implementationUsed.setByIndices = true; + implementationUsed.indicesToOffset = true; + return `set_${name}ByIndices(${varIndices}, ${value});`; + } + }; + + const impl = () => { + const impls = []; + let needShapeStrides = false; + if (implementationUsed.offsetToIndices) { + impls.push(offsetToIndicesImplementation); + needShapeStrides = true; + } + if (implementationUsed.indicesToOffset) { + impls.push(indicesToOffsetImplementation); + needShapeStrides = true; + } + if (implementationUsed.broadcastedIndicesToOffset) { + Object.values(broadcastedIndicesToOffsetImplementation).forEach((impl) => impls.push(impl)); + needShapeStrides = true; + } + if (implementationUsed.set) { + impls.push(setImplementation); + needShapeStrides = true; + } + if (implementationUsed.setByIndices) { + impls.push(setByIndicesImplementation); + needShapeStrides = true; + } + if (implementationUsed.get) { + impls.push(getImplementation); + needShapeStrides = true; + } + if (implementationUsed.getByIndices) { + impls.push(getByIndicesImplementation); + needShapeStrides = true; + } + if (!useUniform && needShapeStrides) { + impls.unshift( + `const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`, + `const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`, + ); + } + return impls.join('\n'); + }; + + return { + impl, + type, + offsetToIndices, + indicesToOffset, + broadcastedIndicesToOffset, + indices, + indicesGet, + indicesSet, + set, + setByOffset, + setByIndices, + get, + getByOffset, + getByIndices, + // isVec4, + usage, + name, + strides, + shape, + rank, + }; +}; /** * Create a IndicesHelper for an input. @@ -676,9 +707,12 @@ const createIndicesHelper = * @param components - the number of components of the input. available values are 1, 2, 3, 4. default is 1. * @returns an IndicesHelper for the input. */ -export const inputVariable = - (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, 'input', components); +export const inputVariable = ( + name: string, + type: number, + shapeOrRank: number | readonly number[], + components: 1 | 2 | 3 | 4 = 1, +): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'input', components); /** * Create a IndicesHelper for an output. @@ -689,9 +723,12 @@ export const inputVariable = * @param components - the number of components of the output. available values are 1, 2, 3, 4. default is 1. * @returns an IndicesHelper for the output. */ -export const outputVariable = - (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, 'output', components); +export const outputVariable = ( + name: string, + type: number, + shapeOrRank: number | readonly number[], + components: 1 | 2 | 3 | 4 = 1, +): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'output', components); /** * Create a IndicesHelper for an internal variable. @@ -702,12 +739,15 @@ export const outputVariable = * @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1. * @returns an IndicesHelper for the variable. */ -export const internalVariable = - (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, 'internal', components); +export const internalVariable = ( + name: string, + type: number, + shapeOrRank: number | readonly number[], + components: 1 | 2 | 3 | 4 = 1, +): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components); -export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32'; -export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; +export type UniformDataElementType = 'u32' | 'f16' | 'f32' | 'i32'; +export type UniformsArrayType = Array<{ name: string; type: UniformDataElementType; length?: number }>; /** * A ShaderHelper is a helper class for generating WGSL code. @@ -728,7 +768,7 @@ export interface ShaderHelper { * * @param workgroupSize - an optional workgroup size. default is WORKGROUP_SIZE. */ - mainStart(workgroupSize?: number|[number, number, number]): string; + mainStart(workgroupSize?: number | [number, number, number]): string; /** * A helper function to generate the code snippet for guarding against out-of-bounds size. @@ -783,47 +823,60 @@ export interface ShaderHelper { } class ShaderHelperImpl implements ShaderHelper { - constructor(private normalizedDispatchGroup: [number, number, number], private limits: GPUSupportedLimits) {} + constructor( + private normalizedDispatchGroup: [number, number, number], + private limits: GPUSupportedLimits, + ) {} - guardAgainstOutOfBoundsWorkgroupSizes(size: number|string): string { + guardAgainstOutOfBoundsWorkgroupSizes(size: number | string): string { // Guard against out-of-bounds work group sizes const sizeInCode = typeof size === 'number' ? `${size}u` : size; return `if (global_idx >= ${sizeInCode}) { return; }`; } - mainStart(workgroupSize: number|[number, number, number] = WORKGROUP_SIZE) { + mainStart(workgroupSize: number | [number, number, number] = WORKGROUP_SIZE) { const workgroupSizeX = typeof workgroupSize === 'number' ? workgroupSize : workgroupSize[0]; const workgroupSizeY = typeof workgroupSize === 'number' ? 1 : workgroupSize[1]; const workgroupSizeZ = typeof workgroupSize === 'number' ? 1 : workgroupSize[2]; - if (workgroupSizeX > this.limits.maxComputeWorkgroupSizeX || - workgroupSizeY > this.limits.maxComputeWorkgroupSizeY || - workgroupSizeZ > this.limits.maxComputeWorkgroupSizeZ) { - throw new Error(`workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${ - workgroupSizeZ}] exceeds the maximum workgroup size [${this.limits.maxComputeWorkgroupSizeX}, ${ - this.limits.maxComputeWorkgroupSizeY}, ${this.limits.maxComputeWorkgroupSizeZ}].`); + if ( + workgroupSizeX > this.limits.maxComputeWorkgroupSizeX || + workgroupSizeY > this.limits.maxComputeWorkgroupSizeY || + workgroupSizeZ > this.limits.maxComputeWorkgroupSizeZ + ) { + throw new Error( + `workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${ + workgroupSizeZ + }] exceeds the maximum workgroup size [${this.limits.maxComputeWorkgroupSizeX}, ${ + this.limits.maxComputeWorkgroupSizeY + }, ${this.limits.maxComputeWorkgroupSizeZ}].`, + ); } if (workgroupSizeX * workgroupSizeY * workgroupSizeZ > this.limits.maxComputeInvocationsPerWorkgroup) { - throw new Error(`workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${ - workgroupSizeZ}] exceeds the maximum workgroup invocations ${ - this.limits.maxComputeInvocationsPerWorkgroup}.`); + throw new Error( + `workgroup size [${workgroupSizeX}, ${workgroupSizeY}, ${ + workgroupSizeZ + }] exceeds the maximum workgroup invocations ${this.limits.maxComputeInvocationsPerWorkgroup}.`, + ); } const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; - const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, + const paramList = is1DimensionDispatch + ? `@builtin(global_invocation_id) global_id : vec3, @builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3` : - `@builtin(global_invocation_id) global_id : vec3, + @builtin(local_invocation_id) local_id : vec3` + : `@builtin(global_invocation_id) global_id : vec3, @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_idx : u32, @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; - const globalIdxDefinition = is1DimensionDispatch ? - 'let global_idx = global_id.x; let local_idx = local_id.x;' : - `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + + const globalIdxDefinition = is1DimensionDispatch + ? 'let global_idx = global_id.x; let local_idx = local_id.x;' + : `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${ - workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`; + workgroupSizeX * workgroupSizeY * workgroupSizeZ + }u + local_idx;`; return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ}) fn main(${paramList}) { @@ -834,10 +887,10 @@ class ShaderHelperImpl implements ShaderHelper { private appendVariableUniforms(variable: IndicesHelper): void { if (variable.rank !== 0) { if (variable.shape.startsWith('uniforms.')) { - this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: 'u32', length: variable.rank}); + this.uniforms.push({ name: variable.shape.replace('uniforms.', ''), type: 'u32', length: variable.rank }); } if (variable.strides.startsWith('uniforms.')) { - this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: 'u32', length: variable.rank}); + this.uniforms.push({ name: variable.strides.replace('uniforms.', ''), type: 'u32', length: variable.rank }); } } } @@ -855,13 +908,14 @@ class ShaderHelperImpl implements ShaderHelper { } declareVariables(...variables: IndicesHelper[]): string { - return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n'); + return variables.map((v) => this.declareVariable(v, this.variableIndex++)).join('\n'); } private registerInternalVariable(variable: IndicesHelper): void { if (variable.usage !== 'internal') { throw new Error( - 'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.'); + 'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.', + ); } this.internalVariables.push(variable); @@ -869,12 +923,12 @@ class ShaderHelperImpl implements ShaderHelper { } registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper { - variables.forEach(v => this.registerInternalVariable(v)); + variables.forEach((v) => this.registerInternalVariable(v)); return this; } registerUniform(name: string, type: UniformDataElementType, length = 1): ShaderHelper { - this.uniforms.push({name, type, length}); + this.uniforms.push({ name, type, length }); return this; } @@ -892,7 +946,7 @@ class ShaderHelperImpl implements ShaderHelper { } const uniformSnippets: string[] = []; - for (const {name, type, length} of this.uniforms) { + for (const { name, type, length } of this.uniforms) { if (length && length > 4) { if (type === 'f16') { uniformSnippets.push(`@align(16) ${name}:array, ${Math.ceil(length / 8)}>`); @@ -915,27 +969,29 @@ class ShaderHelperImpl implements ShaderHelper { * Get additional implementation that needs to be added to the shader source. */ get additionalImplementations(): string { - return this.uniformDeclaration() + this.variables.map(i => i.impl()).join('\n') + - this.internalVariables.map(i => i.impl()).join('\n'); + return ( + this.uniformDeclaration() + + this.variables.map((i) => i.impl()).join('\n') + + this.internalVariables.map((i) => i.impl()).join('\n') + ); } /** * Get the variable info of the shader program. */ - get variablesInfo(): ProgramUniformVariableInfo[]|undefined { + get variablesInfo(): ProgramUniformVariableInfo[] | undefined { if (this.uniforms.length === 0) { return undefined; } const uniformWgslTypeToDataType = (type: UniformDataElementType) => - ([DataType.uint32, DataType.float16, DataType.float, - DataType.int32][['u32', 'f16', 'f32', 'i32'].indexOf(type)]); - return this.uniforms.map(u => ([uniformWgslTypeToDataType(u.type), u.length ?? 1])); + [DataType.uint32, DataType.float16, DataType.float, DataType.int32][['u32', 'f16', 'f32', 'i32'].indexOf(type)]; + return this.uniforms.map((u) => [uniformWgslTypeToDataType(u.type), u.length ?? 1]); } } export const createShaderHelper = (dispatchGroup: [number, number, number], limits: GPUSupportedLimits) => - new ShaderHelperImpl(dispatchGroup, limits); + new ShaderHelperImpl(dispatchGroup, limits); /** * This function comes from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/ops/broadcast_util.ts#L18-L40 diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 010ee589c44fa..ec690720268ca 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; -import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -71,43 +71,48 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -const createConcatProgramInfo = - (inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => { - const outputSize = ShapeUtil.size(outputShape); - - const sizeInConcatAxis = new Array(inputs.length); - const inputVars = new Array(inputs.length); - - let previousSum = 0; - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - const inputRanks = []; - const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; - for (let i = 0; i < inputs.length; ++i) { - previousSum += inputs[i].dims[adjustedAxis]; - sizeInConcatAxis[i] = previousSum; - inputRanks.push(inputs[i].dims.length); - inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); - inputDependencies.push('rank'); - programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); - } - for (let i = 0; i < inputs.length; ++i) { - programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); - } - programUniforms.push(...createTensorShapeVariables(outputShape)); +const createConcatProgramInfo = ( + inputs: readonly TensorView[], + adjustedAxis: number, + outputShape: number[], + dataType: DataType, +): ProgramInfo => { + const outputSize = ShapeUtil.size(outputShape); + + const sizeInConcatAxis = new Array(inputs.length); + const inputVars = new Array(inputs.length); + + let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + const inputRanks = []; + const programUniforms: ProgramUniform[] = [{ type: DataType.uint32, data: outputSize }]; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[adjustedAxis]; + sizeInConcatAxis[i] = previousSum; + inputRanks.push(inputs[i].dims.length); + inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); + inputDependencies.push('rank'); + programUniforms.push({ type: DataType.uint32, data: sizeInConcatAxis[i] }); + } + for (let i = 0; i < inputs.length; ++i) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const output = outputVariable('output', dataType, outputShape.length); - const indicesAxis = output.indicesGet('indices', adjustedAxis); - const sizeInConcatAxisStr = - Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const output = outputVariable('output', dataType, outputShape.length); + const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = Array.from(Array(sizeInConcatAxis.length).keys()) + .map((i) => `uniforms.sizeInConcatAxis${i}`) + .join(','); + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${(() => { - shaderHelper.registerUniform('outputSize', 'u32'); - for (let i = 0; i < inputs.length; i++) { - shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); - } - return shaderHelper.declareVariables(...inputVars, output); - })()} + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output); + })()} ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} @@ -125,17 +130,17 @@ const createConcatProgramInfo = ${assignOutputData(inputVars, output)} }`; - return { - name: 'Concat', - shaderCache: {hint: `${adjustedAxis}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms, - }), - getShaderSource, - }; - }; + return { + name: 'Concat', + shaderCache: { hint: `${adjustedAxis}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { const inputs = context.inputs; @@ -143,13 +148,16 @@ export const concat = (context: ComputeContext, attributes: ConcatAttributes): v const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); validateInputs(inputs, adjustedAxis); const outputShape = inputShape.slice(); - outputShape[adjustedAxis] = - inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); + outputShape[adjustedAxis] = inputs.reduce( + (sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), + 0, + ); // 0 length tensors are valid for concat, remove them - const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0); - context.compute( - createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs}); + const nonEmptyInputs = inputs.filter((input) => ShapeUtil.size(input.dims) > 0); + context.compute(createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), { + inputs: nonEmptyInputs, + }); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => - createAttributeWithCacheKey({axis: attributes.axis as number}); + createAttributeWithCacheKey({ axis: attributes.axis as number }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 924030125c420..dbe0e0c9647bd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -1,66 +1,85 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; - -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; -import {calculateOutputShape, ConvAttributes} from './conv'; -import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; + +import { + createTensorShapeVariables, + getMaxComponents, + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from './common'; +import { calculateOutputShape, ConvAttributes } from './conv'; +import { appendActivationUniforms, appendActivationUniformsData, getActivationSnippet } from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv * @param squeezeOutputShapeFunction - an optional function to squeeze the output shape, only used in conv1d */ -export const createGroupedConvProgramInfo = - (inputs: readonly TensorView[], attributes: ConvAttributes, - squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfo => { - const hasBias = inputs.length > 2; - const processBias = hasBias ? 'value += b[output_channel];' : ''; - const xShape = inputs[0].dims; - const wShape = inputs[1].dims; - const outputChannelsPerGroup = wShape[0] / attributes.group; - - const isChannelLast = attributes.format === 'NHWC'; - const outputShape = calculateOutputShape( - xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); - const outputSize = ShapeUtil.size(outputShape); - - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations}, - {type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]}, - {type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]}, - {type: DataType.uint32, data: outputChannelsPerGroup} - ]; - appendActivationUniformsData(attributes, programUniforms); - programUniforms.push(...createTensorShapeVariables(xShape, wShape)); - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; - if (hasBias) { - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); - } - programUniforms.push(...createTensorShapeVariables(outputShape)); - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const output = outputVariable('output', inputs[0].dataType, outputShape.length); - const baseType = tensorTypeToWsglStorageType(output.type.tensor); - const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); - const x = inputVariable('x', inputs[0].dataType, xShape.length); - const w = inputVariable('w', inputs[1].dataType, wShape.length); - const inputVars = [x, w]; - if (hasBias) { - inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims.length)); - } +export const createGroupedConvProgramInfo = ( + inputs: readonly TensorView[], + attributes: ConvAttributes, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], +): ProgramInfo => { + const hasBias = inputs.length > 2; + const processBias = hasBias ? 'value += b[output_channel];' : ''; + const xShape = inputs[0].dims; + const wShape = inputs[1].dims; + const outputChannelsPerGroup = wShape[0] / attributes.group; + + const isChannelLast = attributes.format === 'NHWC'; + const outputShape = calculateOutputShape( + xShape, + wShape, + attributes.dilations, + attributes.pads, + attributes.strides, + isChannelLast, + ); + const outputSize = ShapeUtil.size(outputShape); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: attributes.dilations }, + { type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]] }, + { type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]] }, + { type: DataType.uint32, data: outputChannelsPerGroup }, + ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(xShape, wShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); + const x = inputVariable('x', inputs[0].dataType, xShape.length); + const w = inputVariable('w', inputs[1].dataType, wShape.length); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims.length)); + } - const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length}, - {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, - {name: 'output_channels_per_group', type: 'u32'} - ]; - appendActivationUniforms(attributes, uniforms); - return ` + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'dilations', type: 'u32', length: attributes.dilations.length }, + { name: 'strides', type: 'u32', length: 2 }, + { name: 'pads', type: 'u32', length: 2 }, + { name: 'output_channels_per_group', type: 'u32' }, + ]; + appendActivationUniforms(attributes, uniforms); + return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} @@ -70,7 +89,8 @@ export const createGroupedConvProgramInfo = let batch: u32 = outputIndices[0]; let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}]; let xRCCorner: vec2 = vec2(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${ - isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads; + isChannelLast ? 2 : 3 + }]) * uniforms.strides - uniforms.pads; let group_id: u32 = output_channel / uniforms.output_channels_per_group; var value: ${output.type.value} = ${output.type.value}(0); @@ -90,8 +110,10 @@ export const createGroupedConvProgramInfo = } let xVal = ${ - isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : - x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; + isChannelLast + ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') + : x.get('batch', 'input_channel', 'xHeight', 'xWidth') + }; let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')}; value += xVal*wVal; } @@ -101,58 +123,63 @@ export const createGroupedConvProgramInfo = ${applyActivation} ${output.setByOffset('global_idx', 'value')} }`; - }; - return { - name: 'GroupedConv', - shaderCache: {hint: attributes.cacheKey, inputDependencies}, - getRunData: () => ({ - outputs: [{ - dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, - dataType: inputs[0].dataType - }], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource, - }; - }; - -export const createGroupedConvVectorizeProgramInfo = - (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[]): ProgramInfo => { - const hasBias = inputs.length > 2; - const components = getMaxComponents(outputShape[3]); - const outputNumber = getMaxComponents(outputShape[2]); - const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; - const xShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[0].dims[2], inputs[0].dims[3] / components]; - const wShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[1].dims[3] / components]; - const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; - - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, - {type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]}, - {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]} - ]; - appendActivationUniformsData(attributes, programUniforms); - programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader)); - const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; - const getShaderSource = (shaderHelper: ShaderHelper) => { - const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const baseType = tensorTypeToWsglStorageType(output.type.tensor); - const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); - const x = inputVariable('x', inputs[0].dataType, xShape.length, components); - const w = inputVariable('w', inputs[1].dataType, wShape.length, components); - const inputVars = [x, w]; - if (hasBias) { - inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); - } - const processBias = hasBias ? 'value += b[output_channel];' : ''; - const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, - {name: 'strides', type: 'i32', length: 2}, - {name: 'pads', type: 'i32', length: 2}, - ]; - appendActivationUniforms(attributes, uniforms); - return ` + }; + return { + name: 'GroupedConv', + shaderCache: { hint: attributes.cacheKey, inputDependencies }, + getRunData: () => ({ + outputs: [ + { + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType, + }, + ], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; + +export const createGroupedConvVectorizeProgramInfo = ( + inputs: readonly TensorView[], + attributes: ConvAttributes, + outputShape: readonly number[], +): ProgramInfo => { + const hasBias = inputs.length > 2; + const components = getMaxComponents(outputShape[3]); + const outputNumber = getMaxComponents(outputShape[2]); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const xShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[0].dims[2], inputs[0].dims[3] / components]; + const wShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[1].dims[3] / components]; + const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]] }, + { type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]] }, + ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader)); + const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); + const x = inputVariable('x', inputs[0].dataType, xShape.length, components); + const w = inputVariable('w', inputs[1].dataType, wShape.length, components); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); + } + const processBias = hasBias ? 'value += b[output_channel];' : ''; + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'strides', type: 'i32', length: 2 }, + { name: 'pads', type: 'i32', length: 2 }, + ]; + appendActivationUniforms(attributes, uniforms); + return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} @@ -198,19 +225,19 @@ export const createGroupedConvVectorizeProgramInfo = ${output.set('batch', 'row', 'col + i', 'output_channel', 'value')}; } }`; - }; - - return { - name: 'GroupedConv-Vectorize', - shaderCache: { - hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, - inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] - }, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource, - }; - }; + }; + + return { + name: 'GroupedConv-Vectorize', + shaderCache: { + hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'], + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 41bd1d5326dc1..ece2e1b7c7dcd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -1,18 +1,23 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor-view'; -import {ComputeContext} from '../types'; - -import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; -import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu'; -import {ConvAttributes} from './conv'; -import {parseInternalActivationAttributes} from './fuse-utils'; -import {createTransposeProgramInfo} from './transpose'; - -const computeTotalPad = - (inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => - (inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize; +import { TensorView } from '../../tensor-view'; +import { ComputeContext } from '../types'; + +import { createConv2DTransposeMatMulProgramInfo } from './3rd-party/conv_backprop_mm_webgpu'; +import { createConvTranspose2DProgramInfo } from './3rd-party/conv_backprop_webgpu'; +import { ConvAttributes } from './conv'; +import { parseInternalActivationAttributes } from './fuse-utils'; +import { createTransposeProgramInfo } from './transpose'; + +const computeTotalPad = ( + inDim: number, + stride: number, + adj: number, + kernel: number, + dilation: number, + outSize: number, +) => (inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize; const distributePadding = (totalPad: number, autoPad: string, pads: number[], head: number, tail: number) => { const smallPad = Math.floor(totalPad / 2); @@ -25,86 +30,110 @@ const distributePadding = (totalPad: number, autoPad: string, pads: number[], he } }; -const calculateOutputShapeAndPads = - (inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], autoPad: string, - group: number, pads: number[], strides: readonly number[], isChannelLast: boolean, outputPadding: number[], - outputShape: number[]) => { - const spatialRank = inputShape.length - 2; - const updateOutputShape = outputShape.length === 0; - if (outputPadding.length === 0) { - for (let i = 0; i < spatialRank; ++i) { - outputPadding.push(0); - } - } - const batchSize = inputShape[0]; - const outChannels = kernelShape[isChannelLast ? 3 : 1] * group; - for (let i = 0, j = inputShape.length - spatialRank - (isChannelLast ? 1 : 0); i < spatialRank; ++i, ++j) { - const inSize = inputShape[j]; - const outSize = updateOutputShape ? inSize * strides[i] : outputShape[i]; - const totalPad = computeTotalPad(inSize, strides[i], pads[i], kernelShape[j], dilations[i], outSize); - distributePadding(totalPad, autoPad, pads, i, i + spatialRank); - if (updateOutputShape) { - outputShape.push( - strides[i] * (inSize - 1) + outputPadding[i] + (kernelShape[j] - 1) * dilations[i] + 1 - pads[i] - - pads[i + spatialRank]); - } - } - outputShape.splice(0, 0, batchSize); - outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels); - }; +const calculateOutputShapeAndPads = ( + inputShape: readonly number[], + kernelShape: readonly number[], + dilations: readonly number[], + autoPad: string, + group: number, + pads: number[], + strides: readonly number[], + isChannelLast: boolean, + outputPadding: number[], + outputShape: number[], +) => { + const spatialRank = inputShape.length - 2; + const updateOutputShape = outputShape.length === 0; + if (outputPadding.length === 0) { + for (let i = 0; i < spatialRank; ++i) { + outputPadding.push(0); + } + } + const batchSize = inputShape[0]; + const outChannels = kernelShape[isChannelLast ? 3 : 1] * group; + for (let i = 0, j = inputShape.length - spatialRank - (isChannelLast ? 1 : 0); i < spatialRank; ++i, ++j) { + const inSize = inputShape[j]; + const outSize = updateOutputShape ? inSize * strides[i] : outputShape[i]; + const totalPad = computeTotalPad(inSize, strides[i], pads[i], kernelShape[j], dilations[i], outSize); + distributePadding(totalPad, autoPad, pads, i, i + spatialRank); + if (updateOutputShape) { + outputShape.push( + strides[i] * (inSize - 1) + + outputPadding[i] + + (kernelShape[j] - 1) * dilations[i] + + 1 - + pads[i] - + pads[i + spatialRank], + ); + } + } + outputShape.splice(0, 0, batchSize); + outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels); +}; export interface ConvTransposeAttributes extends ConvAttributes { readonly outputPadding: readonly number[]; readonly outputShape: readonly number[]; } -const getAdjustedConvTransposeAttributes = - (attributes: T, inputs: readonly TensorView[]): T => { - const kernelShape = attributes.kernelShape.slice(); - // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims - if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { - kernelShape.length = 0; - for (let i = 2; i < inputs[1].dims.length; ++i) { - kernelShape.push(inputs[1].dims[i]); - } - } - const isChannelsLast = attributes.format === 'NHWC'; - kernelShape.splice(0, 0, inputs[1].dims[0]); - kernelShape.splice(isChannelsLast ? 3 : 1, 0, inputs[1].dims[1]); - - const pads = attributes.pads.slice(); - const outputShape = attributes.outputShape.slice(); - const outputPadding = attributes.outputPadding.slice(); - const inputShape = inputs[0].dims; - let dilations = attributes.dilations.slice(); - if (dilations.reduce((a, b) => a + b, 0) === 0) { - const spatialRank = inputs[0].dims.length - 2; - dilations = new Array(spatialRank).fill(1); - } - let strides = attributes.strides.slice(); - if (strides.reduce((a, b) => a + b, 0) === 0) { - const spatialRank = inputs[0].dims.length - 2; - strides = new Array(spatialRank).fill(1); - } - // If outputShape is not specified in the attributes of this op, infer it from the parameters - // Similarly, automatically infer pads if not specified - calculateOutputShapeAndPads( - inputShape, kernelShape, dilations, attributes.autoPad, attributes.group, pads, strides, isChannelsLast, - outputPadding, outputShape); - - // always return a new object so does not modify the original attributes - const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides}); - return newAttributes; - }; +const getAdjustedConvTransposeAttributes = ( + attributes: T, + inputs: readonly TensorView[], +): T => { + const kernelShape = attributes.kernelShape.slice(); + // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims + if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { + kernelShape.length = 0; + for (let i = 2; i < inputs[1].dims.length; ++i) { + kernelShape.push(inputs[1].dims[i]); + } + } + const isChannelsLast = attributes.format === 'NHWC'; + kernelShape.splice(0, 0, inputs[1].dims[0]); + kernelShape.splice(isChannelsLast ? 3 : 1, 0, inputs[1].dims[1]); + + const pads = attributes.pads.slice(); + const outputShape = attributes.outputShape.slice(); + const outputPadding = attributes.outputPadding.slice(); + const inputShape = inputs[0].dims; + let dilations = attributes.dilations.slice(); + if (dilations.reduce((a, b) => a + b, 0) === 0) { + const spatialRank = inputs[0].dims.length - 2; + dilations = new Array(spatialRank).fill(1); + } + let strides = attributes.strides.slice(); + if (strides.reduce((a, b) => a + b, 0) === 0) { + const spatialRank = inputs[0].dims.length - 2; + strides = new Array(spatialRank).fill(1); + } + // If outputShape is not specified in the attributes of this op, infer it from the parameters + // Similarly, automatically infer pads if not specified + calculateOutputShapeAndPads( + inputShape, + kernelShape, + dilations, + attributes.autoPad, + attributes.group, + pads, + strides, + isChannelsLast, + outputPadding, + outputShape, + ); + + // always return a new object so does not modify the original attributes + const newAttributes: T = Object.assign({}, attributes); + Object.assign(newAttributes, { kernelShape, pads, outputPadding, outputShape, dilations, strides }); + return newAttributes; +}; export const parseConvTransposeAttributes = (attributes: Record): ConvTransposeAttributes => { const activationAttributes = parseInternalActivationAttributes(attributes); // TODO : Make this generic enough to compute default attributes for multi-dimensional conv const format = attributes.format as 'NHWC' | 'NCHW'; - const autoPad = - ['NOTSET', 'VALID', 'SAME_UPPER', - 'SAME_LOWER'][typeof attributes.autoPad == 'undefined' ? 0 : attributes.autoPad as number]; + const autoPad = ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][ + typeof attributes.autoPad == 'undefined' ? 0 : (attributes.autoPad as number) + ]; const dilations = attributes.dilations as [number, number]; const group = attributes.group as number; const kernelShape = attributes.kernelShape as [number, number]; @@ -125,7 +154,7 @@ export const parseConvTransposeAttributes = (attributes: Record strides, wIsConst, ...activationAttributes, - cacheKey: `${attributes.format};${activationAttributes.activation};` + cacheKey: `${attributes.format};${activationAttributes.activation};`, }; }; @@ -186,8 +215,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose // if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor // (the first 2 dims are batch_size and channels) const kernelShapeSet = attributes.kernelShape.reduce((a, b) => a + b, 0) > 0; - if (kernelShapeSet && attributes.kernelShape.length !== 0 && - attributes.kernelShape.length !== inputs[1].dims.length - 2) { + if ( + kernelShapeSet && + attributes.kernelShape.length !== 0 && + attributes.kernelShape.length !== inputs[1].dims.length - 2 + ) { throw new Error('invalid kernel shape'); } @@ -200,59 +232,71 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose // for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] const weightTransposePerm = [2, 3, 1, 0]; -const convTranspose2d = - (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { - const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); - const isChannelsLast = attributes.format === 'NHWC'; - const outputShape = adjustedAttributes.outputShape; - const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's - // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit - // utilization rate is very low. - if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) { - context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes)); - return; - } - const outHeight = outputShape[isChannelsLast ? 1 : 2]; - const outWidth = outputShape[isChannelsLast ? 2 : 3]; - const weightHeight = inputs[1].dims[2]; - const weightWidth = inputs[1].dims[3]; - - const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; - const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; - const dimInner = weightHeight * weightWidth * inputChannels; - - const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; - - - // STEP.1: transpose weight - const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute( - createTransposeProgramInfo(inputs[1], weightTransposePerm), - {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; - if (attributes.wIsConst && !context.kernelCustomData.wT) { - context.kernelCustomData.wT = transposedWeight; - } - - // STEP.2: prepare reshaped inputs - const convTransposeInputs = [inputs[0], transposedWeight]; - const hasBias = inputs.length === 3; - if (hasBias) { - if (!isChannelsLast && inputs[2].dims.length === 1) { - convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); - } else { - convTransposeInputs.push(inputs[2]); - } - } - - // STEP.3: compute matmul - context.compute( - createConv2DTransposeMatMulProgramInfo( - convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, - sequentialAccessByThreads), - {inputs: convTransposeInputs}); - }; +const convTranspose2d = ( + context: ComputeContext, + inputs: readonly TensorView[], + attributes: ConvTransposeAttributes, +): void => { + const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); + const isChannelsLast = attributes.format === 'NHWC'; + const outputShape = adjustedAttributes.outputShape; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's + // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit + // utilization rate is very low. + if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) { + context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes)); + return; + } + const outHeight = outputShape[isChannelsLast ? 1 : 2]; + const outWidth = outputShape[isChannelsLast ? 2 : 3]; + const weightHeight = inputs[1].dims[2]; + const weightWidth = inputs[1].dims[3]; + + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; + + const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; + + // STEP.1: transpose weight + const transposedWeight = + (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute(createTransposeProgramInfo(inputs[1], weightTransposePerm), { + inputs: [1], + outputs: [attributes.wIsConst ? -2 : -1], + })[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + + // STEP.2: prepare reshaped inputs + const convTransposeInputs = [inputs[0], transposedWeight]; + const hasBias = inputs.length === 3; + if (hasBias) { + if (!isChannelsLast && inputs[2].dims.length === 1) { + convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); + } else { + convTransposeInputs.push(inputs[2]); + } + } + + // STEP.3: compute matmul + context.compute( + createConv2DTransposeMatMulProgramInfo( + convTransposeInputs, + adjustedAttributes, + outputShape, + dimAOuter, + dimBOuter, + dimInner, + hasBias, + sequentialAccessByThreads, + ), + { inputs: convTransposeInputs }, + ); +}; const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { // extend the input to 2D by adding H dimension @@ -260,13 +304,14 @@ const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttri const inputs = [ context.inputs[0].reshape( - isChannelLast ? - // [N, W, C] -> [N, H=1, W, C] - [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]] : - // [N, C, W] -> [N, C, H=1, W] - [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]]), + isChannelLast + ? // [N, W, C] -> [N, H=1, W, C] + [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]] + : // [N, C, W] -> [N, C, H=1, W] + [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]], + ), //[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW] - context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]) + context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]), ]; if (context.inputs.length === 3) { inputs.push(context.inputs[2]); @@ -291,12 +336,17 @@ const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttri strides = [1].concat(strides); dilations = [1].concat(dilations); kernelShape = [1].concat(kernelShape); - const adjustedAttributes = - getAdjustedConvTransposeAttributes({...attributes, pads, strides, dilations, kernelShape}, inputs); - context.compute(createConvTranspose2DProgramInfo( - inputs, adjustedAttributes, - outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : - [outputShape[0], outputShape[1], outputShape[3]])); + const adjustedAttributes = getAdjustedConvTransposeAttributes( + { ...attributes, pads, strides, dilations, kernelShape }, + inputs, + ); + context.compute( + createConvTranspose2DProgramInfo(inputs, adjustedAttributes, (outputShape) => + isChannelLast + ? [outputShape[0], outputShape[2], outputShape[3]] + : [outputShape[0], outputShape[1], outputShape[3]], + ), + ); }; export const convTranspose = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 52bd69130e617..f1469d4ce67be 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -1,40 +1,46 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor-view'; -import {PoolConvUtil} from '../../util'; -import {AttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext} from '../types'; - -import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; -import {computeConv3DInfo, createConv3DNaiveProgramInfo} from './3rd-party/conv3d_naive_webgpu'; -import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo} from './conv-grouped'; -import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; -import {createNaiveMatmulProgramInfo} from './matmul'; -import {createTransposeProgramInfo} from './transpose'; - -export const calculateOutputShape = - (inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], - adjustPads: readonly number[], strides: readonly number[], isChannelLast: boolean): number[] => { - const batchSize = inputShape[0]; - const inputSpatialShape = inputShape.slice(isChannelLast ? 1 : 2, isChannelLast ? 3 : 4); - const spatialRank = inputSpatialShape.length; - const outChannels = kernelShape[0]; - const kernelSpatialShape = kernelShape.slice(2); - const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1)); - const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]); - const outputShape = - inputSpatialShapeWithPad.map((v, i) => Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i])); - outputShape.splice(0, 0, batchSize); - outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels); - return outputShape; - }; +import { TensorView } from '../../tensor-view'; +import { PoolConvUtil } from '../../util'; +import { AttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext } from '../types'; + +import { createConv2DMatMulProgramInfo } from './3rd-party/conv2d_mm_webgpu'; +import { computeConv3DInfo, createConv3DNaiveProgramInfo } from './3rd-party/conv3d_naive_webgpu'; +import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu'; +import { createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo } from './conv-grouped'; +import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils'; +import { createNaiveMatmulProgramInfo } from './matmul'; +import { createTransposeProgramInfo } from './transpose'; + +export const calculateOutputShape = ( + inputShape: readonly number[], + kernelShape: readonly number[], + dilations: readonly number[], + adjustPads: readonly number[], + strides: readonly number[], + isChannelLast: boolean, +): number[] => { + const batchSize = inputShape[0]; + const inputSpatialShape = inputShape.slice(isChannelLast ? 1 : 2, isChannelLast ? 3 : 4); + const spatialRank = inputSpatialShape.length; + const outChannels = kernelShape[0]; + const kernelSpatialShape = kernelShape.slice(2); + const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1)); + const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]); + const outputShape = inputSpatialShapeWithPad.map((v, i) => + Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]), + ); + outputShape.splice(0, 0, batchSize); + outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels); + return outputShape; +}; export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey { readonly autoPad: string; readonly dilations: readonly number[]; - readonly format: 'NHWC'|'NCHW'; + readonly format: 'NHWC' | 'NCHW'; readonly group: number; readonly kernelShape: readonly number[]; readonly pads: readonly number[]; @@ -105,12 +111,18 @@ const getAdjustedConvAttributes = (attributes: T, inpu } const pads = attributes.pads.slice(); PoolConvUtil.adjustPadsBasedOnAutoPad( - inputs[0].dims, attributes.strides, attributes.dilations, kernelShape, pads, attributes.format === 'NHWC', - attributes.autoPad); + inputs[0].dims, + attributes.strides, + attributes.dilations, + kernelShape, + pads, + attributes.format === 'NHWC', + attributes.autoPad, + ); // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads}); + Object.assign(newAttributes, { kernelShape, pads }); return newAttributes; }; @@ -136,7 +148,7 @@ export const parseConvAttributes = (attributes: Record): ConvAt strides, wIsConst, ...activationAttributes, - cacheKey: `${attributes.format};${activationAttributes.activation};` + cacheKey: `${attributes.format};${activationAttributes.activation};`, }; }; @@ -153,15 +165,28 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // [webgpu]Conv - conv - vectorize group - B // [webgpu]Conv - conv - vectorize group - D const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere'); - if (enableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && - inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) { + if ( + enableGroupedConvVectorize && + isChannelsLast && + inputs[1].dims[0] === attributes.group && + inputs[1].dims[1] === 1 && + attributes.dilations[0] === 1 && + attributes.dilations[1] === 1 + ) { const outputShape = calculateOutputShape( - inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, - isChannelsLast); - const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute( - createTransposeProgramInfo(inputs[1], weightTransposeAttribute), - {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + inputs[0].dims, + inputs[1].dims, + attributes.dilations, + adjustedAttributes.pads, + attributes.strides, + isChannelsLast, + ); + const transposedWeight = + (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), { + inputs: [1], + outputs: [attributes.wIsConst ? -2 : -1], + })[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; } @@ -169,8 +194,9 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut if (inputs.length === 3) { convInputs.push(inputs[2]); } - context.compute( - createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), {inputs: convInputs}); + context.compute(createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), { + inputs: convInputs, + }); } else { context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes)); } @@ -185,27 +211,45 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut const weightWidth = inputs[1].dims[3]; const outputShape = calculateOutputShape( - inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, - isChannelsLast); + inputs[0].dims, + inputs[1].dims, + attributes.dilations, + adjustedAttributes.pads, + attributes.strides, + isChannelsLast, + ); const outHeight = outputShape[isChannelsLast ? 1 : 2]; const outWidth = outputShape[isChannelsLast ? 2 : 3]; const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const sameSize = isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && - attributes.pads[0] === 0 && attributes.pads[1] === 0; - if (sameSize || - (weightHeight === 1 && weightWidth === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1 && - attributes.strides[0] === 1 && attributes.strides[1] === 1 && attributes.pads[0] === 0 && - attributes.pads[1] === 0)) { + const sameSize = + isChannelsLast && + weightHeight === inputHeight && + weightWidth === inputWidth && + attributes.pads[0] === 0 && + attributes.pads[1] === 0; + if ( + sameSize || + (weightHeight === 1 && + weightWidth === 1 && + attributes.dilations[0] === 1 && + attributes.dilations[1] === 1 && + attributes.strides[0] === 1 && + attributes.strides[1] === 1 && + attributes.pads[0] === 0 && + attributes.pads[1] === 0) + ) { // conv2dByMatMul const batch = outputShape[0]; let xReshaped, wReshaped, matmulOutputShape; const matmulInputs = []; if (isChannelsLast) { - const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute( - createTransposeProgramInfo(inputs[1], weightTransposeAttribute), - {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + const transposedWeight = + (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), { + inputs: [1], + outputs: [attributes.wIsConst ? -2 : -1], + })[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; } @@ -236,13 +280,14 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // Tune the threshold. if (N < 8 && K < 8) { context.compute( - createNaiveMatmulProgramInfo( - matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), - {inputs: matmulInputs}); + createNaiveMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), + { inputs: matmulInputs }, + ); } else { context.compute( - createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), - {inputs: matmulInputs}); + createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), + { inputs: matmulInputs }, + ); } return; } @@ -252,10 +297,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; // STEP.1: transpose weight - const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute( - createTransposeProgramInfo(inputs[1], weightTransposeAttribute), - {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + const transposedWeight = + (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), { + inputs: [1], + outputs: [attributes.wIsConst ? -2 : -1], + })[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; } @@ -271,10 +318,18 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; const dimInner = weightHeight * weightWidth * inputChannels; context.compute( - createConv2DMatMulProgramInfo( - convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, - sequentialAccessByThreads), - {inputs: convInputs}); + createConv2DMatMulProgramInfo( + convInputs, + adjustedAttributes, + outputShape, + dimAOuter, + dimBOuter, + dimInner, + hasBias, + sequentialAccessByThreads, + ), + { inputs: convInputs }, + ); }; const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => { @@ -282,13 +337,14 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => { const isChannelLast = attributes.format === 'NHWC'; const inputs = [ context.inputs[0].reshape( - isChannelLast ? - // [N, W, C] -> [N, H=1, W, C] - [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]] : - // [N, C, W] -> [N, C, H=1, W] - [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]]), + isChannelLast + ? // [N, W, C] -> [N, H=1, W, C] + [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]] + : // [N, C, W] -> [N, C, H=1, W] + [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]], + ), //[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW] - context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]) + context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]), ]; if (context.inputs.length === 3) { inputs.push(context.inputs[2]); @@ -297,10 +353,15 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => { const strides = [1].concat(attributes.strides); const dilations = [1].concat(attributes.dilations); const kernelShape = [1].concat(attributes.kernelShape); - const adjustedAttributes = getAdjustedConvAttributes({...attributes, pads, strides, dilations, kernelShape}, inputs); - context.compute(createGroupedConvProgramInfo( - inputs, adjustedAttributes, - outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [])); + const adjustedAttributes = getAdjustedConvAttributes( + { ...attributes, pads, strides, dilations, kernelShape }, + inputs, + ); + context.compute( + createGroupedConvProgramInfo(inputs, adjustedAttributes, (outputShape) => + isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [], + ), + ); }; const conv3d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { @@ -308,14 +369,24 @@ const conv3d = (context: ComputeContext, inputs: readonly TensorView[], attribut const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs); const pads = attributes.autoPad === 'NOTSET' ? attributes.pads : attributes.autoPad; const convInfo = computeConv3DInfo( - inputs[0].dims as [number, number, number, number, number], - inputs[1].dims as [number, number, number, number, number], - attributes.strides as number | [number, number, number], - attributes.dilations as number | [number, number, number], pads as string | number[], false, format); - context.compute(createConv3DNaiveProgramInfo( - inputs, adjustedAttributes, convInfo.outShape, + inputs[0].dims as [number, number, number, number, number], + inputs[1].dims as [number, number, number, number, number], + attributes.strides as number | [number, number, number], + attributes.dilations as number | [number, number, number], + pads as string | number[], + false, + format, + ); + context.compute( + createConv3DNaiveProgramInfo( + inputs, + adjustedAttributes, + convInfo.outShape, [convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth], - [convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left], format)); + [convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left], + format, + ), + ); }; export const conv = (context: ComputeContext, attributes: ConvAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index b8b50b35653a2..b8a7336f77cb6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -1,39 +1,41 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; - -import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper} from './common'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo } from '../types'; +import { createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper } from './common'; export interface CumSumAttributes extends AttributeWithCacheKey { readonly exclusive: boolean; readonly reverse: boolean; } -const createCumsumProgramInfo = - (inputType: number, inputShape: readonly number[], axisInput: TensorView, attributes: CumSumAttributes): - ProgramInfo => { - const outputSize = ShapeUtil.size(inputShape); // outputShape is same as inputShape. - const rank = inputShape.length; // input/output rank - const input = inputVariable('input', inputType, rank); - const output = outputVariable('output', inputType, rank); - const axisValue = axisInput.dataType === DataType.int32 ? axisInput.getInt32Array()[0] : - Number(axisInput.getBigInt64Array()[0]); - const axis = ShapeUtil.normalizeAxis(axisValue, rank); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `; - const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank); - const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0'; - const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1'); - return ` - ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('axis', 'u32') - .declareVariables(input, output)} +const createCumsumProgramInfo = ( + inputType: number, + inputShape: readonly number[], + axisInput: TensorView, + attributes: CumSumAttributes, +): ProgramInfo => { + const outputSize = ShapeUtil.size(inputShape); // outputShape is same as inputShape. + const rank = inputShape.length; // input/output rank + const input = inputVariable('input', inputType, rank); + const output = outputVariable('output', inputType, rank); + const axisValue = + axisInput.dataType === DataType.int32 ? axisInput.getInt32Array()[0] : Number(axisInput.getBigInt64Array()[0]); + const axis = ShapeUtil.normalizeAxis(axisValue, rank); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `; + const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank); + const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0'; + const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1'); + return ` + ${shaderHelper + .registerUniform('outputSize', 'u32') + .registerUniform('axis', 'u32') + .declareVariables(input, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var inputIndices = ${output.offsetToIndices('global_idx')}; @@ -46,33 +48,32 @@ const createCumsumProgramInfo = } ${output.setByOffset('global_idx', 'sum')}; }`; - }; - return { - name: 'CumSum', - shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']}, - getRunData: () => ({ - outputs: [{dims: inputShape, dataType: inputType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: axis}, - ...createTensorShapeVariables(inputShape, inputShape) - ] - - }), - getShaderSource - }; - }; - + }; + return { + name: 'CumSum', + shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank'] }, + getRunData: () => ({ + outputs: [{ dims: inputShape, dataType: inputType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: axis }, + ...createTensorShapeVariables(inputShape, inputShape), + ], + }), + getShaderSource, + }; +}; export const cumsum = (context: ComputeContext, attributes: CumSumAttributes): void => { const inputShape = context.inputs[0].dims; const inputType = context.inputs[0].dataType; const axis = context.inputs[1]; - context.compute(createCumsumProgramInfo(inputType, inputShape, axis, attributes), {inputs: [0]}); + context.compute(createCumsumProgramInfo(inputType, inputShape, axis, attributes), { inputs: [0] }); }; export const parseCumSumAttributes = (attributes: Record): CumSumAttributes => { - const exclusive = attributes.exclusive as number === 1; - const reverse = attributes.reverse as number === 1; - return createAttributeWithCacheKey({exclusive, reverse}); + const exclusive = (attributes.exclusive as number) === 1; + const reverse = (attributes.reverse as number) === 1; + return createAttributeWithCacheKey({ exclusive, reverse }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/depth-to-space.ts b/js/web/lib/wasm/jsep/webgpu/ops/depth-to-space.ts index 83809b3d5de6c..52ce8fc11e094 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/depth-to-space.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/depth-to-space.ts @@ -1,16 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo } from '../types'; -import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common'; export interface FormatAttributes { - readonly format: 'NHWC'|'NCHW'; + readonly format: 'NHWC' | 'NCHW'; } export interface DepthToSpaceAttributes extends FormatAttributes, AttributeWithCacheKey { @@ -47,13 +47,15 @@ const createDepthToSpaceProgramInfo = (inputTensor: TensorView, attributes: Dept const isDCRmode = attributes.mode === 'DCR'; if (isChannelLast) { [n, h, w, c] = inputTensor.dims; - shape = isDCRmode ? [n, h, w, blocksize, blocksize, c / (blocksize ** 2)] : - [n, h, w, c / (blocksize ** 2), blocksize, blocksize]; + shape = isDCRmode + ? [n, h, w, blocksize, blocksize, c / blocksize ** 2] + : [n, h, w, c / blocksize ** 2, blocksize, blocksize]; perm = isDCRmode ? [0, 1, 3, 2, 4, 5] : [0, 1, 4, 2, 5, 3]; } else { [n, h, w, c] = [inputTensor.dims[0], inputTensor.dims[2], inputTensor.dims[3], inputTensor.dims[1]]; - shape = isDCRmode ? [n, blocksize, blocksize, c / (blocksize ** 2), h, w] : - [n, c / (blocksize ** 2), blocksize, blocksize, h, w]; + shape = isDCRmode + ? [n, blocksize, blocksize, c / blocksize ** 2, h, w] + : [n, c / blocksize ** 2, blocksize, blocksize, h, w]; perm = isDCRmode ? [0, 3, 4, 1, 5, 2] : [0, 1, 4, 2, 5, 3]; } const reshapedInputTensor = inputTensor.reshape(shape); @@ -79,18 +81,24 @@ const createDepthToSpaceProgramInfo = (inputTensor: TensorView, attributes: Dept return { name: 'DepthToSpace', - shaderCache: {hint: `${inputTensor.dims};${attributes.blocksize};${attributes.mode}`, inputDependencies: ['rank']}, + shaderCache: { + hint: `${inputTensor.dims};${attributes.blocksize};${attributes.mode}`, + inputDependencies: ['rank'], + }, getRunData: (inputs) => { - const outputShape = isChannelLast ? [n, h * blocksize, w * blocksize, c / (blocksize ** 2)] : - [n, c / (blocksize ** 2), h * blocksize, w * blocksize]; + const outputShape = isChannelLast + ? [n, h * blocksize, w * blocksize, c / blocksize ** 2] + : [n, c / blocksize ** 2, h * blocksize, w * blocksize]; const outputSize = ShapeUtil.size(outputShape); const shapeBeforePerm = reshapedInputTensor.dims; const shapeAfterPerm = ShapeUtil.sortBasedOnPerm(shapeBeforePerm, perm); return { - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: - [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(shapeBeforePerm, shapeAfterPerm)], + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(shapeBeforePerm, shapeAfterPerm), + ], }; }, getShaderSource, @@ -103,8 +111,8 @@ export const depthToSpace = (context: ComputeContext, attributes: DepthToSpaceAt }; export const parseDepthToSpaceAttributes = (attributes: Record): DepthToSpaceAttributes => - createAttributeWithCacheKey({ - blocksize: attributes.blocksize as number, - mode: attributes.mode as string, - format: attributes.format as 'NHWC' | 'NCHW' - }); + createAttributeWithCacheKey({ + blocksize: attributes.blocksize as number, + mode: attributes.mode as string, + format: attributes.format as 'NHWC' | 'NCHW', + }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 19a009c2eb79b..48da675193ad8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common'; export interface EinsumAttributes extends AttributeWithCacheKey { readonly equation: string; @@ -20,17 +20,16 @@ export interface EinsumAttributes extends AttributeWithCacheKey { // Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to // 'Z' or '...' to represent arbitrary dimensions. -const symbolPattern = - '[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match -const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match -const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end. -const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match -const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end. +const symbolPattern = '[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match +const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match +const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end. +const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match +const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end. interface SymbolInfo { - count: number; // Symbol corresponding to a dimmension of an input - inputIndices: number[]; // Number of input variables the symbol corresponds to - dimValue: number; // Number of dimensions the symbol corresponds to + count: number; // Symbol corresponding to a dimmension of an input + inputIndices: number[]; // Number of input variables the symbol corresponds to + dimValue: number; // Number of dimensions the symbol corresponds to } class EinsumTerm { @@ -50,12 +49,15 @@ class EinsumTerm { this.symbolToIndices.set(symbol, value); } - symbolToIndices: Map; // Map from symbol to dimensions of the input corresponding to the term - inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs + symbolToIndices: Map; // Map from symbol to dimensions of the input corresponding to the term + inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs } class EinsumEquation { - constructor(inputs: readonly TensorView[], public readonly equation: string) { + constructor( + inputs: readonly TensorView[], + public readonly equation: string, + ) { this.hasEllipsis = false; this.symbolToInfo = new Map(); this.lhs = new Array(); @@ -80,9 +82,9 @@ class EinsumEquation { if (rhs === '') { // Construct RHS from LHS terms/symbols rhs += [...this.symbolToInfo.entries()] - .filter(([sym, info]) => (info.count === 1 || sym === '...')) - .map(([sym]) => sym) - .join(''); + .filter(([sym, info]) => info.count === 1 || sym === '...') + .map(([sym]) => sym) + .join(''); } else { if (!rhs.match(RegExp(termPattern))) { throw new Error('Invalid RHS'); @@ -103,7 +105,7 @@ class EinsumEquation { } }); this.rhs = this.processTerm(rhs, false, this.outputDims); - } // End of EinsumEqation constructor + } // End of EinsumEqation constructor // Add a symbol to the equation addSymbol(symbol: string, dimValue: number, inputIndex: number) { @@ -116,7 +118,7 @@ class EinsumEquation { info.inputIndices.push(inputIndex); } } else { - info = {count: 1, dimValue, inputIndices: [inputIndex]}; + info = { count: 1, dimValue, inputIndices: [inputIndex] }; } this.symbolToInfo.set(symbol, info); } @@ -128,7 +130,7 @@ class EinsumEquation { let ellipsisDims = []; let nextDim = 0; // For output empty string is allowed because the output may be reduced to a scalar value - if (!term.match(RegExp(termPatternOnly)) && (!isInput && term !== '')) { + if (!term.match(RegExp(termPatternOnly)) && !isInput && term !== '') { throw new Error('Invalid LHS term'); } const indexSymbols = term.match(RegExp(symbolPattern, 'g')); @@ -146,8 +148,10 @@ class EinsumEquation { } ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength); if (this.hasEllipsis) { - if (this.ellipsisDims.length !== ellipsisDims.length || - this.ellipsisDims.toString() !== ellipsisDims.toString()) { + if ( + this.ellipsisDims.length !== ellipsisDims.length || + this.ellipsisDims.toString() !== ellipsisDims.toString() + ) { throw new Error('Ellipsis dimensions mismatch'); } } else if (isInput) { @@ -170,92 +174,100 @@ class EinsumEquation { return einsumTerm; } - symbolToInfo: Map; // All symbols in the equation - hasEllipsis: boolean; // The equation has ellipsis or not - ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to. - lhs: EinsumTerm[]; // Terms on the left-hand side of the equation - rhs: EinsumTerm; // Term on the right-hand side of the equation - outputDims: number[]; // Output dimensions of the equation -} // End of class EinsumEquation + symbolToInfo: Map; // All symbols in the equation + hasEllipsis: boolean; // The equation has ellipsis or not + ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to. + lhs: EinsumTerm[]; // Terms on the left-hand side of the equation + rhs: EinsumTerm; // Term on the right-hand side of the equation + outputDims: number[]; // Output dimensions of the equation +} // End of class EinsumEquation const appendMax = (name: string): string => name + '_max'; -const createEinsumProgramInfo = - (inputShapes: Array, dataType: number, einsumEquation: EinsumEquation, - outputShape: readonly number[]): ProgramInfo => { - const ranks = inputShapes.map((dims) => dims.length); - const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank)); - const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', dataType, outputShape.length); - const uniformsSymbols = - [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol)); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const idxCopy: string[] = []; - const initProd = 'var prod = 1.0;'; - const initSum = 'var sum = 0.0;'; - const updateSum = 'sum += prod;'; - const reduceOpsSetIndices: string[] = []; - const reduceOpsLoopHeaders: string[] = []; - const reduceOpsLoopFooters: string[] = []; - const reduceOpCompute: string[] = []; - const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size; - einsumEquation.symbolToInfo.forEach((info, symbol) => { - if (einsumEquation.rhs.symbolToIndices.has(symbol)) { - const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0]; - if (outputIndex !== undefined) { - einsumEquation.lhs.forEach((term, i) => { - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); - } - indices.forEach((index) => { - idxCopy.push(`${ - inputVars[i].indicesSet( - `input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); - }); - } +const createEinsumProgramInfo = ( + inputShapes: Array, + dataType: number, + einsumEquation: EinsumEquation, + outputShape: readonly number[], +): ProgramInfo => { + const ranks = inputShapes.map((dims) => dims.length); + const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank)); + const outputSize = ShapeUtil.size(outputShape); + const output = outputVariable('output', dataType, outputShape.length); + const uniformsSymbols = [...einsumEquation.symbolToInfo.keys()].filter( + (symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol), + ); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const idxCopy: string[] = []; + const initProd = 'var prod = 1.0;'; + const initSum = 'var sum = 0.0;'; + const updateSum = 'sum += prod;'; + const reduceOpsSetIndices: string[] = []; + const reduceOpsLoopHeaders: string[] = []; + const reduceOpsLoopFooters: string[] = []; + const reduceOpCompute: string[] = []; + const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size; + einsumEquation.symbolToInfo.forEach((info, symbol) => { + if (einsumEquation.rhs.symbolToIndices.has(symbol)) { + const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0]; + if (outputIndex !== undefined) { + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + idxCopy.push( + `${inputVars[i].indicesSet( + `input${i}Indices`, + index, + output.indicesGet('outputIndices', outputIndex), + )}`, + ); }); } - } else { - einsumEquation.lhs.forEach((term, i) => { - if (info.inputIndices.includes(i)) { - const indices = term.symbolToIndices.get(symbol); - if (indices === undefined) { - throw new Error('Invalid symbol error'); - } - indices.forEach((index) => { - reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); - }); - reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); - } + }); + } + } else { + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); }); - reduceOpsLoopHeaders.push( - `for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`); - reduceOpsLoopFooters.push('}'); + reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); } }); - const reduceOps = isReduceOpsWithoutLoop ? - [ - ...idxCopy, - `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` - ] : - [ - ...idxCopy, - initSum, - ...reduceOpsLoopHeaders, - ...reduceOpsSetIndices, - initProd, - ...reduceOpCompute, - updateSum, - ...reduceOpsLoopFooters, - ]; - return ` - ${ - shaderHelper - .registerUniforms(uniformsSymbols.map((symbol) => ({name: `${appendMax(symbol)}`, type: 'u32'}))) - .registerUniform('outputSize', 'u32') - .declareVariables(...inputVars, output)} + reduceOpsLoopHeaders.push( + `for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`, + ); + reduceOpsLoopFooters.push('}'); + } + }); + const reduceOps = isReduceOpsWithoutLoop + ? [ + ...idxCopy, + `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`, + ] + : [ + ...idxCopy, + initSum, + ...reduceOpsLoopHeaders, + ...reduceOpsSetIndices, + initProd, + ...reduceOpCompute, + updateSum, + ...reduceOpsLoopFooters, + ]; + return ` + ${shaderHelper + .registerUniforms(uniformsSymbols.map((symbol) => ({ name: `${appendMax(symbol)}`, type: 'u32' }))) + .registerUniform('outputSize', 'u32') + .declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} @@ -264,32 +276,30 @@ const createEinsumProgramInfo = ${reduceOps.join('\n')}; ${output.setByOffset('global_idx', 'sum')}; }`; - }; + }; + return { + name: 'Einsum', + shaderCache: { hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank') }, + getRunData: () => { + // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The + // filter is added to make sure that dimValue is never 0. + const programUniformsInit: ProgramUniform[] = uniformsSymbols + .filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) + .map((symbol) => ({ type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0 })); + programUniformsInit.push({ type: DataType.uint32, data: outputSize }); + const programUniforms: ProgramUniform[] = inputShapes + .map((dims, _) => [...createTensorShapeVariables(dims)]) + .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); + programUniforms.push(...createTensorShapeVariables(outputShape)); return { - name: 'Einsum', - shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')}, - getRunData: () => { - // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The - // filter is added to make sure that dimValue is never 0. - const programUniformsInit: ProgramUniform[] = - uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) - .map( - (symbol) => - ({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); - programUniformsInit.push({type: DataType.uint32, data: outputSize}); - const programUniforms: ProgramUniform[] = - inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)]) - .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); - programUniforms.push(...createTensorShapeVariables(outputShape)); - return ({ - outputs: [{dims: outputShape, dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }); - }, - getShaderSource, + outputs: [{ dims: outputShape, dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, }; - }; + }, + getShaderSource, + }; +}; export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); @@ -300,5 +310,5 @@ export const einsum = (context: ComputeContext, attributes: EinsumAttributes): v export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { const equation = (attributes.equation as string).replace(/\s+/g, ''); - return createAttributeWithCacheKey({equation}); + return createAttributeWithCacheKey({ equation }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 80ee906423e19..4e2bfa9d89924 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -18,8 +18,11 @@ const validateInputs = (inputs: readonly TensorView[]): void => { let shapeIndex = shape.length < inputShape.length ? 0 : shape.length - inputShape.length; let inputShapeIndex = inputShape.length < shape.length ? 0 : inputShape.length - shape.length; for (; shapeIndex < shape.length && inputShapeIndex < inputShape.length; ++shapeIndex, ++inputShapeIndex) { - if (shape[shapeIndex] !== inputShape[inputShapeIndex] && shape[shapeIndex] !== 1 && - inputShape[inputShapeIndex] !== 1) { + if ( + shape[shapeIndex] !== inputShape[inputShapeIndex] && + shape[shapeIndex] !== 1 && + inputShape[inputShapeIndex] !== 1 + ) { throw new Error('Expand requires shape to be broadcastable to input'); } } @@ -38,8 +41,7 @@ const getAdjustedShape = (shape1: readonly number[], shape2: readonly number[]): }; const calculateOutputShape = (inputShape: readonly number[], shape: readonly number[]): number[] => - (inputShape.length > shape.length) ? getAdjustedShape(inputShape, shape) : getAdjustedShape(shape, inputShape); - + inputShape.length > shape.length ? getAdjustedShape(inputShape, shape) : getAdjustedShape(shape, inputShape); const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => { const inputShape = inputs[0].dims; @@ -84,21 +86,23 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ${assignment}`; }; - const programUniforms: ProgramUniform[] = - [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)]; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(inputShape, outputShape), + ]; return { name: 'Expand', - shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']}, + shaderCache: { hint: `${outputShape.length}`, inputDependencies: ['rank'] }, getShaderSource, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }) + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), }; }; export const expand = (context: ComputeContext): void => { validateInputs(context.inputs); - context.compute(createExpandProgramInfo(context.inputs), {inputs: [0]}); + context.compute(createExpandProgramInfo(context.inputs), { inputs: [0] }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts index f50a6a3f011fe..aedb700e73844 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts @@ -1,12 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo } from '../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common'; +import { + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglValueType, + UniformsArrayType, + WORKGROUP_SIZE, +} from './common'; import * as unary from './unary-op'; // GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias. @@ -22,15 +29,18 @@ const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): Program const bias = inputVariable('bias', dataType, [1], 4); const y = outputVariable('y', dataType, [1], 4); - const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}]; + const uniforms: UniformsArrayType = [ + { name: 'output_vec_size', type: 'u32' }, + { name: 'bias_size', type: 'u32' }, + ]; - const singleElementBias = (i: 0|1|2|3) => ` + const singleElementBias = (i: 0 | 1 | 2 | 3) => ` let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size; let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`; - const biasGetExpression = useVec4 ? - ` - let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` : - `${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)} + const biasGetExpression = useVec4 + ? ` + let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` + : `${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)} let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`; return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)} @@ -49,14 +59,16 @@ const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): Program return { name: 'FastGeluWithBias', - shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']}, + shaderCache: { hint: `${useVec4}`, inputDependencies: ['type', 'type'] }, getShaderSource, getRunData: (inputs) => ({ - outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], - programUniforms: - [{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}], - dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)} - }) + outputs: [{ dims: inputs[0].dims, dataType: inputs[0].dataType }], + programUniforms: [ + { type: DataType.uint32, data: Math.ceil(outputSize / 4) }, + { type: DataType.uint32, data: biasLength }, + ], + dispatchGroup: { x: Math.ceil(outputSize / WORKGROUP_SIZE / 4) }, + }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index cfa0b42ef9eeb..8c19ecae280bc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {MAX_CLIP, MIN_CLIP} from '../../util'; -import {ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { MAX_CLIP, MIN_CLIP } from '../../util'; +import { ProgramUniform } from '../types'; -import {UniformsArrayType} from './common'; +import { UniformsArrayType } from './common'; export interface InternalActivationAttributes { readonly activation: string; @@ -15,68 +15,80 @@ export interface InternalActivationAttributes { readonly beta?: number; } -export const getActivationSnippet = - (attributes: InternalActivationAttributes, valueType: string, baseType = 'f32'): string => { - switch (attributes.activation) { - case 'Relu': - return `value = max(value, ${valueType}(0.0));`; - case 'Sigmoid': - return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; - case 'Clip': - return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${ - baseType}(uniforms.clip_max)));`; - case 'HardSigmoid': - return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${ - baseType}(uniforms.beta)));`; - case 'LeakyRelu': - return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; - case 'Tanh': - return `let e2x = exp(-2.0 * abs(value)); +export const getActivationSnippet = ( + attributes: InternalActivationAttributes, + valueType: string, + baseType = 'f32', +): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${ + baseType + }(uniforms.clip_max)));`; + case 'HardSigmoid': + return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${ + baseType + }(uniforms.beta)));`; + case 'LeakyRelu': + return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; + case 'Tanh': + return `let e2x = exp(-2.0 * abs(value)); value = sign(value) * (1.0 - e2x) / (1.0 + e2x); `; - case '': - return ''; - // TODO: adding other activations that can be fused. - default: - throw new Error(`Unsupported activation ${attributes.activation}`); - } - }; + case '': + return ''; + // TODO: adding other activations that can be fused. + default: + throw new Error(`Unsupported activation ${attributes.activation}`); + } +}; -export const appendActivationUniformsData = - (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { - if (attributes.activation === 'Clip') { - programUniform.push( - {type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!}); - } else if (attributes.activation === 'HardSigmoid') { - programUniform.push( - {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); - } else if (attributes.activation === 'LeakyRelu') { - programUniform.push({type: DataType.float, data: attributes.alpha!}); - } - }; +export const appendActivationUniformsData = ( + attributes: InternalActivationAttributes, + programUniform: ProgramUniform[], +) => { + if (attributes.activation === 'Clip') { + programUniform.push( + { type: DataType.float, data: attributes.clipMax! }, + { type: DataType.float, data: attributes.clipMin! }, + ); + } else if (attributes.activation === 'HardSigmoid') { + programUniform.push( + { type: DataType.float, data: attributes.alpha! }, + { type: DataType.float, data: attributes.beta! }, + ); + } else if (attributes.activation === 'LeakyRelu') { + programUniform.push({ type: DataType.float, data: attributes.alpha! }); + } +}; export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => { if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + uniforms.push({ name: 'clip_max', type: 'f32' }, { name: 'clip_min', type: 'f32' }); } else if (attributes.activation === 'HardSigmoid') { - uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); + uniforms.push({ name: 'alpha', type: 'f32' }, { name: 'beta', type: 'f32' }); } else if (attributes.activation === 'LeakyRelu') { - uniforms.push({name: 'alpha', type: 'f32'}); + uniforms.push({ name: 'alpha', type: 'f32' }); } }; -export const parseInternalActivationAttributes = - (attributes: Record|undefined): InternalActivationAttributes => { - const activation = attributes?.activation as string || ''; - if (activation === 'HardSigmoid') { - const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5]; - return {activation, alpha, beta}; - } else if (activation === 'Clip') { - const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; - return {activation, clipMax, clipMin}; - } else if (activation === 'LeakyRelu') { - const [alpha] = attributes?.activation_params as [number] || [0.01]; - return {activation, alpha}; - } - return {activation}; - }; +export const parseInternalActivationAttributes = ( + attributes: Record | undefined, +): InternalActivationAttributes => { + const activation = (attributes?.activation as string) || ''; + if (activation === 'HardSigmoid') { + const [alpha, beta] = (attributes?.activation_params as [number, number]) || [0.2, 0.5]; + return { activation, alpha, beta }; + } else if (activation === 'Clip') { + const [clipMin, clipMax] = (attributes?.activation_params as [number, number]) || [MIN_CLIP, MAX_CLIP]; + return { activation, clipMax, clipMin }; + } else if (activation === 'LeakyRelu') { + const [alpha] = (attributes?.activation_params as [number]) || [0.01]; + return { activation, alpha }; + } + return { activation }; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index 4ab6c175a67e2..b3ad61bc3af43 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common'; export interface GatherElementsAttributes extends AttributeWithCacheKey { axis: number; @@ -28,41 +28,43 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const createGatherElementsProgramInfo = - (inputs: readonly TensorView[], attributes: GatherElementsAttributes): ProgramInfo => { - const inputShape = inputs[0].dims; - const inputOutputDataType = inputs[0].dataType; - const inputRank = inputShape.length; - - const indicesShape = inputs[1].dims; - const indicesDataType = inputs[1].dataType; - const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); - const axisDimLimit = inputShape[axis]; - - const outputShape = indicesShape.slice(0); - const outputSize = ShapeUtil.size(outputShape); - - const input = inputVariable('input', inputOutputDataType, inputRank); - const indices = inputVariable('indicesInput', indicesDataType, indicesShape.length); - const output = outputVariable('output', inputOutputDataType, outputShape.length); - - - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, - {type: DataType.uint32, data: axis} - ]; - programUniforms.push(...createTensorShapeVariables(inputShape, indicesShape, outputShape)); - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; - - // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits - // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor - // Input data will be treated as u32 or two u32 for 8-byte tensors - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('axisDimLimit', 'i32') - .registerUniform('axis', 'u32') - .declareVariables(input, indices, output)} +const createGatherElementsProgramInfo = ( + inputs: readonly TensorView[], + attributes: GatherElementsAttributes, +): ProgramInfo => { + const inputShape = inputs[0].dims; + const inputOutputDataType = inputs[0].dataType; + const inputRank = inputShape.length; + + const indicesShape = inputs[1].dims; + const indicesDataType = inputs[1].dataType; + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); + const axisDimLimit = inputShape[axis]; + + const outputShape = indicesShape.slice(0); + const outputSize = ShapeUtil.size(outputShape); + + const input = inputVariable('input', inputOutputDataType, inputRank); + const indices = inputVariable('indicesInput', indicesDataType, indicesShape.length); + const output = outputVariable('output', inputOutputDataType, outputShape.length); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.int32, data: axisDimLimit }, + { type: DataType.uint32, data: axis }, + ]; + programUniforms.push(...createTensorShapeVariables(inputShape, indicesShape, outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + + // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits + // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor + // Input data will be treated as u32 or two u32 for 8-byte tensors + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper + .registerUniform('outputSize', 'u32') + .registerUniform('axisDimLimit', 'i32') + .registerUniform('axis', 'u32') + .declareVariables(input, indices, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} @@ -79,20 +81,20 @@ const createGatherElementsProgramInfo = ${output.setByOffset('global_idx', 'value')}; }`; - return { - name: 'GatherElements', - shaderCache: {inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource, - }; - }; + return { + name: 'GatherElements', + shaderCache: { inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; export const parseGatherElementsAttributes = (attributes: Record): GatherElementsAttributes => - createAttributeWithCacheKey({axis: attributes.axis as number}); + createAttributeWithCacheKey({ axis: attributes.axis as number }); export const gatherElements = (context: ComputeContext, attributes: GatherElementsAttributes): void => { const inputs = context.inputs; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index d48bb909f7f8f..2492f3986863f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -34,8 +34,10 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, - {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, outputShape) + { type: DataType.uint32, data: outputSize }, + { type: DataType.int32, data: axisDimLimit }, + { type: DataType.uint32, data: axis }, + ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, outputShape), ]; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -43,12 +45,13 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length); const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); - const calcDataIndices = (x: number|string): string => { + const calcDataIndices = (x: number | string): string => { const indicesRank = indicesShape.length; let calcStr = `var indicesIndices${x} = ${indices.type.indices}(0);`; for (let i = 0; i < indicesRank; i++) { calcStr += `${indicesRank > 1 ? `indicesIndices${x}[${i}]` : `indicesIndices${x}`} = ${ - outputShape.length > 1 ? `outputIndices${x}[uniforms.axis + ${i}]` : `outputIndices${x}`};`; + outputShape.length > 1 ? `outputIndices${x}[uniforms.axis + ${i}]` : `outputIndices${x}` + };`; } calcStr += ` var idx${x} = ${indices.getByIndices(`indicesIndices${x}`)}; @@ -63,7 +66,8 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath j += indicesRank; } else { calcStr += `${inputRank > 1 ? `dataIndices${x}[${i}]` : `dataIndices${x}`} = ${ - outputShape.length > 1 ? `outputIndices${x}[${j}]` : `outputIndices${x}`};`; + outputShape.length > 1 ? `outputIndices${x}[${j}]` : `outputIndices${x}` + };`; j++; } } @@ -97,11 +101,11 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath `; } return ` - ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('axisDimLimit', 'i32') - .registerUniform('axis', 'u32') - .declareVariables(data, indices, output)} + ${shaderHelper + .registerUniform('outputSize', 'u32') + .registerUniform('axisDimLimit', 'i32') + .registerUniform('axis', 'u32') + .declareVariables(data, indices, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} ${assignment} @@ -109,20 +113,18 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }; return { name: 'Gather', - shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank', 'rank']}, + shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank', 'rank'] }, getRunData: () => ({ - outputs: [ - {dims: outputShape, dataType: inputs[0].dataType}, - ], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, }), getShaderSource, }; }; export const parseGatherAttributes = (attributes: Record): GatherAttributes => - createAttributeWithCacheKey({axis: attributes.axis as number}); + createAttributeWithCacheKey({ axis: attributes.axis as number }); export const gather = (context: ComputeContext, attributes: GatherAttributes): void => { const inputs = context.inputs; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 76302e1af2e53..7f2469d95e1c1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -1,13 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {GemmUtil, ShapeUtil} from '../../util'; -import {AttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; - -import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { GemmUtil, ShapeUtil } from '../../util'; +import { AttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; + +import { + createTensorShapeVariables, + IndicesHelper, + inputVariable, + outputVariable, + ShaderHelper, + UniformsArrayType, +} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs) { @@ -22,8 +29,7 @@ const validateInputs = (inputs: readonly TensorView[]): void => { throw new Error('Invalid input shape of C'); } - if ((inputs[0].dataType !== inputs[1].dataType) || - (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType)) { + if (inputs[0].dataType !== inputs[1].dataType || (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType)) { throw new Error('Input types are mismatched'); } }; @@ -39,16 +45,24 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt const aShape = inputs[0].dims.slice(); const bShape = inputs[1].dims.slice(); const [M, N, K] = GemmUtil.getShapeOfGemmResult( - aShape, attributes.transA, bShape, attributes.transB, inputs.length === 3 ? inputs[2].dims : undefined); + aShape, + attributes.transA, + bShape, + attributes.transB, + inputs.length === 3 ? inputs[2].dims : undefined, + ); const outputShape = [M, N]; if (!outputShape) { - throw new Error('Can\'t use gemm on the given tensors'); + throw new Error("Can't use gemm on the given tensors"); } const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, - {type: DataType.uint32, data: K}, {type: DataType.float, data: attributes.alpha}, - {type: DataType.float, data: attributes.beta} + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: M }, + { type: DataType.uint32, data: N }, + { type: DataType.uint32, data: K }, + { type: DataType.float, data: attributes.alpha }, + { type: DataType.float, data: attributes.beta }, ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; if (inputs.length === 3) { @@ -73,7 +87,7 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt const a = inputVariable('a', inputs[0].dataType, inputs[0].dims); const b = inputVariable('b', inputs[1].dataType, inputs[1].dims); const dataType = a.type.value; - let c: IndicesHelper|null = null; + let c: IndicesHelper | null = null; const variables = [a, b]; if (inputs.length === 3) { c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length); @@ -82,8 +96,12 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt const output = outputVariable('output', inputs[0].dataType, outputShape.length); variables.push(output); const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'K', type: 'u32'}, - {name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'} + { name: 'output_size', type: 'u32' }, + { name: 'M', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'K', type: 'u32' }, + { name: 'alpha', type: 'f32' }, + { name: 'beta', type: 'f32' }, ]; return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} @@ -103,7 +121,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt ${(() => { if (c != null) { return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${ - dataType}(uniforms.beta) * ${c.getByOffset('cOffset')};`; + dataType + }(uniforms.beta) * ${c.getByOffset('cOffset')};`; } return ''; })()} @@ -113,11 +132,11 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt return { name: 'Gemm', - shaderCache: {hint: `${attributes.cacheKey}`, inputDependencies}, + shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies }, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, }), getShaderSource, }; @@ -128,7 +147,13 @@ export const parseGemmAttributes = (attributes: Record): GemmAt const transB = attributes.transB as boolean; const alpha = attributes.alpha as number; const beta = attributes.beta as number; - return {transA, transB, alpha, beta, cacheKey: `${attributes.transA};${attributes.transB};${attributes.alpha === 1}`}; + return { + transA, + transB, + alpha, + beta, + cacheKey: `${attributes.transA};${attributes.transB};${attributes.alpha === 1}`, + }; }; export const gemm = (context: ComputeContext, attributes: GemmAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 0558d1caf76a6..56291c037b7da 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -1,17 +1,23 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; - -import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; -import {maybeTransposeToBNSHAndAddBias} from './multihead-attention'; -import {createTileProgramInfo} from './tile'; -import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; + +import { + applyAttention, + AttentionAttrs, + AttentionMaskType, + AttentionParameters, + AttentionQkvFormat, +} from './attention'; +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; +import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; +import { createTileProgramInfo } from './tile'; +import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { const query = inputs[0]; @@ -56,8 +62,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent const dmmhaPacking = false; const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; - const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : - attributes.numHeads * query.dims[4]; + const hiddenSize = + query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4]; let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; @@ -114,7 +120,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent } qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; kvSequenceLength = key.dims[1]; - } else { // key_dims.size() == 4 (cross-attention with past_key) + } else { + // key_dims.size() == 4 (cross-attention with past_key) if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } @@ -122,7 +129,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent qkvFormat = AttentionQkvFormat.unknown; kvSequenceLength = key.dims[2]; } - } else { // packed QKV + } else { + // packed QKV if (query.dims.length !== 3 && query.dims.length !== 5) { throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); } @@ -186,69 +194,77 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent }; }; -const createConcatProgramInfo = - (a: TensorView, b: TensorView|undefined, dataType: DataType, params: AttentionParameters): ProgramInfo => { - const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize]; - const component = 4; - const outputSize = ShapeUtil.size(outputShape) / component; - const presentSequenceLength = params.totalSequenceLength; - const output = outputVariable('present_kv', dataType, outputShape.length, component); - const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component); - const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined; - - const H = Math.ceil(params.headSize / component); - const dispatch = {x: presentSequenceLength, y: a.dims[0], z: 1}; - - const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank']; - - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: params.pastSequenceLength}, - {type: DataType.uint32, data: params.kvSequenceLength}, - {type: DataType.uint32, data: params.totalSequenceLength} - ]; - - const inputs = [inputA]; - if (inputB) { - programUniforms.push( - ...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(b!.dims), - ...createTensorShapeVariables(outputShape)); - inputs.push(inputB); - } else { - programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape)); - } - const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'past_seqlen', type: 'u32'}, {name: 'new_seqlen', type: 'u32'}, - {name: 'present_seqlen', type: 'u32'} - ]; - - const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H; +const createConcatProgramInfo = ( + a: TensorView, + b: TensorView | undefined, + dataType: DataType, + params: AttentionParameters, +): ProgramInfo => { + const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize]; + const component = 4; + const outputSize = ShapeUtil.size(outputShape) / component; + const presentSequenceLength = params.totalSequenceLength; + const output = outputVariable('present_kv', dataType, outputShape.length, component); + const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component); + const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined; + + const H = Math.ceil(params.headSize / component); + const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 }; + + const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank']; + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: params.pastSequenceLength }, + { type: DataType.uint32, data: params.kvSequenceLength }, + { type: DataType.uint32, data: params.totalSequenceLength }, + ]; + + const inputs = [inputA]; + if (inputB) { + programUniforms.push( + ...createTensorShapeVariables(a.dims), + ...createTensorShapeVariables(b!.dims), + ...createTensorShapeVariables(outputShape), + ); + inputs.push(inputB); + } else { + programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape)); + } + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'past_seqlen', type: 'u32' }, + { name: 'new_seqlen', type: 'u32' }, + { name: 'present_seqlen', type: 'u32' }, + ]; + + const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H; var past_head_stride = uniforms.past_seqlen * H; if (is_bsnh) { past_head_stride = H; } let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset];`; - const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H; + const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H; let new_row_stride = num_heads * H; let new_head_stride = H; let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; present_kv[out_offset] = new_kv[in_offset];`; - const concatStr = b ? `if (s < past_seqlen) { + const concatStr = b + ? `if (s < past_seqlen) { ${pastStr} } else if (s < past_seqlen + uniforms.new_seqlen) { ${newStr} - }` : - `if (s < past_seqlen + uniforms.new_seqlen) { + }` + : `if (s < past_seqlen + uniforms.new_seqlen) { ${newStr} }`; - // TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit. - const getShaderSource = (shaderHelper: ShaderHelper) => ` + // TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit. + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)} - ${shaderHelper.mainStart([ - H, params.kvNumHeads!, 1 - ])} + ${shaderHelper.mainStart([H, params.kvNumHeads!, 1])} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} var indices = ${output.offsetToIndices('global_idx')}; let h = local_id.x; @@ -277,53 +293,66 @@ const createConcatProgramInfo = ${concatStr} }`; - return { - name: 'ConcatPastNew', - shaderCache: {hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType}], - dispatchGroup: dispatch, - programUniforms, - }), - getShaderSource, - }; - }; + return { + name: 'ConcatPastNew', + shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType }], + dispatchGroup: dispatch, + programUniforms, + }), + getShaderSource, + }; +}; export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => - createAttributeWithCacheKey({...attributes}); - -const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); - -const maybeExpandAndTransposeToBNSH = - (context: ComputeContext, input: TensorView, pastKV: TensorView|undefined, params: AttentionParameters, - outputIndex: number) => { - let reshapedInput = input; - const numHeads = params.kvNumHeads!; - const nReps = params.nReps!; - if (input.dims.length === 3 && params.kvSequenceLength !== 0) { - reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]); - } + createAttributeWithCacheKey({ ...attributes }); + +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] }); + +const maybeExpandAndTransposeToBNSH = ( + context: ComputeContext, + input: TensorView, + pastKV: TensorView | undefined, + params: AttentionParameters, + outputIndex: number, +) => { + let reshapedInput = input; + const numHeads = params.kvNumHeads!; + const nReps = params.nReps!; + if (input.dims.length === 3 && params.kvSequenceLength !== 0) { + reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]); + } - if (pastKV) { - reshapedInput = context.compute( - createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), - {inputs: [reshapedInput, pastKV], outputs: [params.isPastkvBSNH ? outputIndex : -1]})[0]; - } else { - reshapedInput = context.compute( - createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), - {inputs: [reshapedInput], outputs: [params.isPastkvBSNH ? outputIndex : -1]})[0]; - } - if (nReps !== 1) { - reshapedInput = context.compute( - createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), {inputs: [reshapedInput], outputs: [-1]})[0]; - reshapedInput = - reshapedInput.reshape([params.batchSize, params.totalSequenceLength, numHeads * nReps, params.headSize]); - } + if (pastKV) { + reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), { + inputs: [reshapedInput, pastKV], + outputs: [params.isPastkvBSNH ? outputIndex : -1], + })[0]; + } else { + reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), { + inputs: [reshapedInput], + outputs: [params.isPastkvBSNH ? outputIndex : -1], + })[0]; + } + if (nReps !== 1) { + reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), { + inputs: [reshapedInput], + outputs: [-1], + })[0]; + reshapedInput = reshapedInput.reshape([ + params.batchSize, + params.totalSequenceLength, + numHeads * nReps, + params.headSize, + ]); + } - return context.compute( - createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), - {inputs: [reshapedInput], outputs: [-1]})[0]; - }; + return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { + inputs: [reshapedInput], + outputs: [-1], + })[0]; +}; export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { const params = validateInputs(context.inputs, attributes); @@ -336,8 +365,15 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti } const Q = maybeTransposeToBNSHAndAddBias( - context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], undefined, - 0); + context, + params.batchSize, + params.numHeads, + params.sequenceLength, + params.headSize, + context.inputs[0], + undefined, + 0, + ); const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index c1d762e62aaa9..7b6140f3b1185 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -1,45 +1,62 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; -import {createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +import { + createTensorShapeVariables, + fillVector, + getMaxComponents, + inputVariable, + outputVariable, + ShaderHelper, + sumVector, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from './common'; export interface InstanceNormAttributes { epsilon: number; - format: 'NHWC'|'NCHW'; + format: 'NHWC' | 'NCHW'; } -const createInstanceNormProgramInfo = - (inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { - const xShape = inputs[0].dims; - const outputShape = xShape; - const axis = 2; - const normCount = ShapeUtil.sizeToDimension(xShape, axis); - const normSize = ShapeUtil.sizeFromDimension(xShape, axis); - const components = getMaxComponents(normSize); - const normPackedSize = normSize / components; - const inputShape = [xShape[0], xShape[1], normPackedSize]; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; - const programUniforms: ProgramUniform[] = - [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; - programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); +const createInstanceNormProgramInfo = ( + inputs: readonly TensorView[], + attributes: InstanceNormAttributes, +): ProgramInfo => { + const xShape = inputs[0].dims; + const outputShape = xShape; + const axis = 2; + const normCount = ShapeUtil.sizeToDimension(xShape, axis); + const normSize = ShapeUtil.sizeFromDimension(xShape, axis); + const components = getMaxComponents(normSize); + const normPackedSize = normSize / components; + const inputShape = [xShape[0], xShape[1], normPackedSize]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: normSize }, + { type: DataType.uint32, data: normPackedSize }, + ]; + programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); - const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); - const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); - const variables = [x, scale, bias, output]; - const dataType = x.type.value; - const f32Type = components === 1 ? 'f32' : `vec${components}`; - const workgroupSize = 64; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); + const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); + const variables = [x, scale, bias, output]; + const dataType = x.type.value; + const f32Type = components === 1 ? 'f32' : `vec${components}`; + const workgroupSize = 64; - const uniforms: UniformsArrayType = [{name: 'normSize', type: 'u32'}, {name: 'normPackedSize', type: 'u32'}]; - return ` + const uniforms: UniformsArrayType = [ + { name: 'normSize', type: 'u32' }, + { name: 'normPackedSize', type: 'u32' }, + ]; + return ` var meanShared : f32; var squaredNormShared : f32; var workgroupShared : array<${f32Type}, ${workgroupSize}>; @@ -97,49 +114,56 @@ const createInstanceNormProgramInfo = let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ - f32Type}(channelShift)); + f32Type + }(channelShift)); ${output.set('batch', 'channel', 'h', 'value')}; } }`; - }; - return { - ...{name: 'InstanceNormalization'}, - // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: {hint: `${attributes.epsilon};${components}`, inputDependencies}, - getRunData: () => ({ - outputs: [ - {dims: outputShape, dataType: inputs[0].dataType}, - ], - dispatchGroup: {x: normCount}, - programUniforms - }), - getShaderSource, - }; - }; + }; + return { + ...{ name: 'InstanceNormalization' }, + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: { hint: `${attributes.epsilon};${components}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: normCount }, + programUniforms, + }), + getShaderSource, + }; +}; -const computeMean = - (context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number, - epsilon: number) => { - const components = getMaxComponents(c); - const WG = 64; - // we will store channel scale and channel shift in [2, components] matrix - // or in vec2 when components == 1 - const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`; - const sumCastType = components === 1 ? 'f32' : `vec${components}f`; - const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`; - const unitsOfWork = n * c / components; - const wgSize = Math.ceil(h / WG); +const computeMean = ( + context: ComputeContext, + input: TensorView, + scale: TensorView, + bias: TensorView, + n: number, + h: number, + c: number, + epsilon: number, +) => { + const components = getMaxComponents(c); + const WG = 64; + // we will store channel scale and channel shift in [2, components] matrix + // or in vec2 when components == 1 + const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const sumCastType = components === 1 ? 'f32' : `vec${components}f`; + const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`; + const unitsOfWork = (n * c) / components; + const wgSize = Math.ceil(h / WG); - const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; - const meanProgramUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h}, - {type: DataType.uint32, data: Math.floor(c / components)}, - {type: DataType.uint32, data: Math.floor(h * c / components)} - ]; + const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + const meanProgramUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: wgSize }, + { type: DataType.uint32, data: h }, + { type: DataType.uint32, data: Math.floor(c / components) }, + { type: DataType.uint32, data: Math.floor((h * c) / components) }, + ]; - const getMeanShaderSource = (shaderHelper: ShaderHelper) => { - const inputHelper = inputVariable('input', input.dataType, input.dims, components); - return ` + const getMeanShaderSource = (shaderHelper: ShaderHelper) => { + const inputHelper = inputVariable('input', input.dataType, input.dims, components); + return ` ${shaderHelper.declareVariables(inputHelper)} @group(0) @binding(1) var output : array<${outputType}>; struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32}; @@ -164,33 +188,33 @@ const computeMean = } output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; }`; - }; + }; - const meanValues = context.compute( - { - name: 'InstanceNormComputeMean', - shaderCache: {hint: `${components}`, inputDependencies: meanInputDependencies}, - getRunData: () => ({ - outputs: [ - {dims: [n, c, WG, 2], dataType: DataType.float}, - ], - dispatchGroup: {x: n * c / components}, - programUniforms: meanProgramUniforms - }), - getShaderSource: getMeanShaderSource, - }, - {inputs: [input], outputs: [-1]})[0]; + const meanValues = context.compute( + { + name: 'InstanceNormComputeMean', + shaderCache: { hint: `${components}`, inputDependencies: meanInputDependencies }, + getRunData: () => ({ + outputs: [{ dims: [n, c, WG, 2], dataType: DataType.float }], + dispatchGroup: { x: (n * c) / components }, + programUniforms: meanProgramUniforms, + }), + getShaderSource: getMeanShaderSource, + }, + { inputs: [input], outputs: [-1] }, + )[0]; - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h}, - {type: DataType.uint32, data: Math.floor(c / components)}, - {type: DataType.uint32, data: Math.floor(WG * c / components)} - ]; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; - const getShaderSource = (shaderHelper: ShaderHelper) => { - const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); - const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); - return ` + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: unitsOfWork }, + { type: DataType.uint32, data: h }, + { type: DataType.uint32, data: Math.floor(c / components) }, + { type: DataType.uint32, data: Math.floor((WG * c) / components) }, + ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); + const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + return ` @group(0) @binding(0) var input : array<${outputType}>; @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; @@ -219,47 +243,51 @@ const computeMean = output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; }`; - }; - return context.compute( - { - name: 'InstanceNormComputeChannelScaleShift', - // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: {hint: `${components};${epsilon}`, inputDependencies}, - getRunData: () => ({ - outputs: [ - {dims: [n, c, 2], dataType: DataType.float}, - ], - dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource, - }, - {inputs: [meanValues, scale, bias], outputs: [-1]})[0]; - }; + }; + return context.compute( + { + name: 'InstanceNormComputeChannelScaleShift', + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: { hint: `${components};${epsilon}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: [n, c, 2], dataType: DataType.float }], + dispatchGroup: { x: Math.ceil(unitsOfWork / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }, + { inputs: [meanValues, scale, bias], outputs: [-1] }, + )[0]; +}; -const createInstanceNormNHWCProgramInfo = - (context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => { - const xShape = inputs[0].dims; - const outputShape = xShape; - const N = xShape[0]; - const C = xShape[xShape.length - 1]; - const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; - const components = getMaxComponents(C); - const outputSize = ShapeUtil.size(outputShape) / components; - const programUniforms: ProgramUniform[] = - [{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}]; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; - // first compute mean - const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; - const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; +const createInstanceNormNHWCProgramInfo = ( + context: ComputeContext, + inputs: readonly TensorView[], + attributes: InstanceNormAttributes, +) => { + const xShape = inputs[0].dims; + const outputShape = xShape; + const N = xShape[0]; + const C = xShape[xShape.length - 1]; + const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; + const components = getMaxComponents(C); + const outputSize = ShapeUtil.size(outputShape) / components; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: H }, + { type: DataType.uint32, data: Math.floor(C / components) }, + ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + // first compute mean + const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; - const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); - const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); + const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); + const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); - return ` + return ` @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; @group(0) @binding(1) var scaleInput : array<${scaleType}>; @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; @@ -274,20 +302,21 @@ const createInstanceNormNHWCProgramInfo = let scale = scaleInput[scaleOffset]; output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); }`; - }; - context.compute( - { - name: 'InstanceNormalizationNHWC', - shaderCache: {hint: `${components}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource, - }, - {inputs: [inputs[0], channelScaleShift]}); - }; + }; + context.compute( + { + name: 'InstanceNormalizationNHWC', + shaderCache: { hint: `${components}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }, + { inputs: [inputs[0], channelScaleShift] }, + ); +}; export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => { if (attributes.format === 'NHWC') { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index b2a1bbe2bea49..292be26aee2dd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -1,12 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; - -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType,} from './common'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; + +import { + castToF32, + fillVector, + getMaxComponents, + inputVariable, + outputVariable, + ShaderHelper, + sumVector, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from './common'; interface LayerNormAttributes { simplified: boolean; @@ -20,70 +30,76 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const createLayerNormProgramInfo = - (inputs: readonly TensorView[], attributes: LayerNormAttributes, outputCount: number): ProgramInfo => { - const simplified = attributes.simplified; - - const xShape = inputs[0].dims; - const scale = inputs[1]; - const bias = !simplified && inputs[2]; - - const outputShape = xShape; - const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length); - const normCount = ShapeUtil.sizeToDimension(xShape, axis); - const normSize = ShapeUtil.sizeFromDimension(xShape, axis); - - const scaleSize = ShapeUtil.size(scale.dims); - const biasSize = bias ? ShapeUtil.size(bias.dims) : 0; - if (scaleSize !== normSize || (bias && biasSize !== normSize)) { - throw new Error(`Size of X.shape()[axis:] == ${normSize}. +const createLayerNormProgramInfo = ( + inputs: readonly TensorView[], + attributes: LayerNormAttributes, + outputCount: number, +): ProgramInfo => { + const simplified = attributes.simplified; + + const xShape = inputs[0].dims; + const scale = inputs[1]; + const bias = !simplified && inputs[2]; + + const outputShape = xShape; + const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length); + const normCount = ShapeUtil.sizeToDimension(xShape, axis); + const normSize = ShapeUtil.sizeFromDimension(xShape, axis); + + const scaleSize = ShapeUtil.size(scale.dims); + const biasSize = bias ? ShapeUtil.size(bias.dims) : 0; + if (scaleSize !== normSize || (bias && biasSize !== normSize)) { + throw new Error(`Size of X.shape()[axis:] == ${normSize}. Size of scale and bias (if provided) must match this. Got scale size of ${scaleSize} and bias size of ${biasSize}`); - } - - const meanInvStdDevDim: number[] = []; - for (let i = 0; i < xShape.length; ++i) { - if (i < axis) { - meanInvStdDevDim.push(xShape[i]); - } else { - meanInvStdDevDim.push(1); - } - } - const components = getMaxComponents(normSize); - const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: normCount}, {type: DataType.float, data: normSize}, - {type: DataType.uint32, data: Math.floor(normSize / components)}, - {type: DataType.float, data: attributes.epsilon} - ]; - if (bias) { - inputDependencies.push('type'); - } - const hasMeanDataOutput = outputCount > 1; - const hasInvStdOutput = outputCount > 2; - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('scale', scale.dataType, scale.dims, components), - ]; - if (bias) { - variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - if (hasMeanDataOutput) { - variables.push(outputVariable('mean_data_output', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdOutput) { - variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); - } - - const uniforms: UniformsArrayType = [ - {name: 'norm_count', type: 'u32'}, {name: 'norm_size', type: 'f32'}, - {name: 'norm_size_vectorized', type: 'u32'}, {name: 'epsilon', type: 'f32'} - ]; - return ` + } + + const meanInvStdDevDim: number[] = []; + for (let i = 0; i < xShape.length; ++i) { + if (i < axis) { + meanInvStdDevDim.push(xShape[i]); + } else { + meanInvStdDevDim.push(1); + } + } + const components = getMaxComponents(normSize); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: normCount }, + { type: DataType.float, data: normSize }, + { type: DataType.uint32, data: Math.floor(normSize / components) }, + { type: DataType.float, data: attributes.epsilon }, + ]; + if (bias) { + inputDependencies.push('type'); + } + const hasMeanDataOutput = outputCount > 1; + const hasInvStdOutput = outputCount > 2; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('scale', scale.dataType, scale.dims, components), + ]; + if (bias) { + variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanDataOutput) { + variables.push(outputVariable('mean_data_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } + + const uniforms: UniformsArrayType = [ + { name: 'norm_count', type: 'u32' }, + { name: 'norm_size', type: 'f32' }, + { name: 'norm_size_vectorized', type: 'u32' }, + { name: 'epsilon', type: 'f32' }, + ]; + return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} @@ -98,7 +114,8 @@ const createLayerNormProgramInfo = } let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size; let inv_std_dev = inverseSqrt(${sumVector('mean_square_vector', components)} / uniforms.norm_size ${ - simplified ? '' : '- mean * mean'} + uniforms.epsilon); + simplified ? '' : '- mean * mean' + } + uniforms.epsilon); for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; @@ -111,23 +128,26 @@ const createLayerNormProgramInfo = ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''}; }`; - }; - const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; - if (hasMeanDataOutput) { - outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); - } - if (hasInvStdOutput) { - outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); - } - - return { - name: 'LayerNormalization', - shaderCache: {hint: `${components};${outputCount};${simplified}`, inputDependencies}, - getRunData: () => - ({outputs, dispatchGroup: {x: Math.ceil(normCount / 64 /* workgroup size */)}, programUniforms}), - getShaderSource, - }; - }; + }; + const outputs = [{ dims: outputShape, dataType: inputs[0].dataType }]; + if (hasMeanDataOutput) { + outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float }); + } + if (hasInvStdOutput) { + outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float }); + } + + return { + name: 'LayerNormalization', + shaderCache: { hint: `${components};${outputCount};${simplified}`, inputDependencies }, + getRunData: () => ({ + outputs, + dispatchGroup: { x: Math.ceil(normCount / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; export const layerNorm = (context: ComputeContext, attributes: LayerNormAttributes): void => { validateInputs(context.inputs); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 1a92d861002fb..d2a6b2d352e25 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,113 +1,138 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {BroadcastUtil, ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { BroadcastUtil, ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; -import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; +import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu'; +import { + createTensorShapeVariables, + getBroadcastDims, + getMaxComponents, + IndicesHelper, + inputVariable, + internalVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from './common'; +import { + appendActivationUniforms, + appendActivationUniformsData, + getActivationSnippet, + InternalActivationAttributes, +} from './fuse-utils'; -export const createNaiveMatmulProgramInfo = - (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], - reshapedOutputShape?: readonly number[], - isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { - const aShape = inputs[0].dims; - const bShape = inputs[1].dims; +export const createNaiveMatmulProgramInfo = ( + inputs: readonly TensorView[], + activationAttributes: InternalActivationAttributes, + outputShape: readonly number[], + reshapedOutputShape?: readonly number[], + isChannelsLast = false /* only used for conv2dByMatMul*/, +): ProgramInfo => { + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; - const M = aShape[aShape.length - 2]; - const N = bShape[bShape.length - 1]; - const K = aShape[aShape.length - 1]; - const components = getMaxComponents(N); - const aComponents = getMaxComponents(K); - const outputNumber = getMaxComponents(M); - const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; - const hasBias = inputs.length > 2; - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const batchSize = ShapeUtil.size(outerDims); - const outputShapeInShader = [batchSize, M, N]; + const M = aShape[aShape.length - 2]; + const N = bShape[bShape.length - 1]; + const K = aShape[aShape.length - 1]; + const components = getMaxComponents(N); + const aComponents = getMaxComponents(K); + const outputNumber = getMaxComponents(M); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; + const hasBias = inputs.length > 2; + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); + const batchSize = ShapeUtil.size(outerDims); + const outputShapeInShader = [batchSize, M, N]; - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, - {type: DataType.uint32, data: K} - ]; - appendActivationUniformsData(activationAttributes, programUniforms); - programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); - if (hasBias) { - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - } - programUniforms.push(...createTensorShapeVariables(outputShapeInShader)); + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: M }, + { type: DataType.uint32, data: N }, + { type: DataType.uint32, data: K }, + ]; + appendActivationUniformsData(activationAttributes, programUniforms); + programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + } + programUniforms.push(...createTensorShapeVariables(outputShapeInShader)); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length); - const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); - const b = inputVariable('b', inputs[1].dataType, bShape.length, components); - const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const baseType = tensorTypeToWsglStorageType(output.type.tensor); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); - const inputVariables = [a, b]; - let processBias = ''; - if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); - processBias = `${ - isChannelsLast ? `value += bias[col / ${biasComponents}];` : - `value += ${output.type.value}(bias[row + i]);`}`; - } + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length); + const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const b = inputVariable('b', inputs[1].dataType, bShape.length, components); + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); + const inputVariables = [a, b]; + let processBias = ''; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + processBias = `${ + isChannelsLast ? `value += bias[col / ${biasComponents}];` : `value += ${output.type.value}(bias[row + i]);` + }`; + } - const outerDimsA = aShape.slice(0, -2); - const outerDimsB = bShape.slice(0, -2); - const broadCastADims = getBroadcastDims(outerDimsA, outerDims); - const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); - const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, - {name: 'K', type: 'u32'} - ]; - appendActivationUniforms(activationAttributes, uniforms); + const outerDimsA = aShape.slice(0, -2); + const outerDimsB = bShape.slice(0, -2); + const broadCastADims = getBroadcastDims(outerDimsA, outerDims); + const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'M', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'K', type: 'u32' }, + ]; + appendActivationUniforms(activationAttributes, uniforms); - const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { - const rank = variable.rank; - const name = variable.name; - if (rank === 2) { - return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`; - } - const batchRank = batchDims.rank; - let resStr = `var ${name}_indices: ${variable.type.indices};`; - for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`; - } - broadCastDims.forEach(i => { - resStr += `\n${name}_indices[${i}] = 0;`; - }); - resStr += `${name}_indices[${rank - 2}] = 0u; + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { + const rank = variable.rank; + const name = variable.name; + if (rank === 2) { + return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`; + } + const batchRank = batchDims.rank; + let resStr = `var ${name}_indices: ${variable.type.indices};`; + for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { + resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`; + } + broadCastDims.forEach((i) => { + resStr += `\n${name}_indices[${i}] = 0;`; + }); + resStr += `${name}_indices[${rank - 2}] = 0u; ${name}_indices[${rank - 1}] = 0u;`; - return resStr; - }; + return resStr; + }; - const calcResult = (): string => { - let calcStr = `var a_data: ${a.type.value};`; - for (let i = 0; i < aComponents; i++) { - calcStr += ` + const calcResult = (): string => { + let calcStr = `var a_data: ${a.type.value};`; + for (let i = 0; i < aComponents; i++) { + calcStr += ` let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`; - } - for (let i = 0; i < outputNumber; i++) { - calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`; + } + for (let i = 0; i < outputNumber; i++) { + calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`; - for (let j = 0; j < aComponents; j++) { - calcStr += ` + for (let j = 0; j < aComponents; j++) { + calcStr += ` values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${ - i}]);\n`; - } - } - return calcStr; - }; + i + }]);\n`; + } + } + return calcStr; + }; - return ` - ${ - shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( - ...inputVariables, output)} + return ` + ${shaderHelper + .registerUniforms(uniforms) + .registerInternalVariables(batchDims) + .declareVariables(...inputVariables, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let col = (global_idx % (uniforms.N / ${components})) * ${components}; @@ -135,21 +160,21 @@ export const createNaiveMatmulProgramInfo = } } `; - }; - return { - name: 'MatMulNaive', - shaderCache: { - hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, - inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] - }, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource - }; - }; + }; + return { + name: 'MatMulNaive', + shaderCache: { + hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, + inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'], + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -165,13 +190,13 @@ export const matMul = (context: ComputeContext): void => { validateInputs(context.inputs); const outputShape = BroadcastUtil.calcShape(context.inputs[0].dims, context.inputs[1].dims, true); if (!outputShape) { - throw new Error('Can\'t use matmul on the given tensors'); + throw new Error("Can't use matmul on the given tensors"); } const N = outputShape[outputShape.length - 1]; const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; if (N < 8 && K < 8) { - context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); + context.compute(createNaiveMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); + context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 8aabaeb22f4d4..121ac8baff04b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -1,13 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, getTensorElementSize} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; +import { DataType, getTensorElementSize } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +import { + createTensorShapeVariables, + getMaxComponents, + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from './common'; // TODO support quantization bits not equal to 4 export interface MatMulNBitsAttributes extends AttributeWithCacheKey { @@ -28,7 +36,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt throw new Error('The last dim of input shape does not match the k value'); } const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); - const blobSize = attributes.blockSize / 8 * attributes.bits; + const blobSize = (attributes.blockSize / 8) * attributes.bits; const b = inputs[1]; if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) { throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize'); @@ -42,84 +50,96 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt const zeroPoints = inputs[3]; const zeroPointsShape = zeroPoints.dims; const expectedZeroPointsSize = - attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + attributes.bits > 4 ? attributes.n * nBlocksPerCol : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { throw new Error('zeroPoints input size error.'); } } }; -export const createMatMulNBitsProgramInfo = - (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes, - maxComputeWorkgroupSizes: [number, number, number], maxComputeWorkgroupStorageSize: number): ProgramInfo => { - const inputShape = inputs[0].dims; - const aRank = inputShape.length; - const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); - const dimAOuter = inputShape[aRank - 2]; - const dimInner = attributes.k; - const dimBOuter = attributes.n; - const batchDims = inputShape.slice(0, aRank - 2); - const batchSize = ShapeUtil.size(batchDims); - const blobSize = attributes.blockSize / 8 * attributes.bits; - const blobSizeInWords = blobSize / 4; - const dataType = inputs[0].dataType; - const outputNumber = getMaxComponents(dimAOuter); - const aComponents = getMaxComponents(attributes.k); - const bComponents = getMaxComponents(blobSizeInWords); - const elementSize = getTensorElementSize(dataType)!; - const workgroupOutputSize = dimAOuter * nBlocksPerCol * elementSize; - const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize); - const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0; - const components = (!useBlockwiseMatMulNBits || maxNumberOfComponents >= 4) ? getMaxComponents(dimBOuter) : - ((maxNumberOfComponents >= 2) && getMaxComponents(dimBOuter) >= 2) ? 2 : - 1; - const outputShape = batchDims.concat([dimAOuter, dimBOuter]); - const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; +export const createMatMulNBitsProgramInfo = ( + inputs: readonly TensorView[], + attributes: MatMulNBitsAttributes, + maxComputeWorkgroupSizes: [number, number, number], + maxComputeWorkgroupStorageSize: number, +): ProgramInfo => { + const inputShape = inputs[0].dims; + const aRank = inputShape.length; + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const dimAOuter = inputShape[aRank - 2]; + const dimInner = attributes.k; + const dimBOuter = attributes.n; + const batchDims = inputShape.slice(0, aRank - 2); + const batchSize = ShapeUtil.size(batchDims); + const blobSize = (attributes.blockSize / 8) * attributes.bits; + const blobSizeInWords = blobSize / 4; + const dataType = inputs[0].dataType; + const outputNumber = getMaxComponents(dimAOuter); + const aComponents = getMaxComponents(attributes.k); + const bComponents = getMaxComponents(blobSizeInWords); + const elementSize = getTensorElementSize(dataType)!; + const workgroupOutputSize = dimAOuter * nBlocksPerCol * elementSize; + const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize); + const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0; + const components = + !useBlockwiseMatMulNBits || maxNumberOfComponents >= 4 + ? getMaxComponents(dimBOuter) + : maxNumberOfComponents >= 2 && getMaxComponents(dimBOuter) >= 2 + ? 2 + : 1; + const outputShape = batchDims.concat([dimAOuter, dimBOuter]); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; - const programUniforms: ProgramUniform[] = useBlockwiseMatMulNBits ? - [] : - [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.blockSize}]; - const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents]; - const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); - bShape.splice(-1, 1, blobSizeInWords / bComponents); - programUniforms.push(...createTensorShapeVariables(inputShapeTemp)); - programUniforms.push(...createTensorShapeVariables(bShape)); - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - if (inputs.length === 4) { - programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); - } - const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const inputRank = inputShapeTemp.length; - const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents); - const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); - const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); - const inputVariables = [a, b, scales]; - const zeroPoints = - inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; - if (zeroPoints) { - inputVariables.push(zeroPoints); - } - const outputRank = outputShapeTemp.length; - const output = outputVariable('output', inputs[0].dataType, outputRank, components); - const uniforms: UniformsArrayType = [{name: 'output_size', type: 'u32'}, {name: 'block_size', type: 'u32'}]; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const programUniforms: ProgramUniform[] = useBlockwiseMatMulNBits + ? [] + : [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: attributes.blockSize }, + ]; + const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents]; + const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); + bShape.splice(-1, 1, blobSizeInWords / bComponents); + programUniforms.push(...createTensorShapeVariables(inputShapeTemp)); + programUniforms.push(...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputRank = inputShapeTemp.length; + const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents); + const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const outputRank = outputShapeTemp.length; + const output = outputVariable('output', inputs[0].dataType, outputRank, components); + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'block_size', type: 'u32' }, + ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const qDqDataType = (() => { - switch (aComponents) { - case 1: - return `array<${dataType}, 8>`; - case 2: - return `mat4x2<${dataType}>`; - case 4: - return `mat2x4<${dataType}>`; - default: - throw new Error(`${aComponents}-component is not supported.`); - } - })(); + const qDqDataType = (() => { + switch (aComponents) { + case 1: + return `array<${dataType}, 8>`; + case 2: + return `mat4x2<${dataType}>`; + case 4: + return `mat2x4<${dataType}>`; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + })(); - const processOneBlock = ` + const processOneBlock = ` for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { ${b.indicesSet('b_indices', '2', 'word')}; let b_data = ${b.getByIndices('b_indices')}; @@ -128,17 +148,20 @@ export const createMatMulNBitsProgramInfo = let b_mask: u32 = 0x0F0F0F0Fu; let b_value_lower: vec4 = unpack4xU8(b_value & b_mask); let b_value_upper: vec4 = unpack4xU8((b_value >> 4) & b_mask); - let b_quantized_values = ${qDqDataType}(${ - Array.from({length: 4}, (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`) - .join(', ')}); + let b_quantized_values = ${qDqDataType}(${Array.from( + { length: 4 }, + (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`, + ).join(', ')}); let b_dequantized_values = ${(() => { - if (aComponents === 1) { - return `${qDqDataType}(${ - Array.from({length: 8}, (_, i) => `(b_quantized_values[${i}] - zero_point) * scale`).join(', ')});`; - } else { - return `(b_quantized_values - ${qDqDataType}(${Array(8).fill('zero_point').join(',')})) * scale;`; - } - })()}; + if (aComponents === 1) { + return `${qDqDataType}(${Array.from( + { length: 8 }, + (_, i) => `(b_quantized_values[${i}] - zero_point) * scale`, + ).join(', ')});`; + } else { + return `(b_quantized_values - ${qDqDataType}(${Array(8).fill('zero_point').join(',')})) * scale;`; + } + })()}; // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 for (var m: u32 = 0; m < ${useBlockwiseMatMulNBits ? dimAOuter : outputNumber}u; m++) { ${a.indicesSet('a_indices', inputRank - 2, useBlockwiseMatMulNBits ? 'm' : `row * ${outputNumber} + m`)}; @@ -150,33 +173,35 @@ export const createMatMulNBitsProgramInfo = input_offset++; } ${useBlockwiseMatMulNBits ? 'workgroup_shared[workgroup_shared_offset + m]' : 'output_values[m]'}${ - components > 1 ? '[c]' : ''} += ${ - Array - .from( - {length: 8 / aComponents}, - (_, i) => `${ - aComponents === 1 ? `a_data[${i}] * b_dequantized_values[${i}]` : - `dot(a_data[${i}], b_dequantized_values[${i}])`}`) - .join(' + ')}; + components > 1 ? '[c]' : '' + } += ${Array.from( + { length: 8 / aComponents }, + (_, i) => + `${ + aComponents === 1 + ? `a_data[${i}] * b_dequantized_values[${i}]` + : `dot(a_data[${i}], b_dequantized_values[${i}])` + }`, + ).join(' + ')}; } word_offset += ${8 / aComponents}; } }`; - const updateZeroPointIndex = zeroPoints ? ` + const updateZeroPointIndex = zeroPoints + ? ` zero_point_offset += 4; if (zero_point_offset == 32) { zero_point_offset = 0; zero_point_index++; zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; - }` : - ''; + }` + : ''; - return useBlockwiseMatMulNBits ? ` + return useBlockwiseMatMulNBits + ? ` var workgroup_shared: array<${output.type.value}, ${dimAOuter * nBlocksPerCol}>; ${shaderHelper.declareVariables(...inputVariables, output)} - ${shaderHelper.mainStart([ - nBlocksPerCol, 1, 1 - ])} + ${shaderHelper.mainStart([nBlocksPerCol, 1, 1])} var a_indices: ${a.type.indices}; var block = local_id.x; var col = workgroup_id.y; @@ -186,15 +211,17 @@ export const createMatMulNBitsProgramInfo = for (var c: u32 = 0; c < ${components}; c++) { let col_times_components_plus_c = col * ${components} + c; ${ - zeroPoints ? ` + zeroPoints + ? ` var zero_point_bytes_per_col: u32 = (${nBlocksPerCol} + 1) / 2; var zero_point_byte_count: u32 = col_times_components_plus_c * zero_point_bytes_per_col + (block >> 0x1u); var zero_point_word_index: u32 = zero_point_byte_count >> 0x2u; var zero_point_byte_offset: u32 = zero_point_byte_count & 0x3u; var zero_point_nibble_offset: u32 = block & 0x1u; var zero_point_bits_offset: u32 = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); - var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;` : - ''} + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;` + : '' + } var b_indices: ${b.type.indices}; ${b.indicesSet('b_indices', '0', 'col_times_components_plus_c')}; // The scale and zero points are computed per block. @@ -227,8 +254,8 @@ export const createMatMulNBitsProgramInfo = output_offset += ${dimBOuter / components}; } } - }` : - ` + }` + : ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} @@ -241,12 +268,14 @@ export const createMatMulNBitsProgramInfo = // zero_point_offset is either 0 or 4. It is bit offset within one byte. // TODO support zero_point_offset for bits > 4 ${ - zeroPoints ? ` + zeroPoints + ? ` var zero_point_abs_offset = col * ${components} * ((${nBlocksPerCol} + 1) / 2); var zero_point_index: u32 = zero_point_abs_offset / 4; var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; - var zero_point_offset: u32 = (zero_point_abs_offset % 4) * 8;` : - ''} + var zero_point_offset: u32 = (zero_point_abs_offset % 4) * 8;` + : '' + } var scale_index = col * ${nBlocksPerCol * components}; var b_indices: ${b.type.indices}; for (var c: u32 = 0; c < ${components}; c++) { @@ -266,41 +295,45 @@ export const createMatMulNBitsProgramInfo = } // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte. ${ - zeroPoints ? `if (zero_point_offset % 8 > 0) { + zeroPoints + ? `if (zero_point_offset % 8 > 0) { ${updateZeroPointIndex} - }` : - ''} + }` + : '' + } } for (var k: u32 = 0u; k < ${outputNumber}u; k++) { ${output.indicesSet('output_indices', outputRank - 2, `${outputNumber} * row + k`)}; ${output.setByIndices('output_indices', 'output_values[k]')} } }`; - }; - return { - name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', - shaderCache: { - hint: `${attributes.cacheKey};${dimAOuter};${dataType};${inputs.length}`, - inputDependencies: Array(inputs.length).fill('rank') - }, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType}], - name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', - dispatchGroup: useBlockwiseMatMulNBits ? {x: 1, y: Math.ceil(dimBOuter / components), z: batchSize} : - {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource - }; - }; + }; + return { + name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', + shaderCache: { + hint: `${attributes.cacheKey};${dimAOuter};${dataType};${inputs.length}`, + inputDependencies: Array(inputs.length).fill('rank'), + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType }], + name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', + dispatchGroup: useBlockwiseMatMulNBits + ? { x: 1, y: Math.ceil(dimBOuter / components), z: batchSize } + : { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { validateInputs(context.inputs, attributes); const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes(); const maxComputeWorkgroupStorageSize = context.getMaxComputeWorkgroupStoragesize(); - context.compute(createMatMulNBitsProgramInfo( - context.inputs, attributes, maxComputeWorkgroupSizes, maxComputeWorkgroupStorageSize)); + context.compute( + createMatMulNBitsProgramInfo(context.inputs, attributes, maxComputeWorkgroupSizes, maxComputeWorkgroupStorageSize), + ); }; export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => - createAttributeWithCacheKey(attributes as Omit); + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 09fadea66fa1f..1e0902eb0ff56 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -1,18 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; - -import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; -import {inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; -import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, GpuDataType, ProgramUniform } from '../types'; + +import { + applyAttention, + AttentionAttrs, + AttentionMaskType, + AttentionParameters, + AttentionQkvFormat, +} from './attention'; +import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; +import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; const getInput = (inputs: readonly TensorView[], i: number) => - (inputs.length > i) && (inputs[i].dims.length > 0) && (ShapeUtil.size(inputs[i].dims)) > 0 ? inputs[i] : undefined; + inputs.length > i && inputs[i].dims.length > 0 && ShapeUtil.size(inputs[i].dims) > 0 ? inputs[i] : undefined; const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { const query = inputs[0]; @@ -65,8 +71,8 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr const dmmhaPacking = false; const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; - const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : - attributes.numHeads * query.dims[4]; + const hiddenSize = + query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4]; let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; @@ -79,8 +85,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr if (pastKey.dims[0] !== batchSize || pastKey.dims[1] !== attributes.numHeads || pastKey.dims[3] !== headSize) { throw new Error('Input "past_key" shape (batch_size, num_heads, past_sequence_length, head_size)'); } - if (pastValue.dims[0] !== batchSize || pastValue.dims[1] !== attributes.numHeads || - pastValue.dims[3] !== headSize) { + if ( + pastValue.dims[0] !== batchSize || + pastValue.dims[1] !== attributes.numHeads || + pastValue.dims[3] !== headSize + ) { throw new Error('Input "past_value" shape (batch_size, num_heads, past_sequence_length, head_size)'); } if (pastKey.dims[2] !== pastValue.dims[2]) { @@ -122,7 +131,8 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr } qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; kvSequenceLength = key.dims[1]; - } else { // key_dims.size() == 4 (cross-attention with past_key) + } else { + // key_dims.size() == 4 (cross-attention with past_key) if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } @@ -130,7 +140,8 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr qkvFormat = AttentionQkvFormat.unknown; kvSequenceLength = key.dims[2]; } - } else { // packed QKV + } else { + // packed QKV if (query.dims.length !== 3 && query.dims.length !== 5) { throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); } @@ -208,9 +219,12 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr if (relativePositionBias.dims.length !== 4) { throw new Error('Input "relative_position_bias" is expected to have 4 dimensions'); } - if ((relativePositionBias.dims[0] !== batchSize && relativePositionBias.dims[0] !== 1) || - relativePositionBias.dims[1] !== attributes.numHeads || relativePositionBias.dims[2] !== sequenceLength || - relativePositionBias.dims[3] !== totalSequenceLength) { + if ( + (relativePositionBias.dims[0] !== batchSize && relativePositionBias.dims[0] !== 1) || + relativePositionBias.dims[1] !== attributes.numHeads || + relativePositionBias.dims[2] !== sequenceLength || + relativePositionBias.dims[3] !== totalSequenceLength + ) { throw new Error('Input "relative_position_bias" shape (batch_size, 1, sequence_length, kv_sequence_length)'); } } @@ -240,29 +254,38 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr }; export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => - createAttributeWithCacheKey({...attributes}); - -const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); - -const addBiasTranspose = - (context: ComputeContext, qkv: TensorView, bias: TensorView, batchSize: number, sequenceLength: number, - hiddenSize: number, biasOffset: number) => { - const outputShape = [batchSize, sequenceLength, hiddenSize]; - const outputSize = ShapeUtil.size(outputShape); - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: biasOffset}, - {type: DataType.uint32, data: hiddenSize} - ]; - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); - const qkvInput = inputVariable('qkv', qkv.dataType, outputShape); - const biasInput = inputVariable('bias', bias.dataType, outputShape); - - const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'bias_offset', type: 'u32'}, {name: 'hidden_size', type: 'u32'} - ]; - return ` + createAttributeWithCacheKey({ ...attributes }); + +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] }); + +const addBiasTranspose = ( + context: ComputeContext, + qkv: TensorView, + bias: TensorView, + batchSize: number, + sequenceLength: number, + hiddenSize: number, + biasOffset: number, +) => { + const outputShape = [batchSize, sequenceLength, hiddenSize]; + const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: biasOffset }, + { type: DataType.uint32, data: hiddenSize }, + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); + const qkvInput = inputVariable('qkv', qkv.dataType, outputShape); + const biasInput = inputVariable('bias', bias.dataType, outputShape); + + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'bias_offset', type: 'u32' }, + { name: 'hidden_size', type: 'u32' }, + ]; + return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(qkvInput, biasInput, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} @@ -270,48 +293,65 @@ const addBiasTranspose = qkv_with_bias[global_idx] = qkv[global_idx] + bias[bias_offset_idx]; }`; - }; - - return context.compute( - { - name: 'MultiHeadAttentionAddBias', - shaderCache: {inputDependencies: ['type', 'type']}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource, - }, - {inputs: [qkv, bias], outputs: [-1]})[0]; - }; - -export const maybeTransposeToBNSHAndAddBias = - (context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number, - input: TensorView, bias?: TensorView, biasOffset?: number) => { - // const newDims = []; - - let reshapedInput = input; - if (!bias) { - if (input.dims.length === 3) { - reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); - } - return context.compute( - createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), - {inputs: [reshapedInput], outputs: [-1]})[0]; - } else { - if (sequenceLength === 1) { - throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV'); - } else { - reshapedInput = - addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!); - reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); - return context.compute( - createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), - {inputs: [reshapedInput], outputs: [-1]})[0]; - } - } - }; + }; + + return context.compute( + { + name: 'MultiHeadAttentionAddBias', + shaderCache: { inputDependencies: ['type', 'type'] }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }, + { inputs: [qkv, bias], outputs: [-1] }, + )[0]; +}; + +export const maybeTransposeToBNSHAndAddBias = ( + context: ComputeContext, + batchSize: number, + numHeads: number, + sequenceLength: number, + headSize: number, + input: TensorView, + bias?: TensorView, + biasOffset?: number, +) => { + // const newDims = []; + + let reshapedInput = input; + if (!bias) { + if (input.dims.length === 3) { + reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); + } + return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { + inputs: [reshapedInput], + outputs: [-1], + })[0]; + } else { + if (sequenceLength === 1) { + throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV'); + } else { + reshapedInput = addBiasTranspose( + context, + input, + bias, + batchSize, + sequenceLength, + numHeads * headSize, + biasOffset!, + ); + reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { + inputs: [reshapedInput], + outputs: [-1], + })[0]; + } + } +}; export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { const params = validateInputs(context.inputs, attributes); @@ -335,24 +375,67 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio const kvBNSH = key && value && key.dims.length === 4 && value.dims.length === 4; const Q = maybeTransposeToBNSHAndAddBias( - context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, query, bias, 0); + context, + params.batchSize, + params.numHeads, + params.sequenceLength, + params.headSize, + query, + bias, + 0, + ); if (kvBNSH) { return applyAttention( - context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, relativePositionBias, params, - attributes); + context, + Q, + key, + value, + keyPaddingMask, + undefined, + pastKey, + pastValue, + relativePositionBias, + params, + attributes, + ); } if (!key || !value) { throw new Error('key and value must be provided'); } const K = maybeTransposeToBNSHAndAddBias( - context, params.batchSize, params.numHeads, params.kvSequenceLength, params.headSize, key, bias, - params.hiddenSize); + context, + params.batchSize, + params.numHeads, + params.kvSequenceLength, + params.headSize, + key, + bias, + params.hiddenSize, + ); const V = maybeTransposeToBNSHAndAddBias( - context, params.batchSize, params.numHeads, params.kvSequenceLength, params.vHeadSize, value, bias, - 2 * params.hiddenSize); + context, + params.batchSize, + params.numHeads, + params.kvSequenceLength, + params.vHeadSize, + value, + bias, + 2 * params.hiddenSize, + ); applyAttention( - context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, relativePositionBias, params, attributes); + context, + Q, + K, + V, + keyPaddingMask, + undefined, + pastKey, + pastValue, + relativePositionBias, + params, + attributes, + ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index d649d3d220ae1..4951bd0192baf 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -1,12 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; - -import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; + +import { + createTensorShapeVariables, + getElementAt, + IndicesHelper, + inputVariable, + outputVariable, + ShaderHelper, + UniformDataElementType, + UniformsArrayType, +} from './common'; interface PadAttributes { // 0-constant, 1-reflect, 2-edge, 3-wrap @@ -152,10 +161,12 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads); const inputDims = inputs[0].dims; const outputSize = ShapeUtil.size(outputShape); - const programUniforms: ProgramUniform[] = - [{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: attributes.pads}]; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.int32, data: attributes.pads }, + ]; if (attributes.mode === 0) { - programUniforms.push({type: inputs[0].dataType, data: attributes.value}); + programUniforms.push({ type: inputs[0].dataType, data: attributes.value }); } programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape)); @@ -166,10 +177,12 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr const input = inputVariable('x', inputs[0].dataType, inputDims.length); const dataType = input.type.value; const padSnippet = getPadSnippet(output, inputDims.length, attributes); - const uniforms: UniformsArrayType = - [{name: 'output_size', type: 'u32'}, {name: 'pads', type: 'i32', length: attributes.pads.length}]; + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'pads', type: 'i32', length: attributes.pads.length }, + ]; if (attributes.mode === 0) { - uniforms.push({name: 'constant_value', type: dataType as UniformDataElementType}); + uniforms.push({ name: 'constant_value', type: dataType as UniformDataElementType }); } return ` @@ -187,11 +200,11 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr return { name: 'Pad', - shaderCache: {hint: `${attributes.mode}`, inputDependencies}, + shaderCache: { hint: `${attributes.mode}`, inputDependencies }, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}, - programUniforms + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) }, + programUniforms, }), getShaderSource, }; @@ -200,7 +213,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => { if (inputs.length > 1) { const bigInt64Pads = inputs[1].getBigInt64Array(); - const value = (inputs.length >= 3 && inputs[2].data) ? inputs[2].getFloat32Array()[0] : 0.0; + const value = inputs.length >= 3 && inputs[2].data ? inputs[2].getFloat32Array()[0] : 0.0; const inputRank = inputs[0].dims.length; const updatePads = new Int32Array(2 * inputRank).fill(0); @@ -211,13 +224,13 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]); } } else { - bigInt64Pads.forEach((v, i) => updatePads[Number(i)] = (Number(v))); + bigInt64Pads.forEach((v, i) => (updatePads[Number(i)] = Number(v))); } const pads: number[] = []; - updatePads.forEach(v => pads.push(v)); + updatePads.forEach((v) => pads.push(v)); - return {mode: attributes.mode, value, pads}; + return { mode: attributes.mode, value, pads }; } else { return attributes; } @@ -226,5 +239,5 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes export const pad = (context: ComputeContext, attributes: PadAttributes): void => { validateInputs(context.inputs); const updatedAttributes = createPadAttributesFromInputs(context.inputs, attributes); - context.compute(createPadProgramInfo(context.inputs, updatedAttributes), {inputs: [0]}); + context.compute(createPadProgramInfo(context.inputs, updatedAttributes), { inputs: [0] }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 5521650e8ded4..8b2438e45d6b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -1,15 +1,23 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env} from 'onnxruntime-common'; - -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {PoolConvUtil, ShapeUtil} from '../../util'; -import {AttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; - -import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import { env } from 'onnxruntime-common'; + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { PoolConvUtil, ShapeUtil } from '../../util'; +import { AttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; + +import { + createTensorShapeVariables, + getElementAt, + IndicesHelper, + inputVariable, + outputVariable, + ShaderHelper, + UniformsArrayType, +} from './common'; // TODO: support: // - ceil_mode "test_maxpool_2d_ceil" @@ -23,12 +31,15 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const getAdjustedPoolAttributesAndOutputShape = ( - input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => { +const getAdjustedPoolAttributesAndOutputShape = ( + input: TensorView, + attributes: AttributeType, + isGlobalOperator: boolean, +): [AttributeType, number[]] => { const isChannelsLast = attributes.format === 'NHWC'; const inputShapeAsChannelFirst = input.dims.slice(); if (isChannelsLast) { - inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position. + inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position. } const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); const kernelShape = attributes.kernelShape.slice(); @@ -38,28 +49,41 @@ const getAdjustedPoolAttributesAndOutputShape = ( - outputShape: readonly number[], - attributes: AttributeType): [ProgramUniform[], UniformsArrayType, boolean, boolean, boolean] => { +const getUniformAndPadInfo = ( + outputShape: readonly number[], + attributes: AttributeType, +): [ProgramUniform[], UniformsArrayType, boolean, boolean, boolean] => { const isChannelsLast = attributes.format === 'NHWC'; const outputSize = ShapeUtil.size(outputShape); const kernelSize = ShapeUtil.size(attributes.kernelShape); - const programUniforms: ProgramUniform[] = - [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: kernelSize}]; - const uniforms: UniformsArrayType = [{name: 'outputSize', type: 'u32'}, {name: 'kernelSize', type: 'u32'}]; + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: kernelSize }, + ]; + const uniforms: UniformsArrayType = [ + { name: 'outputSize', type: 'u32' }, + { name: 'kernelSize', type: 'u32' }, + ]; if (attributes.kernelShape.length <= 2) { const kw = attributes.kernelShape[attributes.kernelShape.length - 1]; const sw = attributes.strides[attributes.strides.length - 1]; @@ -67,14 +91,17 @@ const getUniformAndPadInfo = sum + cur); return [programUniforms, uniforms, !!hasPads, false, false]; } }; -const generatePoolingCode = ( - shaderHelper: ShaderHelper, x: IndicesHelper, rank: number, outputShapeRank: number, attributes: AttributeType, - op1: string, op2: string, start: number, uniforms: UniformsArrayType, hasPads: boolean, pwStartEndNotZero: boolean, - phStartEndNotZero: boolean): string => { +const generatePoolingCode = ( + shaderHelper: ShaderHelper, + x: IndicesHelper, + rank: number, + outputShapeRank: number, + attributes: AttributeType, + op1: string, + op2: string, + start: number, + uniforms: UniformsArrayType, + hasPads: boolean, + pwStartEndNotZero: boolean, + phStartEndNotZero: boolean, +): string => { const isChannelsLast = attributes.format === 'NHWC'; const dataType = x.type.value; const output = outputVariable('output', x.type.tensor, outputShapeRank); @@ -235,8 +281,11 @@ const generatePoolingCode = - (`${attributes.format};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`); + `${attributes.format};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`; const createAveragePoolShaderKeyFromAttributes = (attributes: AveragePoolAttributes): string => - (`${createShaderKeyFromAttributes(attributes)};${attributes.countIncludePad}`); + `${createShaderKeyFromAttributes(attributes)};${attributes.countIncludePad}`; const createMaxPoolShaderKeyFromAttributes = (attributes: MaxPoolAttributes): string => - (`${createShaderKeyFromAttributes(attributes)};${attributes.storageOrder};${attributes.dilations}`); + `${createShaderKeyFromAttributes(attributes)};${attributes.storageOrder};${attributes.dilations}`; const parsePoolCommonAttributes = (attributes: Record): PoolCommonAttributes => ({ format: attributes.format as FormatAttributes['format'], @@ -275,45 +324,68 @@ const parsePoolCommonAttributes = (attributes: Record): PoolCom ceilMode: attributes.ceil_mode as number, kernelShape: attributes.kernel_shape as [number, number], strides: attributes.strides as [number, number], - pads: attributes.pads as [number, number, number, number] + pads: attributes.pads as [number, number, number, number], }); export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey { readonly countIncludePad: boolean; } -const createAveragePoolProgramInfo = - (name: string, input: TensorView, isGlobalOperator: boolean, attributes: AveragePoolAttributes): ProgramInfo => { - const [adjustedAttributes, outputShape] = - getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator); - const x = inputVariable('x', input.dataType, input.dims.length); - const dataType = x.type.value; - - const op1 = 'value += x_val;'; - let op2 = ''; - if (adjustedAttributes.countIncludePad) { - op2 += `value /= ${dataType}(uniforms.kernelSize);`; - } else { - op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`; - } - const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = - getUniformAndPadInfo(outputShape, adjustedAttributes); - programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; - return { - name, - shaderCache: - {hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: input.dataType}], - dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource: shaderHelper => generatePoolingCode( - shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, 0.0, uniforms, - hasPads, pwStartEndNotZero, phStartEndNotZero), - }; - }; +const createAveragePoolProgramInfo = ( + name: string, + input: TensorView, + isGlobalOperator: boolean, + attributes: AveragePoolAttributes, +): ProgramInfo => { + const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape( + input, + attributes, + isGlobalOperator, + ); + const x = inputVariable('x', input.dataType, input.dims.length); + const dataType = x.type.value; + + const op1 = 'value += x_val;'; + let op2 = ''; + if (adjustedAttributes.countIncludePad) { + op2 += `value /= ${dataType}(uniforms.kernelSize);`; + } else { + op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`; + } + const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo( + outputShape, + adjustedAttributes, + ); + programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; + return { + name, + shaderCache: { + hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, + inputDependencies, + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: input.dataType }], + dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource: (shaderHelper) => + generatePoolingCode( + shaderHelper, + x, + input.dims.length, + outputShape.length, + adjustedAttributes, + op1, + op2, + 0.0, + uniforms, + hasPads, + pwStartEndNotZero, + phStartEndNotZero, + ), + }; +}; export const parseAveragePoolAttributes = (attributes: Record): AveragePoolAttributes => { const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true; @@ -323,8 +395,8 @@ export const parseAveragePoolAttributes = (attributes: Record): if (attr.ceilMode !== 0) { throw new Error('using ceil() in shape computation is not yet supported for AveragePool'); } - const averagePoolAttributes = {countIncludePad, ...attr, cacheKey: ''}; - return {...averagePoolAttributes, cacheKey: createAveragePoolShaderKeyFromAttributes(averagePoolAttributes)}; + const averagePoolAttributes = { countIncludePad, ...attr, cacheKey: '' }; + return { ...averagePoolAttributes, cacheKey: createAveragePoolShaderKeyFromAttributes(averagePoolAttributes) }; }; export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => { @@ -340,12 +412,12 @@ const globalPoolAttributes = { strides: [], pads: [], storageOrder: 0, - dilations: [] + dilations: [], }; export const parseGlobalAveragePoolAttributes = (attributes: Record): AveragePoolAttributes => { const format = attributes.format as FormatAttributes['format']; - return {format, ...globalPoolAttributes, cacheKey: format}; + return { format, ...globalPoolAttributes, cacheKey: format }; }; export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => { @@ -358,34 +430,56 @@ export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCa readonly dilations: number[]; } -const createMaxPoolProgramInfo = - (name: string, input: TensorView, isGlobalOperator: boolean, attributes: MaxPoolAttributes): ProgramInfo => { - const [adjustedAttributes, outputShape] = - getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator); - const op1 = ` +const createMaxPoolProgramInfo = ( + name: string, + input: TensorView, + isGlobalOperator: boolean, + attributes: MaxPoolAttributes, +): ProgramInfo => { + const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape( + input, + attributes, + isGlobalOperator, + ); + const op1 = ` value = max(x_val, value); `; - const op2 = ''; - const x = inputVariable('x', input.dataType, input.dims.length); - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; - const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = - getUniformAndPadInfo(outputShape, adjustedAttributes); - programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); - return { - name, - shaderCache: - {hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: input.dataType}], - dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}, - programUniforms - }), - getShaderSource: shaderHelper => generatePoolingCode( - shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, - (input.dataType === DataType.float16) ? -65504 : -1e5, uniforms, hasPads, pwStartEndNotZero, - phStartEndNotZero), - }; - }; + const op2 = ''; + const x = inputVariable('x', input.dataType, input.dims.length); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; + const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo( + outputShape, + adjustedAttributes, + ); + programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); + return { + name, + shaderCache: { + hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, + inputDependencies, + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: input.dataType }], + dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource: (shaderHelper) => + generatePoolingCode( + shaderHelper, + x, + input.dims.length, + outputShape.length, + adjustedAttributes, + op1, + op2, + input.dataType === DataType.float16 ? -65504 : -1e5, + uniforms, + hasPads, + pwStartEndNotZero, + phStartEndNotZero, + ), + }; +}; export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => { validateInputs(context.inputs); @@ -404,13 +498,13 @@ export const parseMaxPoolAttributes = (attributes: Record): Max if (attr.ceilMode !== 0) { throw new Error('using ceil() in shape computation is not yet supported for MaxPool'); } - const maxPoolAttributes = {storageOrder, dilations, ...attr, cacheKey: ''}; - return {...maxPoolAttributes, cacheKey: createMaxPoolShaderKeyFromAttributes(maxPoolAttributes)}; + const maxPoolAttributes = { storageOrder, dilations, ...attr, cacheKey: '' }; + return { ...maxPoolAttributes, cacheKey: createMaxPoolShaderKeyFromAttributes(maxPoolAttributes) }; }; export const parseGlobalMaxPoolAttributes = (attributes: Record): MaxPoolAttributes => { const format = attributes.format as FormatAttributes['format']; - return {format, ...globalPoolAttributes, cacheKey: format}; + return { format, ...globalPoolAttributes, cacheKey: format }; }; export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/quantize-linear.ts b/js/web/lib/wasm/jsep/webgpu/ops/quantize-linear.ts index 0d7c7ab408b3a..52ecd07cb7f92 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/quantize-linear.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/quantize-linear.ts @@ -1,13 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import { + createTensorShapeVariables, + getMaxComponents, + inputVariable, + outputVariable, + ShaderHelper, + UniformsArrayType, +} from './common'; export interface DequantizeLinerAttributes extends AttributeWithCacheKey { axis: number; @@ -50,9 +57,9 @@ const validateInputs = (inputs: readonly TensorView[], attributes: DequantizeLin if (inputs[1].dims.length === 0 || (inputs[1].dims.length === 1 && inputs[1].dims[0] === 1)) { throw new Error('blockSize must be set only for block quantization.'); } - if (!inputs[1] - .dims.map((d, i) => i === attributes.axis || d === inputs[0].dims[i]) - .reduce((a, b) => a && b, true)) { + if ( + !inputs[1].dims.map((d, i) => i === attributes.axis || d === inputs[0].dims[i]).reduce((a, b) => a && b, true) + ) { throw new Error('For block qunatization, scale input shape to match the input shape except for the axis'); } // Scale input rank should be same as the input rank @@ -67,53 +74,62 @@ const validateInputs = (inputs: readonly TensorView[], attributes: DequantizeLin } }; -const createDequantizeLinearProgramInfo = - (inputs: readonly TensorView[], attributes: DequantizeLinerAttributes): ProgramInfo => { - const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length); - const inputType = inputs[0].dataType; - const isSigned = inputType === DataType.int8; - const outputShape = inputs[0].dims; // output shape is same as the input shape - const dataType = inputs[1].dataType; // output type is same as the the scale input type - const outputSize = ShapeUtil.size(outputShape); - const isPacked = inputType === DataType.int8 || inputType === DataType.uint8; - const inputShape = isPacked ? [Math.ceil(ShapeUtil.size(inputs[0].dims) / 4)] : inputs[0].dims; - const scaleShape = inputs[1].dims; - const zeroPointInput = inputs.length > 2 ? inputs[2] : undefined; - const zeroPointShape = zeroPointInput ? - (isPacked ? [Math.ceil(ShapeUtil.size(zeroPointInput.dims) / 4)] : zeroPointInput.dims) : - undefined; - // Scales input is a scaler for per-tensor/per-layer quantization, 1-D tensor for per-axis quantization - // or tensor with same rank as input for blocked quantization. - const perLayerQuantization = scaleShape.length === 0 || (scaleShape.length === 1 && scaleShape[0] === 1); - const perAxisQuantization = perLayerQuantization === false && scaleShape.length === 1; - // Left unnecessary commented-out assignment for documentation - // const blockQuantization = perLayerQuantization === false && perAxisQuantization === false; - const maxComponents = getMaxComponents(outputSize); - const useComponents = perLayerQuantization && (!isPacked || maxComponents === 4); - const components = useComponents ? maxComponents : 1; - const inputComponent = (useComponents && !isPacked) ? maxComponents : 1; - const input = inputVariable('input', isPacked ? DataType.uint32 : inputType, inputShape.length, inputComponent); - const scale = inputVariable('scale', dataType, scaleShape.length); - const zeroPoint = zeroPointInput ? - inputVariable('zero_point', isPacked ? DataType.uint32 : inputType, zeroPointShape!.length) : - undefined; - const output = outputVariable('output', dataType, outputShape.length, components); - const inputVariables = [input, scale]; - if (zeroPoint) { - inputVariables.push(zeroPoint); - } - const inputShapes = [inputShape, scaleShape]; - if (zeroPointInput) { - inputShapes.push(zeroPointShape!); - } - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize / components}, {type: DataType.uint32, data: axis}, - {type: DataType.uint32, data: attributes.blockSize}, ...createTensorShapeVariables(...inputShapes, outputShape) - ]; - const getShaderSource = (shaderHelper: ShaderHelper) => { - const uniforms: UniformsArrayType = - [{name: 'output_size', type: 'u32'}, {name: 'axis', type: 'u32'}, {name: 'block_size', type: 'u32'}]; - return ` +const createDequantizeLinearProgramInfo = ( + inputs: readonly TensorView[], + attributes: DequantizeLinerAttributes, +): ProgramInfo => { + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length); + const inputType = inputs[0].dataType; + const isSigned = inputType === DataType.int8; + const outputShape = inputs[0].dims; // output shape is same as the input shape + const dataType = inputs[1].dataType; // output type is same as the the scale input type + const outputSize = ShapeUtil.size(outputShape); + const isPacked = inputType === DataType.int8 || inputType === DataType.uint8; + const inputShape = isPacked ? [Math.ceil(ShapeUtil.size(inputs[0].dims) / 4)] : inputs[0].dims; + const scaleShape = inputs[1].dims; + const zeroPointInput = inputs.length > 2 ? inputs[2] : undefined; + const zeroPointShape = zeroPointInput + ? isPacked + ? [Math.ceil(ShapeUtil.size(zeroPointInput.dims) / 4)] + : zeroPointInput.dims + : undefined; + // Scales input is a scaler for per-tensor/per-layer quantization, 1-D tensor for per-axis quantization + // or tensor with same rank as input for blocked quantization. + const perLayerQuantization = scaleShape.length === 0 || (scaleShape.length === 1 && scaleShape[0] === 1); + const perAxisQuantization = perLayerQuantization === false && scaleShape.length === 1; + // Left unnecessary commented-out assignment for documentation + // const blockQuantization = perLayerQuantization === false && perAxisQuantization === false; + const maxComponents = getMaxComponents(outputSize); + const useComponents = perLayerQuantization && (!isPacked || maxComponents === 4); + const components = useComponents ? maxComponents : 1; + const inputComponent = useComponents && !isPacked ? maxComponents : 1; + const input = inputVariable('input', isPacked ? DataType.uint32 : inputType, inputShape.length, inputComponent); + const scale = inputVariable('scale', dataType, scaleShape.length); + const zeroPoint = zeroPointInput + ? inputVariable('zero_point', isPacked ? DataType.uint32 : inputType, zeroPointShape!.length) + : undefined; + const output = outputVariable('output', dataType, outputShape.length, components); + const inputVariables = [input, scale]; + if (zeroPoint) { + inputVariables.push(zeroPoint); + } + const inputShapes = [inputShape, scaleShape]; + if (zeroPointInput) { + inputShapes.push(zeroPointShape!); + } + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize / components }, + { type: DataType.uint32, data: axis }, + { type: DataType.uint32, data: attributes.blockSize }, + ...createTensorShapeVariables(...inputShapes, outputShape), + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'axis', type: 'u32' }, + { name: 'block_size', type: 'u32' }, + ]; + return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} @@ -121,94 +137,96 @@ const createDequantizeLinearProgramInfo = // Set input x ${(() => { - if (isPacked) { - return ` + if (isPacked) { + return ` let input = ${input.getByOffset('global_idx / 4')}; let x_vec = ${isSigned ? 'unpack4xI8(input)' : 'unpack4xU8(input)'}; let x_value = ${components === 1 ? 'x_vec[global_idx % 4]' : 'x_vec'};`; - } else { - return `let x_value = ${input.getByOffset('global_idx')};`; - } - })()}; + } else { + return `let x_value = ${input.getByOffset('global_idx')};`; + } + })()}; // Set scale input ${(() => { - if (perLayerQuantization) { - // scale input is a scalar () - return `let scale_value= ${scale.getByOffset('0')}`; - } else if (perAxisQuantization) { - // scale input is a 1D tensor - return ` + if (perLayerQuantization) { + // scale input is a scalar () + return `let scale_value= ${scale.getByOffset('0')}`; + } else if (perAxisQuantization) { + // scale input is a 1D tensor + return ` let scale_index = ${output.indicesGet('output_indices', 'uniforms.axis')}; let scale_value= ${scale.getByOffset('scale_index')};`; - } else { - // Block quantization. Scale input rank is same as input/output rank. - return ` + } else { + // Block quantization. Scale input rank is same as input/output rank. + return ` var scale_indices: ${scale.type.indices} = output_indices; let index = ${scale.indicesGet('scale_indices', 'uniforms.axis')} / uniforms.block_size; ${scale.indicesSet('scale_indices', 'uniforms.axis', 'index')}; let scale_value= ${scale.getByIndices('scale_indices')};`; - } - })()}; + } + })()}; // Set zero-point input ${(() => { - if (zeroPoint) { - if (perLayerQuantization) { - // zero-point input is a scalar - if (isPacked) { - return ` + if (zeroPoint) { + if (perLayerQuantization) { + // zero-point input is a scalar + if (isPacked) { + return ` let zero_point_input = ${zeroPoint.getByOffset('0')}; let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'}; let zero_point_value= zero_point_vec[0]`; - } else { - return `let zero_point_value = ${zeroPoint.getByOffset('0')}`; - } - } else if (perAxisQuantization) { - // zero-point input is a 1D tensor - if (isPacked) { - return ` + } else { + return `let zero_point_value = ${zeroPoint.getByOffset('0')}`; + } + } else if (perAxisQuantization) { + // zero-point input is a 1D tensor + if (isPacked) { + return ` let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')}; let zero_point_input = ${zeroPoint.getByOffset('zero_point_index / 4')}; let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'}; let zero_point_value = zero_point_vec[zero_point_index % 4]`; - } else { - return ` + } else { + return ` let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')}; let zero_point_value = ${zeroPoint.getByOffset('zero_point_index')};`; - } - } else { - // BlockedQuantization. The zero-point input shape is same as the input shape except along axis. - if (isPacked) { - return ` + } + } else { + // BlockedQuantization. The zero-point input shape is same as the input shape except along axis. + if (isPacked) { + return ` let zero_point_offset = ${scale.indicesToOffset('scale_indices')}; let zero_point_input = ${zeroPoint.getByOffset('zero_point_offset / 4')}; let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'}; let zero_point_value = zero_point_vec[zero_point_offset % 4];`; - } else { - return `let zero_point_value = ${zeroPoint.getByIndices('scale_indices')};`; + } else { + return `let zero_point_value = ${zeroPoint.getByIndices('scale_indices')};`; + } } + } else { + return `let zero_point_value = ${isPacked ? (isSigned ? 'i32' : 'u32') : input.type.value}(0);`; } - } else { - return `let zero_point_value = ${isPacked ? (isSigned ? 'i32' : 'u32') : input.type.value}(0);`; - } - })()}; + })()}; // Compute and write output ${output.setByOffset('global_idx', `${output.type.value}(x_value - zero_point_value) * scale_value`)}; }`; - }; - return { - name: 'DequantizeLinear', - shaderCache: - {hint: attributes.cacheKey, inputDependencies: zeroPoint ? ['rank', 'rank', 'rank'] : ['rank', 'rank']}, - getShaderSource, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType}], - dispatchGroup: {x: Math.ceil(outputSize / components / 64), y: 1, z: 1}, - programUniforms - }) - }; - }; + }; + return { + name: 'DequantizeLinear', + shaderCache: { + hint: attributes.cacheKey, + inputDependencies: zeroPoint ? ['rank', 'rank', 'rank'] : ['rank', 'rank'], + }, + getShaderSource, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType }], + dispatchGroup: { x: Math.ceil(outputSize / components / 64), y: 1, z: 1 }, + programUniforms, + }), + }; +}; export const dequantizeLinear = (context: ComputeContext, attributes: DequantizeLinerAttributes): void => { validateInputs(context.inputs, attributes); @@ -216,4 +234,4 @@ export const dequantizeLinear = (context: ComputeContext, attributes: Dequantize }; export const parseDequantizeLinearAttributes = (attributes: Record): DequantizeLinerAttributes => - createAttributeWithCacheKey({axis: attributes.axis as number, blockSize: attributes.blockSize as number}); + createAttributeWithCacheKey({ axis: attributes.axis as number, blockSize: attributes.blockSize as number }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/range.ts b/js/web/lib/wasm/jsep/webgpu/ops/range.ts index a21f48ef9ded9..ff7aa8aece9c1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/range.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/range.ts @@ -1,12 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env} from 'onnxruntime-common'; +import { env } from 'onnxruntime-common'; -import {DataType} from '../../../wasm-common'; -import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import {createTensorShapeVariables, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common'; +import { + createTensorShapeVariables, + outputVariable, + ShaderHelper, + UniformDataElementType, + UniformsArrayType, +} from './common'; const validateInputsContent = (start: number, limit: number, delta: number): void => { const sameStartLimit = start === limit; @@ -14,7 +20,7 @@ const validateInputsContent = (start: number, limit: number, delta: number): voi const decreasingRangePositiveStep = start > limit && delta > 0; if (sameStartLimit || increasingRangeNegativeStep || decreasingRangePositiveStep) { - throw new Error('Range these inputs\' contents are invalid.'); + throw new Error("Range these inputs' contents are invalid."); } }; @@ -23,16 +29,19 @@ const createRangeProgramInfo = (start: number, limit: number, delta: number, dat const outputShape: number[] = [numElements]; const outputSize = numElements; const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: dataType, data: start}, {type: dataType, data: delta}, - ...createTensorShapeVariables(outputShape) + { type: DataType.uint32, data: outputSize }, + { type: dataType, data: start }, + { type: dataType, data: delta }, + ...createTensorShapeVariables(outputShape), ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', dataType, outputShape.length); const wgslType = output.type.value; const uniforms: UniformsArrayType = [ - {name: 'outputSize', type: 'u32'}, {name: 'start', type: wgslType as UniformDataElementType}, - {name: 'delta', type: wgslType as UniformDataElementType} + { name: 'outputSize', type: 'u32' }, + { name: 'start', type: wgslType as UniformDataElementType }, + { name: 'delta', type: wgslType as UniformDataElementType }, ]; return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(output)} @@ -44,13 +53,13 @@ const createRangeProgramInfo = (start: number, limit: number, delta: number, dat return { name: 'Range', - shaderCache: {hint: `${dataType}`}, + shaderCache: { hint: `${dataType}` }, getShaderSource, getRunData: () => ({ - outputs: [{dims: outputShape, dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms - }) + outputs: [{ dims: outputShape, dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), }; }; @@ -71,5 +80,5 @@ export const range = (context: ComputeContext): void => { validateInputsContent(start, limit, delta); } - context.compute(createRangeProgramInfo(start, limit, delta, context.inputs[0].dataType), {inputs: []}); + context.compute(createRangeProgramInfo(start, limit, delta, context.inputs[0].dataType), { inputs: [] }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index 210b3ee7e2fca..bf64b04dde1e8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -1,16 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo, ProgramShaderCacheInfo } from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; -import {createReduceAttributesFromInputs, ReduceAttributes} from './reduce'; -import {createTransposeProgramInfo} from './transpose'; +import { inputVariable, outputVariable, ShaderHelper } from './common'; +import { createReduceAttributesFromInputs, ReduceAttributes } from './reduce'; +import { createTransposeProgramInfo } from './transpose'; -const reduceOps: {[key: string]: string} = { +const reduceOps: { [key: string]: string } = { max: 'select(bestValue, candidate, candidate > bestValue)', min: 'select(bestValue, candidate, candidate < bestValue)', mean: 'bestValue + candidate', @@ -20,10 +20,10 @@ const reduceOps: {[key: string]: string} = { logSumExp: 'bestValue + exp(candidate)', l1: 'bestValue + abs(candidate)', l2: 'bestValue + candidate * candidate', - logSum: 'bestValue + candidate' + logSum: 'bestValue + candidate', }; -const reduceSharedOps: {[key: string]: string} = { +const reduceSharedOps: { [key: string]: string } = { max: 'select(bestValue, candidate, candidate > bestValue)', min: 'select(bestValue, candidate, candidate < bestValue)', mean: 'bestValue + candidate', @@ -33,10 +33,10 @@ const reduceSharedOps: {[key: string]: string} = { logSumExp: 'bestValue + candidate', l1: 'bestValue + candidate', l2: 'bestValue + candidate', - logSum: 'bestValue + candidate' + logSum: 'bestValue + candidate', }; -const reduceInitValues: {[key: string]: string} = { +const reduceInitValues: { [key: string]: string } = { max: '_A[offset]', min: '_A[offset]', mean: '0', @@ -46,10 +46,10 @@ const reduceInitValues: {[key: string]: string} = { logSumExp: '0', l1: '0', l2: '0', - logSum: '0' + logSum: '0', }; -const reduceOutputValues: {[key: string]: string} = { +const reduceOutputValues: { [key: string]: string } = { max: 'bestValue', min: 'bestValue', sum: 'bestValue', @@ -58,7 +58,7 @@ const reduceOutputValues: {[key: string]: string} = { logSumExp: 'log(bestValue)', l1: 'bestValue', l2: 'sqrt(bestValue)', - logSum: 'log(bestValue)' + logSum: 'log(bestValue)', }; const getInnerMostAxes = (numInnerAxes: number, rank: number): number[] => { @@ -77,7 +77,7 @@ const computeOutAndReduceShapes = (shape: readonly number[], axes: readonly numb outputShape.push(shape[dim]); } } - const reduceShape = axes.map(dim => shape[dim]); + const reduceShape = axes.map((dim) => shape[dim]); return [outputShape, reduceShape]; }; @@ -112,29 +112,35 @@ const getAxesPermutation = (axes: number[], rank: number): number[] => { res.push(i); } } - axes.forEach(axis => res.push(axis)); + axes.forEach((axis) => res.push(axis)); } return res; }; -export const createReduceSharedProgramInfo = - (name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceType: string, - outputDataType: DataType, outputShape: number[], reduceShape: number[]): ProgramInfo => { - const inputShape = inputs[0].dims; +export const createReduceSharedProgramInfo = ( + name: string, + shaderCache: ProgramShaderCacheInfo, + inputs: readonly TensorView[], + reduceType: string, + outputDataType: DataType, + outputShape: number[], + reduceShape: number[], +): ProgramInfo => { + const inputShape = inputs[0].dims; - const outputSize = ShapeUtil.size(outputShape); - const reduceSize = ShapeUtil.size(reduceShape); + const outputSize = ShapeUtil.size(outputShape); + const reduceSize = ShapeUtil.size(reduceShape); - const input = inputVariable('_A', inputs[0].dataType, inputShape); - const output = outputVariable('output', outputDataType, outputShape); + const input = inputVariable('_A', inputs[0].dataType, inputShape); + const output = outputVariable('output', outputDataType, outputShape); - const workgroupSize = 32; + const workgroupSize = 32; - const sharedMemorySnippet = ` + const sharedMemorySnippet = ` var aBestValues : array; `; - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('reduceSize', 'u32').declareVariables(input, output)} ${sharedMemorySnippet} fn DIV_CEIL(a : u32, b : u32) -> u32 { @@ -168,61 +174,75 @@ export const createReduceSharedProgramInfo = } if (local_idx == 0u) { - ${ - output.setByOffset( - 'outputIndex', - `${ - reduceType === 'mean' ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))` : - `${output.type.storage}(${reduceOutputValues[reduceType]})`}`)}; + ${output.setByOffset( + 'outputIndex', + `${ + reduceType === 'mean' + ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))` + : `${output.type.storage}(${reduceOutputValues[reduceType]})` + }`, + )}; } }`; - // One work group is responsible for only one element of output. - return { - name, - shaderCache, - getShaderSource, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: outputSize}, - programUniforms: [{type: DataType.uint32, data: reduceSize}] - }), - }; - }; - -const reduceCommon = - (context: ComputeContext, name: string, attributes: ReduceAttributes, - reduceType: 'sum'|'sumSquare'|'prod'|'min'|'max'|'mean'|'logSumExp'|'l1'|'l2'|'logSum'): void => { - const updatedAttributes: ReduceAttributes = - context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes); - - let updatedAxes = updatedAttributes.axes; - if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) { - updatedAxes = context.inputs[0].dims.map((_dim, i) => i); - } - const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length); - - let axes = normalizeAxes; - let input = context.inputs[0]; - const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length); - if (permutedAxes.length > 0) { - input = context.compute( - createTransposeProgramInfo(context.inputs[0], permutedAxes), {inputs: [0], outputs: [-1]})[0]; - axes = getInnerMostAxes(axes.length, input.dims.length); - } + // One work group is responsible for only one element of output. + return { + name, + shaderCache, + getShaderSource, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: outputDataType }], + dispatchGroup: { x: outputSize }, + programUniforms: [{ type: DataType.uint32, data: reduceSize }], + }), + }; +}; - const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes); - let finalOutputShape = outputShape; - if (updatedAttributes.keepDims) { - finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes); - } +const reduceCommon = ( + context: ComputeContext, + name: string, + attributes: ReduceAttributes, + reduceType: 'sum' | 'sumSquare' | 'prod' | 'min' | 'max' | 'mean' | 'logSumExp' | 'l1' | 'l2' | 'logSum', +): void => { + const updatedAttributes: ReduceAttributes = + context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes); + + let updatedAxes = updatedAttributes.axes; + if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) { + updatedAxes = context.inputs[0].dims.map((_dim, i) => i); + } + const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length); + + let axes = normalizeAxes; + let input = context.inputs[0]; + const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length); + if (permutedAxes.length > 0) { + input = context.compute(createTransposeProgramInfo(context.inputs[0], permutedAxes), { + inputs: [0], + outputs: [-1], + })[0]; + axes = getInnerMostAxes(axes.length, input.dims.length); + } + + const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes); + let finalOutputShape = outputShape; + if (updatedAttributes.keepDims) { + finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes); + } - context.compute( - createReduceSharedProgramInfo( - name, {hint: updatedAttributes.cacheKey, inputDependencies: ['type']}, [input], reduceType, - context.inputs[0].dataType, finalOutputShape, reduceShape), - {inputs: [input]}); - }; + context.compute( + createReduceSharedProgramInfo( + name, + { hint: updatedAttributes.cacheKey, inputDependencies: ['type'] }, + [input], + reduceType, + context.inputs[0].dataType, + finalOutputShape, + reduceShape, + ), + { inputs: [input] }, + ); +}; export const reduceMeanShared = (context: ComputeContext, attributes: ReduceAttributes): void => { reduceCommon(context, 'ReduceMeanShared', attributes, 'mean'); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index e8205ba6fd928..85be1aef30861 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -1,14 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; - -import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; -import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramShaderCacheInfo } from '../types'; + +import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common'; +import { + reduceL1Shared, + reduceL2Shared, + reduceLogSumExpShared, + reduceLogSumShared, + reduceMaxShared, + reduceMeanShared, + reduceMinShared, + reduceProdShared, + reduceSumShared, + reduceSumSquareShared, +} from './reduce-shared'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length === 0 || inputs.length > 2) { @@ -26,56 +37,65 @@ export interface ReduceAttributes extends AttributeWithCacheKey { axes: number[]; } -export type ReduceOp = - (input: IndicesHelper, output: IndicesHelper, - axes: readonly number[]) => [string, string, string, string, ...string[]]; +export type ReduceOp = ( + input: IndicesHelper, + output: IndicesHelper, + axes: readonly number[], +) => [string, string, string, string, ...string[]]; const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByIndices('input_indices')};`, '']; -export const createReduceProgramInfo = - (name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceOp: ReduceOp, - axesInput: number[], outputDataType: DataType, keepDims = false, noopWithEmptyAxes = false): ProgramInfo => { - const outputShape: number[] = []; - const inputShape = inputs[0].dims; - const inputRank = inputShape.length; - const axes = ShapeUtil.normalizeAxes(axesInput, inputRank); - const reduceOnAllAxes = !noopWithEmptyAxes && axes.length === 0; - inputShape.forEach((d, i) => { - if (reduceOnAllAxes || axes.indexOf(i) >= 0) { - if (keepDims) { - outputShape.push(1); - } // else { // skip this axis} - } else { - outputShape.push(d); +export const createReduceProgramInfo = ( + name: string, + shaderCache: ProgramShaderCacheInfo, + inputs: readonly TensorView[], + reduceOp: ReduceOp, + axesInput: number[], + outputDataType: DataType, + keepDims = false, + noopWithEmptyAxes = false, +): ProgramInfo => { + const outputShape: number[] = []; + const inputShape = inputs[0].dims; + const inputRank = inputShape.length; + const axes = ShapeUtil.normalizeAxes(axesInput, inputRank); + const reduceOnAllAxes = !noopWithEmptyAxes && axes.length === 0; + inputShape.forEach((d, i) => { + if (reduceOnAllAxes || axes.indexOf(i) >= 0) { + if (keepDims) { + outputShape.push(1); + } // else { // skip this axis} + } else { + outputShape.push(d); + } + }); + const outputRank = outputShape.length; + const outputSize = ShapeUtil.size(outputShape); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const idxCopy: string[] = []; // copy output indexes to input indexes + + const input = inputVariable('_A', inputs[0].dataType, inputRank); + const output = outputVariable('output', outputDataType, outputRank); + const ops = reduceOp(input, output, axes); + let reduceOps = ops[2]; + + for (let k = 0, l = 0; k < inputRank; k++) { + // if this axis is reduced + if (reduceOnAllAxes || axes.indexOf(k) >= 0) { + if (keepDims) { + l++; } - }); - const outputRank = outputShape.length; - const outputSize = ShapeUtil.size(outputShape); - const getShaderSource = (shaderHelper: ShaderHelper) => { - const idxCopy: string[] = []; // copy output indexes to input indexes - - const input = inputVariable('_A', inputs[0].dataType, inputRank); - const output = outputVariable('output', outputDataType, outputRank); - const ops = reduceOp(input, output, axes); - let reduceOps = ops[2]; - - for (let k = 0, l = 0; k < inputRank; k++) { - // if this axis is reduced - if (reduceOnAllAxes || axes.indexOf(k) >= 0) { - if (keepDims) { - l++; - } - // loop over the d-th axis - reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputShape[k]}; j${k}++) { + // loop over the d-th axis + reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputShape[k]}; j${k}++) { ${ops[2].includes('last_index') ? `let last_index = j${k};` : ''} ${input.indicesSet('input_indices', k, `j${k}`)} ${reduceOps} }`; - } else { - idxCopy.push(`${input.indicesSet('input_indices', k, output.indicesGet('output_indices', l))};`); - l++; - } - } - return ` + } else { + idxCopy.push(`${input.indicesSet('input_indices', k, output.indicesGet('output_indices', l))};`); + l++; + } + } + return ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} @@ -91,86 +111,103 @@ export const createReduceProgramInfo = ${ops[3]} ${ops.length === 4 ? output.setByOffset('global_idx', 'value') : ops.slice(4).join('\n')} }`; - }; - - return { - name, - shaderCache, - getShaderSource, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: - [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)] - }), - }; - }; - -export const createReduceAttributesFromInputs = - (inputs: readonly TensorView[], attributes: ReduceAttributes): ReduceAttributes => { - const axes: number[] = []; - if (inputs[1].dims[0] > 0) { - inputs[1].getBigInt64Array().forEach(v => axes.push(Number(v))); - } - return createAttributeWithCacheKey( - {axes, keepDims: attributes.keepDims, noopWithEmptyAxes: attributes.noopWithEmptyAxes}); - }; - -const runReduceProgram = - (context: ComputeContext, name: string, attributes: ReduceAttributes, reduceOp: ReduceOp): void => { - const inputs = context.inputs; - const updatedAttributes: ReduceAttributes = - inputs.length === 1 ? attributes : createReduceAttributesFromInputs(inputs, attributes); - - context.compute( - createReduceProgramInfo( - name, {hint: updatedAttributes.cacheKey, inputDependencies: ['rank']}, [inputs[0]], - updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp, - updatedAttributes.axes, inputs[0].dataType, updatedAttributes.keepDims, - updatedAttributes.noopWithEmptyAxes), - {inputs: [0]}); - }; + }; + + return { + name, + shaderCache, + getShaderSource, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: outputDataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(inputShape, outputShape), + ], + }), + }; +}; + +export const createReduceAttributesFromInputs = ( + inputs: readonly TensorView[], + attributes: ReduceAttributes, +): ReduceAttributes => { + const axes: number[] = []; + if (inputs[1].dims[0] > 0) { + inputs[1].getBigInt64Array().forEach((v) => axes.push(Number(v))); + } + return createAttributeWithCacheKey({ + axes, + keepDims: attributes.keepDims, + noopWithEmptyAxes: attributes.noopWithEmptyAxes, + }); +}; + +const runReduceProgram = ( + context: ComputeContext, + name: string, + attributes: ReduceAttributes, + reduceOp: ReduceOp, +): void => { + const inputs = context.inputs; + const updatedAttributes: ReduceAttributes = + inputs.length === 1 ? attributes : createReduceAttributesFromInputs(inputs, attributes); + + context.compute( + createReduceProgramInfo( + name, + { hint: updatedAttributes.cacheKey, inputDependencies: ['rank'] }, + [inputs[0]], + updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp, + updatedAttributes.axes, + inputs[0].dataType, + updatedAttributes.keepDims, + updatedAttributes.noopWithEmptyAxes, + ), + { inputs: [0] }, + ); +}; const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); - const reduceOp: ReduceOp = (input, output) => - [`var value = ${output.type.storage}(0);`, - '', - `value += ${input.getByIndices('input_indices')};`, - 'value = log(value);', + const reduceOp: ReduceOp = (input, output) => [ + `var value = ${output.type.storage}(0);`, + '', + `value += ${input.getByIndices('input_indices')};`, + 'value = log(value);', ]; runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp); }; const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); - const reduceOp: ReduceOp = (input, output) => - [`var value = ${output.type.storage}(0);`, - '', - `value += abs(${input.getByIndices('input_indices')});`, - '', + const reduceOp: ReduceOp = (input, output) => [ + `var value = ${output.type.storage}(0);`, + '', + `value += abs(${input.getByIndices('input_indices')});`, + '', ]; runReduceProgram(context, 'ReduceL1', attributes, reduceOp); }; const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); - const reduceOp: ReduceOp = (input, output) => - [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, - '', - `t = ${input.getByIndices('input_indices')}; value += (t * t);`, - 'value = sqrt(value);', + const reduceOp: ReduceOp = (input, output) => [ + `var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, + '', + `t = ${input.getByIndices('input_indices')}; value += (t * t);`, + 'value = sqrt(value);', ]; runReduceProgram(context, 'ReduceL2', attributes, reduceOp); }; const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); - const reduceOp: ReduceOp = (input, output) => - [`var value = ${output.type.storage}(0);`, - '', - `value += exp(${input.getByIndices('input_indices')});`, - 'value = log(value);', + const reduceOp: ReduceOp = (input, output) => [ + `var value = ${output.type.storage}(0);`, + '', + `value += exp(${input.getByIndices('input_indices')});`, + 'value = log(value);', ]; runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp); }; @@ -222,7 +259,7 @@ const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): const idxZero = []; for (let k = 0; k < input.rank; k++) { if (axes.indexOf(k) >= 0 || axes.length === 0) { - idxZero.push(`input_indices[${k}] = 0;`); // first element + idxZero.push(`input_indices[${k}] = 0;`); // first element } } @@ -238,58 +275,61 @@ const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); - const reduceOp: ReduceOp = (input, output) => - [`var value = ${output.type.storage}(1);`, - '', - `value *= ${input.getByIndices('input_indices')};`, - '', + const reduceOp: ReduceOp = (input, output) => [ + `var value = ${output.type.storage}(1);`, + '', + `value *= ${input.getByIndices('input_indices')};`, + '', ]; runReduceProgram(context, 'ReduceProd', attributes, reduceOp); }; const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); - const reduceOp: ReduceOp = (input, output) => - [`var value = ${output.type.storage}(0);`, - '', - `value += ${input.getByIndices('input_indices')};`, - '', + const reduceOp: ReduceOp = (input, output) => [ + `var value = ${output.type.storage}(0);`, + '', + `value += ${input.getByIndices('input_indices')};`, + '', ]; runReduceProgram(context, 'ReduceSum', attributes, reduceOp); }; const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); - const reduceOp: ReduceOp = (input, output) => - [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, - '', - `t = ${input.getByIndices('input_indices')}; value += t * t;`, - '', + const reduceOp: ReduceOp = (input, output) => [ + `var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, + '', + `t = ${input.getByIndices('input_indices')}; value += t * t;`, + '', ]; runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp); }; -const useNaiveReduceMethod = - (shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => { - if (axes.length === 0) { - return noopWithEmptyAxes; - } +const useNaiveReduceMethod = ( + shape: readonly number[], + axes: readonly number[], + noopWithEmptyAxes: boolean, +): boolean => { + if (axes.length === 0) { + return noopWithEmptyAxes; + } - let outputSize = 1; - let reduceSize = 1; - for (let dim = 0; dim < axes.length; dim++) { - if (axes.indexOf(dim) === -1) { - outputSize *= shape[dim]; - } else { - reduceSize *= shape[dim]; - } - } + let outputSize = 1; + let reduceSize = 1; + for (let dim = 0; dim < axes.length; dim++) { + if (axes.indexOf(dim) === -1) { + outputSize *= shape[dim]; + } else { + reduceSize *= shape[dim]; + } + } - // The condition data is very rough, although considering the count of Execution Unit (EU), the potential - // work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments - // on some machines. - return reduceSize < 32 && outputSize > 1024; - }; + // The condition data is very rough, although considering the count of Execution Unit (EU), the potential + // work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments + // on some machines. + return reduceSize < 32 && outputSize > 1024; +}; export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 2c6b537de1f00..3cd7540ca0b7d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -1,23 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; - -import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; - -type CoordinateTransformMode = 'half_pixel'|'asymmetric'|'pytorch_half_pixel'|'tf_half_pixel_for_nn'|'align_corners'| - 'tf_crop_and_resize'|'half_pixel_symmetric'; - -type KeepAspectRatioPolicy = 'stretch'|'not_smaller'|'not_larger'; - -type Mode = 'nearest'|'linear'|'cubic'; - -type NearestMode = 'round_prefer_floor'|'round_prefer_ceil'|'floor'|'ceil'|'simple'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo } from '../types'; + +import { + createTensorShapeVariables, + getElementAt, + IndicesHelper, + inputVariable, + outputVariable, + ShaderHelper, +} from './common'; + +type CoordinateTransformMode = + | 'half_pixel' + | 'asymmetric' + | 'pytorch_half_pixel' + | 'tf_half_pixel_for_nn' + | 'align_corners' + | 'tf_crop_and_resize' + | 'half_pixel_symmetric'; + +type KeepAspectRatioPolicy = 'stretch' | 'not_smaller' | 'not_larger'; + +type Mode = 'nearest' | 'linear' | 'cubic'; + +type NearestMode = 'round_prefer_floor' | 'round_prefer_ceil' | 'floor' | 'ceil' | 'simple'; export interface ResizeAttributes extends AttributeWithCacheKey { antialias: number; @@ -32,22 +44,38 @@ export interface ResizeAttributes extends AttributeWithCacheKey { } const validateScales = (scales: number[], attributes: ResizeAttributes): void => { - scales.every((value) => value > 0 || (() => { - throw new Error('Resize requires scales input values to be positive'); - })); + scales.every( + (value) => + value > 0 || + (() => { + throw new Error('Resize requires scales input values to be positive'); + }), + ); // Check scales dims based on mode: LINEAR, CUBIC if (scales.length > 0) { if (attributes.mode === 'linear') { - if (!(scales.length === 2 || scales.length === 3 || (scales.length === 4 && scales[0] === 1 && scales[1] === 1) || - (scales.length === 4 && scales[0] === 1 && scales[3] === 1) || - (scales.length === 5 && scales[0] === 1 && scales[1] === 1))) { + if ( + !( + scales.length === 2 || + scales.length === 3 || + (scales.length === 4 && scales[0] === 1 && scales[1] === 1) || + (scales.length === 4 && scales[0] === 1 && scales[3] === 1) || + (scales.length === 5 && scales[0] === 1 && scales[1] === 1) + ) + ) { throw new Error( - `For linear mode, Resize requires scales to be 2D, 3D, 4D with either two outermost or one innermost and - one outermost scale values equal to 1, or 5D with two outermost scale values equal to 1`); + `For linear mode, Resize requires scales to be 2D, 3D, 4D with either two outermost or one innermost and + one outermost scale values equal to 1, or 5D with two outermost scale values equal to 1`, + ); } } else if (attributes.mode === 'cubic') { - if (!(scales.length === 2 || (scales.length === 4 && scales[0] === 1 && scales[1] === 1) || - (scales.length === 4 && scales[0] === 1 && scales[3] === 1))) { + if ( + !( + scales.length === 2 || + (scales.length === 4 && scales[0] === 1 && scales[1] === 1) || + (scales.length === 4 && scales[0] === 1 && scales[3] === 1) + ) + ) { throw new Error('Resize requires scales input size to be 2 or 4 for cubic mode'); } } @@ -55,77 +83,90 @@ const validateScales = (scales: number[], attributes: ResizeAttributes): void => }; const updateScales = (scales: readonly number[], axes: readonly number[], rank: number): number[] => { - axes.every((value) => value >= 0 && value < rank || (() => { - throw new Error('Resize requires axes input values to be positive and less than rank'); - })); + axes.every( + (value) => + (value >= 0 && value < rank) || + (() => { + throw new Error('Resize requires axes input values to be positive and less than rank'); + }), + ); const newScales = new Array(rank).fill(1.0); - axes.forEach((value, index) => newScales[value] = scales[index]); + axes.forEach((value, index) => (newScales[value] = scales[index])); return newScales; }; -const validateInputs = - (inputs: readonly TensorView[], attributes: ResizeAttributes, opsetVersion: number, scales: number[], - sizes: number[], roi: number[]): void => { - const [roiInputIndex, scalesInputIndex, sizesInputIndex] = - (opsetVersion > 10) ? [1, 2, 3] : [-1, (inputs.length > 1) ? 1 : -1, -1]; - const rank = inputs[0].dims.length; - if (roiInputIndex > 0 && inputs.length > roiInputIndex && inputs[roiInputIndex].dims.length > 0) { - inputs[roiInputIndex].getFloat32Array().forEach((value) => roi.push(value)); - } else if (attributes.coordinateTransformMode === 'tf_crop_and_resize') { - throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize'); - } +const validateInputs = ( + inputs: readonly TensorView[], + attributes: ResizeAttributes, + opsetVersion: number, + scales: number[], + sizes: number[], + roi: number[], +): void => { + const [roiInputIndex, scalesInputIndex, sizesInputIndex] = + opsetVersion > 10 ? [1, 2, 3] : [-1, inputs.length > 1 ? 1 : -1, -1]; + const rank = inputs[0].dims.length; + if (roiInputIndex > 0 && inputs.length > roiInputIndex && inputs[roiInputIndex].dims.length > 0) { + inputs[roiInputIndex].getFloat32Array().forEach((value) => roi.push(value)); + } else if (attributes.coordinateTransformMode === 'tf_crop_and_resize') { + throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize'); + } - if (scalesInputIndex > 0 && inputs.length > scalesInputIndex && inputs[scalesInputIndex].dims.length > 0) { - inputs[scalesInputIndex].getFloat32Array().forEach((value) => scales.push(value)); - if (scales.length !== 0 && - (scales.length !== rank && (opsetVersion >= 18 && scales.length !== attributes.axes.length))) { - throw new Error( - 'Resize requires scales input size to be same as input rank or axes size for opset 18 and up'); - } - validateScales(scales, attributes); - if (attributes.axes.length > 0) { - updateScales(scales, attributes.axes, rank).forEach((value, index) => scales[index] = value); - } - } - if (sizesInputIndex > 0 && inputs.length > sizesInputIndex) { - inputs[sizesInputIndex].getBigInt64Array().forEach((value) => sizes.push(Number(value))); - if (sizes.length !== rank || (opsetVersion >= 18 && sizes.length === attributes.axes.length)) { - throw new Error('Resize requires sizes input size to be same as input rank or axes size for opset 18 and up'); - } - } + if (scalesInputIndex > 0 && inputs.length > scalesInputIndex && inputs[scalesInputIndex].dims.length > 0) { + inputs[scalesInputIndex].getFloat32Array().forEach((value) => scales.push(value)); + if ( + scales.length !== 0 && + scales.length !== rank && + opsetVersion >= 18 && + scales.length !== attributes.axes.length + ) { + throw new Error('Resize requires scales input size to be same as input rank or axes size for opset 18 and up'); + } + validateScales(scales, attributes); + if (attributes.axes.length > 0) { + updateScales(scales, attributes.axes, rank).forEach((value, index) => (scales[index] = value)); + } + } + if (sizesInputIndex > 0 && inputs.length > sizesInputIndex) { + inputs[sizesInputIndex].getBigInt64Array().forEach((value) => sizes.push(Number(value))); + if (sizes.length !== rank || (opsetVersion >= 18 && sizes.length === attributes.axes.length)) { + throw new Error('Resize requires sizes input size to be same as input rank or axes size for opset 18 and up'); + } + } - if (attributes.axes.length > 0) { - if (scales.length !== attributes.axes.length) { - throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified'); - } - if (sizes.length !== attributes.axes.length) { - throw new Error( - 'Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified'); - } - } - if (typeof scales !== 'undefined' && typeof sizes !== 'undefined' && scales.length > 0 && sizes.length > rank) { - throw new Error('Resize requires only of scales or sizes to be specified'); - } - }; + if (attributes.axes.length > 0) { + if (scales.length !== attributes.axes.length) { + throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified'); + } + if (sizes.length !== attributes.axes.length) { + throw new Error('Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified'); + } + } + if (typeof scales !== 'undefined' && typeof sizes !== 'undefined' && scales.length > 0 && sizes.length > rank) { + throw new Error('Resize requires only of scales or sizes to be specified'); + } +}; -const getOriginalCoordinateFromResizedCoordinate = - (coordinateTransferMode: CoordinateTransformMode, dType: string): string => - `fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: f32, lengthResized: u32, +const getOriginalCoordinateFromResizedCoordinate = ( + coordinateTransferMode: CoordinateTransformMode, + dType: string, +): string => + `fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: f32, lengthResized: u32, lengthOriginal: u32, roiStart: f32, roiEnd: f32) -> ${dType} { ` + - (() => { - switch (coordinateTransferMode) { - case 'asymmetric': - return `return ${dType}(xResized) / ${dType}(xScale);`; - case 'pytorch_half_pixel': - return `if (lengthResized > 1) { + (() => { + switch (coordinateTransferMode) { + case 'asymmetric': + return `return ${dType}(xResized) / ${dType}(xScale);`; + case 'pytorch_half_pixel': + return `if (lengthResized > 1) { return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5; } else { return 0.0; }`; - case 'tf_half_pixel_for_nn': - return `return (${dType}(xResized) + 0.5) / ${dType}(xScale);`; - case 'align_corners': - return `if (lengthResized == 1) { + case 'tf_half_pixel_for_nn': + return `return (${dType}(xResized) + 0.5) / ${dType}(xScale);`; + case 'align_corners': + return `if (lengthResized == 1) { return 0.0; } else { // The whole part and the fractional part are calculated separately due to inaccuracy of floating @@ -136,61 +177,62 @@ const getOriginalCoordinateFromResizedCoordinate = ${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1); return whole + fract; }`; - case 'tf_crop_and_resize': - return `if (lengthResized > 1) { + case 'tf_crop_and_resize': + return `if (lengthResized > 1) { return ${dType}(roiStart) * ${dType}(lengthOriginal - 1) + (${dType}(xResized) * ${dType}(roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) / ${dType}(lengthResized - 1); } else { return 0.5 * ${dType}(roiStart + roiEnd) * ${dType}(lengthOriginal - 1); }`; - case 'half_pixel_symmetric': - return `const outputWidth = ${dType}xScale * ${dType}(lengthResized); + case 'half_pixel_symmetric': + return `const outputWidth = ${dType}xScale * ${dType}(lengthResized); const adjustment = ${dType}(lengthResized) / outputWidth; const center = ${dType}(lengthOriginal) / 2; const offset = center * (1 - adjustment); return offset + ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`; - case 'half_pixel': - return `return ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`; - default: - throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); - } - })() + - '}'; + case 'half_pixel': + return `return ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`; + default: + throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); + } + })() + + '}'; const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string => - `fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + (() => { - switch (nearestMode) { - case 'round_prefer_ceil': - return 'if (fract(xOriginal) == 0.5) { \ + `fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + + (() => { + switch (nearestMode) { + case 'round_prefer_ceil': + return 'if (fract(xOriginal) == 0.5) { \ return ceil(xOriginal); \ } else { \ return round(xOriginal); \ }'; - case 'floor': - return 'return floor(xOriginal);'; - case 'ceil': - return 'return ceil(xOriginal);'; - case 'round_prefer_floor': - return 'if (fract(xOriginal) == 0.5) { \ + case 'floor': + return 'return floor(xOriginal);'; + case 'ceil': + return 'return ceil(xOriginal);'; + case 'round_prefer_floor': + return 'if (fract(xOriginal) == 0.5) { \ return floor(xOriginal); \ } else { \ return round(xOriginal); \ }'; - case 'simple': - default: - if (opsetVersion < 11) { - return 'if (isDownSample) \ + case 'simple': + default: + if (opsetVersion < 11) { + return 'if (isDownSample) \ { \ return ceil(xOriginal); \ } else { \ return xOriginal; \ }'; - } - throw new Error(`Nearest mode ${nearestMode} is not supported`); - } - })() + - '}'; + } + throw new Error(`Nearest mode ${nearestMode} is not supported`); + } + })() + + '}'; const updateRoI = (roi: readonly number[], axes: readonly number[], rank: number): number[] => { const roiTmp = new Array(rank).fill(0).concat(new Array(rank).fill(1)); @@ -205,39 +247,44 @@ const updateRoI = (roi: readonly number[], axes: readonly number[], rank: number return roiLocal; }; -const initOutputShape = - (inputShape: readonly number[], scales: readonly number[], sizes: readonly number[], axes: readonly number[]): - number[] => { - let outputShape: number[] = []; - if (sizes.length > 0) { - if (axes.length > 0) { - inputShape.forEach((v) => outputShape.push(v)); - if (Math.max(...axes) > inputShape.length) { - throw new Error('axes is out of bound'); - } - axes.forEach((v, i) => outputShape[v] = sizes[i]); - } else { - sizes.forEach((v) => outputShape.push(v)); - } - } else { - if (scales.length === 0) { - throw new Error('Resize requires either scales or sizes.'); - } else { - outputShape = inputShape.map((value, index) => Math.round(value * scales[index])); - } - } - return outputShape; - }; +const initOutputShape = ( + inputShape: readonly number[], + scales: readonly number[], + sizes: readonly number[], + axes: readonly number[], +): number[] => { + let outputShape: number[] = []; + if (sizes.length > 0) { + if (axes.length > 0) { + inputShape.forEach((v) => outputShape.push(v)); + if (Math.max(...axes) > inputShape.length) { + throw new Error('axes is out of bound'); + } + axes.forEach((v, i) => (outputShape[v] = sizes[i])); + } else { + sizes.forEach((v) => outputShape.push(v)); + } + } else { + if (scales.length === 0) { + throw new Error('Resize requires either scales or sizes.'); + } else { + outputShape = inputShape.map((value, index) => Math.round(value * scales[index])); + } + } + return outputShape; +}; const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes) => { const scaleInPolicy = (() => { switch (attributes.keepAspectRatioPolicy) { case 'not_larger': - return attributes.axes.length > 0 ? Math.min(...attributes.axes.map(i => scales[i]), Number.MAX_VALUE) : - Math.min(...scales, Number.MAX_VALUE); + return attributes.axes.length > 0 + ? Math.min(...attributes.axes.map((i) => scales[i]), Number.MAX_VALUE) + : Math.min(...scales, Number.MAX_VALUE); case 'not_smaller': - return attributes.axes.length > 0 ? Math.max(...attributes.axes.map(i => scales[i]), Number.MIN_VALUE) : - Math.max(...scales, Number.MIN_VALUE); + return attributes.axes.length > 0 + ? Math.max(...attributes.axes.map((i) => scales[i]), Number.MIN_VALUE) + : Math.max(...scales, Number.MIN_VALUE); default: throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`); } @@ -245,20 +292,25 @@ const adjustOutputShape = (inputShape: readonly number[], scales: number[], attr scales.fill(1.0, 0, scales.length); const adjustedOutputShape = inputShape.slice(); if (attributes.axes.length > 0) { - attributes.axes.forEach((v) => scales[v] = scaleInPolicy); - attributes.axes.forEach((v) => adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v])); + attributes.axes.forEach((v) => (scales[v] = scaleInPolicy)); + attributes.axes.forEach((v) => (adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v]))); } else { scales.fill(scaleInPolicy, 0, scales.length); - adjustedOutputShape.forEach((v, i) => adjustedOutputShape[i] = Math.round(v * scales[i])); + adjustedOutputShape.forEach((v, i) => (adjustedOutputShape[i] = Math.round(v * scales[i]))); } return adjustedOutputShape; }; -const calculateOriginalIndicesFromOutputIndices = - (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scalesLength: number, - roiLength: number): string => ` +const calculateOriginalIndicesFromOutputIndices = ( + output: IndicesHelper, + inputShape: readonly number[], + outputShape: readonly number[], + scalesLength: number, + roiLength: number, +): string => ` fn calculateOriginalIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> array<${ - output.type.value}, ${outputShape.length}> { + output.type.value + }, ${outputShape.length}> { var original_indices: array<${output.type.value}, ${outputShape.length}>; for (var i:u32 = 0; i < ${outputShape.length}; i++) { var output_index = ${output.indicesGet('output_indices', 'i')}; @@ -277,9 +329,15 @@ const calculateOriginalIndicesFromOutputIndices = return original_indices; }`; -const calculateInputIndicesFromOutputIndices = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], - scalesLength: number, roiLength: number, useExtrapolation: boolean): string => ` +const calculateInputIndicesFromOutputIndices = ( + input: IndicesHelper, + output: IndicesHelper, + inputShape: readonly number[], + outputShape: readonly number[], + scalesLength: number, + roiLength: number, + useExtrapolation: boolean, +): string => ` fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { var input_indices: ${input.type.indices}; for (var i:u32 = 0; i < ${outputShape.length}; i++) { @@ -322,22 +380,31 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): return true; }`; -const setChannelAndBatchIndices = - (input: IndicesHelper, channelIdx: number, batchIdx: number, spacialDims: number): string => - input.rank > spacialDims ? ` +const setChannelAndBatchIndices = ( + input: IndicesHelper, + channelIdx: number, + batchIdx: number, + spacialDims: number, +): string => + input.rank > spacialDims + ? ` ${input.indicesSet('input_indices', channelIdx, 'channel')}; ${input.indicesSet('input_indices', batchIdx, 'batch')}; -` : - ''; - -const bilinearInterpolation = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean, - extrapolationValue: number): string => { - const isNchw = true; - const [batchIdx, heightIdx, widthIdx, channelIdx] = - inputShape.length === 2 ? [-1, 0, 1, -1] : (isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3]); - const dType = input.type.value; - return ` +` + : ''; + +const bilinearInterpolation = ( + input: IndicesHelper, + output: IndicesHelper, + inputShape: readonly number[], + useExtrapolation: boolean, + extrapolationValue: number, +): string => { + const isNchw = true; + const [batchIdx, heightIdx, widthIdx, channelIdx] = + inputShape.length === 2 ? [-1, 0, 1, -1] : isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3]; + const dType = input.type.value; + return ` fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} { var input_indices: ${input.type.indices}; ${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)}; @@ -351,11 +418,12 @@ const bilinearInterpolation = var row:${dType} = originalIndices[${heightIdx}]; var col:${dType} = originalIndices[${widthIdx}]; ${ - useExtrapolation ? - `if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${inputShape[widthIdx]} - 1)) { + useExtrapolation + ? `if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${inputShape[widthIdx]} - 1)) { return ${extrapolationValue}; - }` : - ''}; + }` + : '' + }; row = max(0, min(row, ${inputShape[heightIdx]} - 1)); col = max(0, min(col, ${inputShape[widthIdx]} - 1)); var row1: u32 = u32(row); @@ -382,21 +450,30 @@ const bilinearInterpolation = } return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1); }`; - }; - -const bicubicInterpolation = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], - scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean, - extrapolationValue: number, excludeOutside: boolean): string => { - const is2D = inputShape.length === 2; - const isNchw = true; - const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2]; - const dType = input.type.value; - const createCubicInterpolationFunction = (idx: number): string => { - const direction = idx === heightIdx ? 'row' : 'col'; - return ` +}; + +const bicubicInterpolation = ( + input: IndicesHelper, + output: IndicesHelper, + inputShape: readonly number[], + outputShape: readonly number[], + scales: readonly number[], + roi: readonly number[], + cubicCoeffA: number, + useExtrapolation: boolean, + extrapolationValue: number, + excludeOutside: boolean, +): string => { + const is2D = inputShape.length === 2; + const isNchw = true; + const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2]; + const dType = input.type.value; + const createCubicInterpolationFunction = (idx: number): string => { + const direction = idx === heightIdx ? 'row' : 'col'; + return ` fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${ - output.type.indices}) -> ${dType} { + output.type.indices + }) -> ${dType} { var output_index = ${output.indicesGet('output_indices', idx)}; var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(output_index, ${scales[idx]}, ${outputShape[idx]}, ${inputShape[idx]}, ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); @@ -411,27 +488,29 @@ const bicubicInterpolation = var ${direction}: ${dType} = originalIdx + ${dType}(i); if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) { ${(() => { - if (excludeOutside) { - return `coefs[i + 1] = 0.0; + if (excludeOutside) { + return `coefs[i + 1] = 0.0; continue;`; - } else if (useExtrapolation) { - return `return ${extrapolationValue};`; - } else { - return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`; - } - })()}; + } else if (useExtrapolation) { + return `return ${extrapolationValue};`; + } else { + return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`; + } + })()}; } var input_indices_copy: ${input.type.indices} = input_indices; ${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)}; data[i + 1] = ${ - idx === heightIdx ? input.getByIndices('input_indices_copy') : - 'rowCubicInterpolation(input_indices_copy, output_indices)'}; + idx === heightIdx + ? input.getByIndices('input_indices_copy') + : 'rowCubicInterpolation(input_indices_copy, output_indices)' + }; } return cubicInterpolation1D(data, coefs); }`; - }; + }; - return ` + return ` ${createCubicInterpolationFunction(heightIdx)}; ${createCubicInterpolationFunction(widthIdx)}; fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> { @@ -441,11 +520,13 @@ const bicubicInterpolation = var twoMinusAbsS: ${dType} = 2.0 - absS; var onePlusAbsS: ${dType} = 1.0 + absS; coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${ - cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA}; + cubicCoeffA + }) * onePlusAbsS - 4 * ${cubicCoeffA}; coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1; coeffs[2] = ((${cubicCoeffA} + 2) * oneMinusAbsS - (${cubicCoeffA} + 3)) * oneMinusAbsS * oneMinusAbsS + 1; coeffs[3] = ((${cubicCoeffA} * twoMinusAbsS - 5 * ${cubicCoeffA}) * twoMinusAbsS + 8 * ${ - cubicCoeffA}) * twoMinusAbsS - 4 * ${cubicCoeffA}; + cubicCoeffA + }) * twoMinusAbsS - 4 * ${cubicCoeffA}; return coeffs; } @@ -459,16 +540,20 @@ const bicubicInterpolation = return colCubicInterpolation(input_indices, output_indices); } `; - }; - -const trilinearInterpolation = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean, - extrapolationValue: number): string => { - const isNchw = true; - const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] = - inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]); - const dType = input.type.value; - return ` +}; + +const trilinearInterpolation = ( + input: IndicesHelper, + output: IndicesHelper, + inputShape: readonly number[], + useExtrapolation: boolean, + extrapolationValue: number, +): string => { + const isNchw = true; + const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] = + inputShape.length === 3 ? [-1, 0, 1, 2, -1] : isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]; + const dType = input.type.value; + return ` fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} { var input_indices: ${input.type.indices}; ${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)}; @@ -484,11 +569,14 @@ const trilinearInterpolation = var height:${dType} = originalIndices[${heightIdx}]; var width:${dType} = originalIndices[${widthIdx}]; ${ - useExtrapolation ? `if (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${ - inputShape[heightIdx]} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1)) { + useExtrapolation + ? `if (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${ + inputShape[heightIdx] + } - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1)) { return ${extrapolationValue}; - }` : - ''}; + }` + : '' + }; depth = max(0, min(depth, ${inputShape[depthIdx]} - 1)); height = max(0, min(height, ${inputShape[heightIdx]} - 1)); @@ -531,31 +619,39 @@ const trilinearInterpolation = return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 + x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1); }`; - }; - -const createResizeProgramInfo = - (inputTensor: TensorView, attributes: ResizeAttributes, opsetVersion: number, scalesInput: readonly number[], - sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => { - const inputShape = inputTensor.dims; - const roi = updateRoI(roiInput, attributes.axes, inputShape.length); - - let outputShape = initOutputShape(inputShape, scalesInput, sizes, attributes.axes); - let scales = scalesInput.slice(); - if (scalesInput.length === 0) { - scales = inputShape.map((value, index) => value === 0 ? 1.0 : outputShape[index] / value); - if (attributes.keepAspectRatioPolicy !== 'stretch') { - outputShape = adjustOutputShape(inputShape, scales, attributes); - } - } - const output = outputVariable('output', inputTensor.dataType, outputShape.length); - const input = inputVariable('input', inputTensor.dataType, inputShape.length); - const outputSize = ShapeUtil.size(outputShape); - const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); - const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; - const extrapolationValue = attributes.extrapolationValue; - const dataType = input.type.value; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${noScale ? '' : ` +}; + +const createResizeProgramInfo = ( + inputTensor: TensorView, + attributes: ResizeAttributes, + opsetVersion: number, + scalesInput: readonly number[], + sizes: readonly number[], + roiInput: readonly number[], +): ProgramInfo => { + const inputShape = inputTensor.dims; + const roi = updateRoI(roiInput, attributes.axes, inputShape.length); + + let outputShape = initOutputShape(inputShape, scalesInput, sizes, attributes.axes); + let scales = scalesInput.slice(); + if (scalesInput.length === 0) { + scales = inputShape.map((value, index) => (value === 0 ? 1.0 : outputShape[index] / value)); + if (attributes.keepAspectRatioPolicy !== 'stretch') { + outputShape = adjustOutputShape(inputShape, scales, attributes); + } + } + const output = outputVariable('output', inputTensor.dataType, outputShape.length); + const input = inputVariable('input', inputTensor.dataType, inputShape.length); + const outputSize = ShapeUtil.size(outputShape); + const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); + const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; + const extrapolationValue = attributes.extrapolationValue; + const dataType = input.type.value; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${ + noScale + ? '' + : ` ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)}; ${(() => { switch (attributes.mode) { @@ -563,31 +659,45 @@ const createResizeProgramInfo = return ` ${checkInputIndices(input, inputShape)}; ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)}; - ${ - calculateInputIndicesFromOutputIndices( - input, output, inputShape, outputShape, scales.length, roi.length, useExtrapolation)}; + ${calculateInputIndicesFromOutputIndices( + input, + output, + inputShape, + outputShape, + scales.length, + roi.length, + useExtrapolation, + )}; `; case 'linear': return ` ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)}; ${(() => { - if (inputShape.length === 2 || inputShape.length === 4) { - return `${bilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`; - } else if (inputShape.length === 3 || inputShape.length === 5) { - return `${trilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`; - } else { - throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.'); - } - })()}; + if (inputShape.length === 2 || inputShape.length === 4) { + return `${bilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`; + } else if (inputShape.length === 3 || inputShape.length === 5) { + return `${trilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`; + } else { + throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.'); + } + })()}; `; case 'cubic': return ` ${(() => { if (inputShape.length === 2 || inputShape.length === 4) { - return `${ - bicubicInterpolation( - input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation, - attributes.extrapolationValue, attributes.excludeOutside)}`; + return `${bicubicInterpolation( + input, + output, + inputShape, + outputShape, + scales, + roi, + attributes.cubicCoeffA, + useExtrapolation, + attributes.extrapolationValue, + attributes.excludeOutside, + )}`; } else { throw Error('Cubic mode only supports input dims 2 and 4 are supported in linear mode.'); } @@ -597,57 +707,65 @@ const createResizeProgramInfo = throw Error('Invalid resize mode'); } })()}; - `} - ${ - shaderHelper.registerUniform('output_size', 'u32') - .registerUniform('scales', 'f32', scales.length) - .registerUniform('roi', 'f32', roi.length) - .declareVariables(input, output)} + ` + } + ${shaderHelper + .registerUniform('output_size', 'u32') + .registerUniform('scales', 'f32', scales.length) + .registerUniform('roi', 'f32', roi.length) + .declareVariables(input, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - ${noScale ? 'output[global_idx] = input[global_idx];' : ` + ${ + noScale + ? 'output[global_idx] = input[global_idx];' + : ` let output_indices = ${output.offsetToIndices('global_idx')}; var input_indices: ${input.type.indices}; ${(() => { - switch (attributes.mode) { - case 'nearest': - return `input_indices = calculateInputIndicesFromOutputIndices(output_indices); + switch (attributes.mode) { + case 'nearest': + return `input_indices = calculateInputIndicesFromOutputIndices(output_indices); if (checkInputIndices(input_indices)) { output[global_idx] = ${input.getByIndices('input_indices')}; } else { output[global_idx] = ${attributes.extrapolationValue}; }`; - case 'linear': - return `output[global_idx] = ${ - (inputShape.length === 2 || inputShape.length === 4) ? 'bilinearInterpolation' : - 'trilinearInterpolation'}(output_indices);`; - case 'cubic': - return 'output[global_idx] = bicubicInterpolation(output_indices);'; - default: - throw Error(`Unsupported resize mode: ${attributes.mode}`); + case 'linear': + return `output[global_idx] = ${ + inputShape.length === 2 || inputShape.length === 4 ? 'bilinearInterpolation' : 'trilinearInterpolation' + }(output_indices);`; + case 'cubic': + return 'output[global_idx] = bicubicInterpolation(output_indices);'; + default: + throw Error(`Unsupported resize mode: ${attributes.mode}`); + } + })()}; +` } - })()}; -`} }`; - return { - name: 'Resize', - shaderCache: { - hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ - sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`, - inputDependencies: ['rank'] - }, - getShaderSource, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputTensor.dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: DataType.uint32, data: outputSize}, {type: DataType.float, data: scales}, - {type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape, outputShape) - ] - }) - }; - }; + return { + name: 'Resize', + shaderCache: { + hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ + sizes.length > 0 ? sizes : '' + }|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`, + inputDependencies: ['rank'], + }, + getShaderSource, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.float, data: scales }, + { type: DataType.float, data: roi }, + ...createTensorShapeVariables(inputShape, outputShape), + ], + }), + }; +}; const getOpsetVersionFromCustomDataBuffer = (context: ComputeContext): number => { const customDataBuffer = context.customDataBuffer; @@ -669,17 +787,18 @@ export const resize = (context: ComputeContext, attributes: ResizeAttributes): v throw Error('Only default value (0) for Antialias attribute is supported'); } validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi); - context.compute( - createResizeProgramInfo(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {inputs: [0]}); + context.compute(createResizeProgramInfo(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), { + inputs: [0], + }); }; export const parseResizeAttributes = (attributes: Record): ResizeAttributes => { const antialias = attributes.antialias as number; const axes = attributes.axes as number[]; const coordinateTransformMode: CoordinateTransformMode = - attributes.coordinateTransformMode as CoordinateTransformMode; + attributes.coordinateTransformMode as CoordinateTransformMode; const cubicCoeffA = attributes.cubicCoeffA as number; - const excludeOutside = attributes.excludeOutside as number !== 0; + const excludeOutside = (attributes.excludeOutside as number) !== 0; const extrapolationValue = attributes.extrapolationValue as number; const keepAspectRatioPolicy: KeepAspectRatioPolicy = attributes.keepAspectRatioPolicy as KeepAspectRatioPolicy; const mode: Mode = attributes.mode as Mode; @@ -694,6 +813,6 @@ export const parseResizeAttributes = (attributes: Record): Resi extrapolationValue, keepAspectRatioPolicy, mode, - nearestMode + nearestMode, }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts index a58087072e4c7..8eb7a10ac91fa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/rotary-embedding.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, WORKGROUP_SIZE} from './common'; +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, WORKGROUP_SIZE } from './common'; export interface RotaryEmbeddingAttributes { readonly interleaved: boolean; @@ -18,13 +18,16 @@ export interface RotaryEmbeddingAttributes { const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes): void => { const [input, positionIds, cosCache, sinCache] = inputs; - const {numHeads, rotaryEmbeddingDim} = attributes; + const { numHeads, rotaryEmbeddingDim } = attributes; if (input.dims.length !== 3 && input.dims.length !== 4) { throw new Error(`Input 'x' is expected to have 3 or 4 dimensions, got ${input.dims.length}`); } - if (!ShapeUtil.areEqual(positionIds.dims, []) && !ShapeUtil.areEqual(positionIds.dims, [1]) && - positionIds.dims.length !== 2) { + if ( + !ShapeUtil.areEqual(positionIds.dims, []) && + !ShapeUtil.areEqual(positionIds.dims, [1]) && + positionIds.dims.length !== 2 + ) { throw new Error(`Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ${positionIds.dims.length}`); } if (cosCache.dims.length !== 2) { @@ -34,7 +37,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi throw new Error(`Input 'sin_cache' is expected to have 2 dimensions, got ${sinCache.dims.length}`); } if (!ShapeUtil.areEqual(cosCache.dims, sinCache.dims)) { - throw new Error('Inputs \'cos_cache\' and \'sin_cache\' are expected to have the same shape'); + throw new Error("Inputs 'cos_cache' and 'sin_cache' are expected to have the same shape"); } if (rotaryEmbeddingDim > 0 && numHeads === 0) { @@ -60,8 +63,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi } if (headSize / 2 !== cosCache.dims[1] && rotaryEmbeddingDim / 2 !== cosCache.dims[1]) { - throw new Error(`Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${ - cosCache.dims[1]}`); + throw new Error( + `Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${ + cosCache.dims[1] + }`, + ); } if (sequenceLength > maxSequenceLength) { @@ -69,56 +75,64 @@ const validateInputs = (inputs: readonly TensorView[], attributes: RotaryEmbeddi } }; -const createRotaryEmbeddingProgramInfo = - (inputs: readonly TensorView[], attributes: RotaryEmbeddingAttributes): ProgramInfo => { - const {interleaved, numHeads, rotaryEmbeddingDim, scale} = attributes; - const batchSize = inputs[0].dims[0]; - const batchStride = ShapeUtil.sizeFromDimension(inputs[0].dims, 1); - const sequenceLength = inputs[0].dims[inputs[0].dims.length - 2]; - const hiddenSize = batchStride / sequenceLength; - const halfRotaryEmbeddingDim = inputs[2].dims[1]; - const headSize = rotaryEmbeddingDim === 0 ? halfRotaryEmbeddingDim * 2 : hiddenSize / numHeads; - - // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape - // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] - // to unfold the global index in shader. - const globalShape = - new Array(batchSize, sequenceLength, hiddenSize / headSize, headSize - halfRotaryEmbeddingDim); - const globalStrides = ShapeUtil.computeStrides(globalShape); - - const programUniforms: ProgramUniform[] = [ - {type: DataType.float, data: scale}, - {type: DataType.uint32, data: globalShape}, - {type: DataType.uint32, data: globalStrides}, - - // strides for addressing the input/output tensor, in permutated order to align with the unfolded global index, - // i.e. BSNH - ...(inputs[0].dims.length === 3 ? - new Array({type: DataType.uint32, data: [batchStride, hiddenSize, headSize, 1]}) : - []), - ...(inputs[0].dims.length === 4 ? - new Array( - {type: DataType.uint32, data: [batchStride, headSize, sequenceLength * headSize, 1]}) : - []), - - ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, inputs[2].dims, inputs[3].dims, inputs[0].dims), - ]; - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length); - const positionIds = inputVariable('position_ids', inputs[1].dataType, inputs[1].dims.length); - const cosCache = inputVariable('cos_cache', inputs[2].dataType, inputs[2].dims.length); - const sinCache = inputVariable('sin_cache', inputs[3].dataType, inputs[3].dims.length); - const output = outputVariable('output', inputs[0].dataType, inputs[0].dims.length); - - shaderHelper.registerUniforms([ - {name: 'scale', type: 'f32'}, - {name: 'global_shape', type: 'u32', length: globalShape.length}, - {name: 'global_strides', type: 'u32', length: globalStrides.length}, - {name: 'input_output_strides', type: 'u32', length: globalStrides.length}, - ]); - - return ` +const createRotaryEmbeddingProgramInfo = ( + inputs: readonly TensorView[], + attributes: RotaryEmbeddingAttributes, +): ProgramInfo => { + const { interleaved, numHeads, rotaryEmbeddingDim, scale } = attributes; + const batchSize = inputs[0].dims[0]; + const batchStride = ShapeUtil.sizeFromDimension(inputs[0].dims, 1); + const sequenceLength = inputs[0].dims[inputs[0].dims.length - 2]; + const hiddenSize = batchStride / sequenceLength; + const halfRotaryEmbeddingDim = inputs[2].dims[1]; + const headSize = rotaryEmbeddingDim === 0 ? halfRotaryEmbeddingDim * 2 : hiddenSize / numHeads; + + // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape + // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] + // to unfold the global index in shader. + const globalShape = new Array( + batchSize, + sequenceLength, + hiddenSize / headSize, + headSize - halfRotaryEmbeddingDim, + ); + const globalStrides = ShapeUtil.computeStrides(globalShape); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.float, data: scale }, + { type: DataType.uint32, data: globalShape }, + { type: DataType.uint32, data: globalStrides }, + + // strides for addressing the input/output tensor, in permutated order to align with the unfolded global index, + // i.e. BSNH + ...(inputs[0].dims.length === 3 + ? new Array({ type: DataType.uint32, data: [batchStride, hiddenSize, headSize, 1] }) + : []), + ...(inputs[0].dims.length === 4 + ? new Array({ + type: DataType.uint32, + data: [batchStride, headSize, sequenceLength * headSize, 1], + }) + : []), + + ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, inputs[2].dims, inputs[3].dims, inputs[0].dims), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length); + const positionIds = inputVariable('position_ids', inputs[1].dataType, inputs[1].dims.length); + const cosCache = inputVariable('cos_cache', inputs[2].dataType, inputs[2].dims.length); + const sinCache = inputVariable('sin_cache', inputs[3].dataType, inputs[3].dims.length); + const output = outputVariable('output', inputs[0].dataType, inputs[0].dims.length); + + shaderHelper.registerUniforms([ + { name: 'scale', type: 'f32' }, + { name: 'global_shape', type: 'u32', length: globalShape.length }, + { name: 'global_strides', type: 'u32', length: globalStrides.length }, + { name: 'input_output_strides', type: 'u32', length: globalStrides.length }, + ]); + + return ` ${shaderHelper.declareVariables(input, positionIds, cosCache, sinCache, output)} ${shaderHelper.mainStart(WORKGROUP_SIZE)} @@ -145,24 +159,24 @@ const createRotaryEmbeddingProgramInfo = ${output.setByOffset('k', input.getByOffset('k'))} } }`; - }; - - return { - name: 'RotaryEmbedding', - shaderCache: { - hint: createAttributeWithCacheKey({ - interleaved, - }).cacheKey, - inputDependencies: ['rank', 'rank', 'rank', 'rank'], - }, - getShaderSource, - getRunData: () => ({ - outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(ShapeUtil.size(globalShape) / WORKGROUP_SIZE)}, - programUniforms, - }), - }; - }; + }; + + return { + name: 'RotaryEmbedding', + shaderCache: { + hint: createAttributeWithCacheKey({ + interleaved, + }).cacheKey, + inputDependencies: ['rank', 'rank', 'rank', 'rank'], + }, + getShaderSource, + getRunData: () => ({ + outputs: [{ dims: inputs[0].dims, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(ShapeUtil.size(globalShape) / WORKGROUP_SIZE) }, + programUniforms, + }), + }; +}; export const rotaryEmbedding = (context: ComputeContext, attributes: RotaryEmbeddingAttributes): void => { validateInputs(context.inputs, attributes); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index ae7306eaf20e6..5a3b31e011069 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -1,12 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; - -import {castToF32, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; + +import { + castToF32, + getMaxComponents, + inputVariable, + outputVariable, + ShaderHelper, + sumVector, + tensorTypeToWsglStorageType, + UniformsArrayType, +} from './common'; export interface SkipLayerNormAttributes { simplified: boolean; @@ -69,71 +78,72 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const createSkipLayerNormProgramInfo = - (inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, isTraining: boolean): - ProgramInfo => { - const simplified = attributes.simplified; - - const inputShape = inputs[0].dims; - const inputSize = ShapeUtil.size(inputShape); - const outputShape = inputShape; - const outputSize = inputSize; - const hiddenSize = inputShape.slice(-1)[0]; - const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; - const hasBetaInput = !simplified && inputs.length > 3; - const hasBiasInput = inputs.length > 4; - const hasMeanOutput = isTraining && outputCount > 1; - const hasInvStdDevOutput = isTraining && outputCount > 2; - const hasInputSkipBiasSumOutput = outputCount > 3; - const workgroupSize = 64; - - const components = getMaxComponents(hiddenSize); - - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, - {type: DataType.uint32, data: components}, - {type: DataType.uint32, data: hiddenSize}, - {type: DataType.float, data: attributes.epsilon}, - ]; - const getShaderSource = (shaderHelper: ShaderHelper) => { - const uniformsArray: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, - {name: 'components', type: 'u32'}, - {name: 'hidden_size', type: 'u32'}, - {name: 'epsilon', type: 'f32'}, - ]; - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), - inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), - ]; - if (hasBetaInput) { - variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); - } - if (hasBiasInput) { - variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - if (hasMeanOutput) { - variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdDevOutput) { - variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); - } - if (hasInputSkipBiasSumOutput) { - variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components)); - } - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const vecDataType = tensorTypeToWsglStorageType(DataType.float, components); - return ` +const createSkipLayerNormProgramInfo = ( + inputs: readonly TensorView[], + attributes: SkipLayerNormAttributes, + outputCount: number, + isTraining: boolean, +): ProgramInfo => { + const simplified = attributes.simplified; + + const inputShape = inputs[0].dims; + const inputSize = ShapeUtil.size(inputShape); + const outputShape = inputShape; + const outputSize = inputSize; + const hiddenSize = inputShape.slice(-1)[0]; + const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; + const hasBetaInput = !simplified && inputs.length > 3; + const hasBiasInput = inputs.length > 4; + const hasMeanOutput = isTraining && outputCount > 1; + const hasInvStdDevOutput = isTraining && outputCount > 2; + const hasInputSkipBiasSumOutput = outputCount > 3; + const workgroupSize = 64; + + const components = getMaxComponents(hiddenSize); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: components }, + { type: DataType.uint32, data: hiddenSize }, + { type: DataType.float, data: attributes.epsilon }, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniformsArray: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'components', type: 'u32' }, + { name: 'hidden_size', type: 'u32' }, + { name: 'epsilon', type: 'f32' }, + ]; + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const vecDataType = tensorTypeToWsglStorageType(DataType.float, components); + return ` ${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)} var sum_shared : array<${vecDataType}, ${workgroupSize}>; var sum_squared_shared : array<${vecDataType}, ${workgroupSize}>; - ${shaderHelper.mainStart([ - workgroupSize, 1, 1 - ])} + ${shaderHelper.mainStart([workgroupSize, 1, 1])} let ix = local_id.x; let iy = global_id.x / ${workgroupSize}; @@ -171,7 +181,8 @@ const createSkipLayerNormProgramInfo = let square_sum = sum_squared_shared[0]; let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size); let inv_std_dev = inverseSqrt(${sumVector('square_sum', components)} / f32(uniforms.hidden_size) ${ - simplified ? '' : '- mean * mean'} + uniforms.epsilon); + simplified ? '' : '- mean * mean' + } + uniforms.epsilon); ${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''} ${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''} @@ -181,33 +192,33 @@ const createSkipLayerNormProgramInfo = ${hasBetaInput ? '+ beta[offset1d + i]' : ''}; } }`; - }; - const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; - if (outputCount > 1) { - outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); - } - if (outputCount > 2) { - outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); - } - if (outputCount > 3) { - outputs.push({dims: inputShape, dataType: inputs[0].dataType}); - } - return { - name: 'SkipLayerNormalization', - shaderCache: { - hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`, - inputDependencies: inputs.map((_input, _index) => 'type') - }, - getShaderSource, - getRunData: () => ({ - outputs, - dispatchGroup: { - x: Math.ceil(outputSize / hiddenSize), - }, - programUniforms - }), - }; - }; + }; + const outputs = [{ dims: outputShape, dataType: inputs[0].dataType }]; + if (outputCount > 1) { + outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float }); + } + if (outputCount > 2) { + outputs.push({ dims: meanInvStdDevDim, dataType: DataType.float }); + } + if (outputCount > 3) { + outputs.push({ dims: inputShape, dataType: inputs[0].dataType }); + } + return { + name: 'SkipLayerNormalization', + shaderCache: { + hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`, + inputDependencies: inputs.map((_input, _index) => 'type'), + }, + getShaderSource, + getRunData: () => ({ + outputs, + dispatchGroup: { + x: Math.ceil(outputSize / hiddenSize), + }, + programUniforms, + }), + }; +}; export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNormAttributes): void => { // TODO: initialize isTraining from ComputeContext @@ -225,6 +236,7 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm if (context.outputCount > 3) { outputs.push(3); } - context.compute( - createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs}); + context.compute(createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), { + outputs, + }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index a5e71f30e5966..5a837fd1e0bfa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -1,13 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; - -import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform, TensorInfo } from '../types'; + +import { + createTensorShapeVariables, + getElementAt, + IndicesHelper, + inputVariable, + outputVariable, + ShaderHelper, + UniformsArrayType, +} from './common'; export interface SliceAttributes extends AttributeWithCacheKey { readonly starts: number[]; @@ -37,9 +45,9 @@ const readInput = (inputs: readonly TensorView[], idx: number): number[] => { const input: number[] = []; if (inputs.length > idx) { if (inputs[idx].dataType === DataType.int64) { - inputs[idx].getBigInt64Array().forEach(v => input.push(Number(v))); + inputs[idx].getBigInt64Array().forEach((v) => input.push(Number(v))); } else if (inputs[idx].dataType === DataType.int32) { - inputs[idx].getInt32Array().forEach(v => input.push(Number(v))); + inputs[idx].getInt32Array().forEach((v) => input.push(Number(v))); } else { throw new Error(`Input ${idx} must be an array of int32 or int64`); } @@ -47,38 +55,47 @@ const readInput = (inputs: readonly TensorView[], idx: number): number[] => { return input; }; -const createSliceAttributesFromInputs = - (inputs: readonly TensorView[], attributes: SliceAttributes): SliceAttributes => { - if (inputs.length > 1) { - const starts: number[] = readInput(inputs, 1); - const ends: number[] = readInput(inputs, 2); - let axes: number[] = readInput(inputs, 3); - if (axes.length === 0) { - axes = [...Array(inputs[0].dims.length).keys()]; - } - return createAttributeWithCacheKey({starts, ends, axes}); - } else { - return attributes; - } - }; - -const fixStartEndValues = - (value: number, index: number, inputShape: readonly number[], axes: readonly number[], steps: readonly number[]): - number => { - let newValue = value; - if (value < 0) { - newValue += inputShape[axes[index]]; - } - if (steps[index] < 0) { - return Math.max(0, Math.min(newValue, inputShape[axes[index]] - 1)); - } else { - return Math.max(0, Math.min(newValue, inputShape[axes[index]])); - } - }; +const createSliceAttributesFromInputs = ( + inputs: readonly TensorView[], + attributes: SliceAttributes, +): SliceAttributes => { + if (inputs.length > 1) { + const starts: number[] = readInput(inputs, 1); + const ends: number[] = readInput(inputs, 2); + let axes: number[] = readInput(inputs, 3); + if (axes.length === 0) { + axes = [...Array(inputs[0].dims.length).keys()]; + } + return createAttributeWithCacheKey({ starts, ends, axes }); + } else { + return attributes; + } +}; + +const fixStartEndValues = ( + value: number, + index: number, + inputShape: readonly number[], + axes: readonly number[], + steps: readonly number[], +): number => { + let newValue = value; + if (value < 0) { + newValue += inputShape[axes[index]]; + } + if (steps[index] < 0) { + return Math.max(0, Math.min(newValue, inputShape[axes[index]] - 1)); + } else { + return Math.max(0, Math.min(newValue, inputShape[axes[index]])); + } +}; -const calculateInputIndicesImpl = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[]): string => - `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { +const calculateInputIndicesImpl = ( + input: IndicesHelper, + output: IndicesHelper, + inputShape: readonly number[], +): string => + `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { var input_indices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { @@ -101,12 +118,18 @@ const calculateInputIndicesImpl = const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); - const axes = (attributes.axes.length > 0) ? ShapeUtil.normalizeAxes(attributes.axes, inputShape.length) : - [...Array(inputShape.length).keys()]; + const axes = + attributes.axes.length > 0 + ? ShapeUtil.normalizeAxes(attributes.axes, inputShape.length) + : [...Array(inputShape.length).keys()]; let steps = readInput(inputs, 4); - steps.forEach((step) => step !== 0 || (() => { - throw new Error('step cannot be 0'); - })); + steps.forEach( + (step) => + step !== 0 || + (() => { + throw new Error('step cannot be 0'); + }), + ); if (steps.length === 0) { steps = Array(axes.length).fill(1); } @@ -127,7 +150,7 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice } } } - const signs = steps.map(step => Math.sign(step)); + const signs = steps.map((step) => Math.sign(step)); // Convert negative steps to positive steps and reverse starts and ends steps.forEach((step, i, array) => { if (step < 0) { @@ -144,20 +167,24 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice axes.forEach((axis, _) => { outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); }); - const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType}; + const outputTensorInfo: TensorInfo = { dims: outputShape, dataType: inputs[0].dataType }; const output = outputVariable('output', inputs[0].dataType, outputShape.length); const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length); const outputSize = ShapeUtil.size(outputShape); const uniforms: UniformsArrayType = [ - {name: 'outputSize', type: 'u32'}, {name: 'starts', type: 'u32', length: starts.length}, - {name: 'signs', type: 'i32', length: signs.length}, {name: 'steps', type: 'u32', length: steps.length} + { name: 'outputSize', type: 'u32' }, + { name: 'starts', type: 'u32', length: starts.length }, + { name: 'signs', type: 'i32', length: signs.length }, + { name: 'steps', type: 'u32', length: steps.length }, ]; const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts}, - {type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps}, - ...createTensorShapeVariables(inputs[0].dims, outputShape) + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: starts }, + { type: DataType.int32, data: signs }, + { type: DataType.uint32, data: steps }, + ...createTensorShapeVariables(inputs[0].dims, outputShape), ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -171,20 +198,20 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice }`; return { name: 'Slice', - shaderCache: {hint: `${signs.length}_${starts.length}_${steps.length}`, inputDependencies: ['rank']}, + shaderCache: { hint: `${signs.length}_${starts.length}_${steps.length}`, inputDependencies: ['rank'] }, getShaderSource, getRunData: () => ({ outputs: [outputTensorInfo], - dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, - programUniforms - }) + dispatchGroup: { x: Math.ceil(inputSize / 64 /* workgroup size */) }, + programUniforms, + }), }; }; export const slice = (context: ComputeContext, attributes: SliceAttributes): void => { validateInputs(context.inputs, attributes); const updatedAttributes = createSliceAttributesFromInputs(context.inputs, attributes); - context.compute(createSliceProgramInfo(context.inputs, updatedAttributes), {inputs: [0]}); + context.compute(createSliceProgramInfo(context.inputs, updatedAttributes), { inputs: [0] }); // if (ShapeUtil.size(program.outputs[0].dims) > 0) { // context.compute(programInfoLoader, {inputs: [0]}); // } else { @@ -197,5 +224,5 @@ export const parseSliceAttributes = (attributes: Record): Slice const starts = attributes.starts as number[]; const ends = attributes.ends as number[]; const axes = attributes.axes as number[]; - return createAttributeWithCacheKey({starts, ends, axes}); + return createAttributeWithCacheKey({ starts, ends, axes }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index b0e3ddd149656..c4e5a94f225da 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -5,13 +5,20 @@ // performance limitations when the reduced axis is long. Need to add // a optimized codepath for this. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; - -import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo } from '../types'; + +import { + getMaxComponents, + inputVariable, + outputVariable, + ShaderHelper, + sumVector, + tensorTypeToWsglStorageType, +} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { @@ -55,9 +62,10 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut const output = outputVariable('result', input.dataType, input.dims, components); const valueType = x.type.value; // 6.2.4 in wgsl spec - const threadMaxDecl = tensorTypeToWsglStorageType(input.dataType) === 'f32' ? - `var threadMax = ${valueType}(-3.402823e+38f);` : - `var threadMax = ${valueType}(-65504.0h);`; + const threadMaxDecl = + tensorTypeToWsglStorageType(input.dataType) === 'f32' + ? `var threadMax = ${valueType}(-3.402823e+38f);` + : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (shaderHelper: ShaderHelper) => ` var rowMaxShared : ${valueType}; var rowSumShared : ${valueType}; @@ -133,11 +141,11 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut }`; return { name: 'Softmax', - shaderCache: {hint: `${components}`, inputDependencies: ['type']}, + shaderCache: { hint: `${components}`, inputDependencies: ['type'] }, getRunData: () => ({ - outputs: [{dims: shape, dataType: input.dataType}], - dispatchGroup: {x: rows}, - programUniforms: [{type: DataType.int32, data: packedCols}] + outputs: [{ dims: shape, dataType: input.dataType }], + dispatchGroup: { x: rows }, + programUniforms: [{ type: DataType.int32, data: packedCols }], }), getShaderSource, }; @@ -149,4 +157,4 @@ export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): }; export const parseSoftmaxAttributes = (attributes: Record): SoftmaxAttributes => - createAttributeWithCacheKey({axis: attributes.axis as number}); + createAttributeWithCacheKey({ axis: attributes.axis as number }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index a09ac78b17006..3f8131be1c358 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -1,13 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform, TensorInfo } from '../types'; -import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import { + createTensorShapeVariables, + getElementAt, + IndicesHelper, + inputVariable, + outputVariable, + ShaderHelper, +} from './common'; export interface SplitAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -21,16 +28,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const createSplitAttributesFromInputs = - (inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => { - const splitSizes: number[] = []; - let numOutputs: number = attributes.numOutputs; - if (inputs[1].dims[0] > 0) { - inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); - numOutputs = splitSizes.length; - } - return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes}); - }; +const createSplitAttributesFromInputs = ( + inputs: readonly TensorView[], + attributes: SplitAttributes, +): SplitAttributes => { + const splitSizes: number[] = []; + let numOutputs: number = attributes.numOutputs; + if (inputs[1].dims[0] > 0) { + inputs[1].getBigInt64Array().forEach((v) => splitSizes.push(Number(v))); + numOutputs = splitSizes.length; + } + return createAttributeWithCacheKey({ numOutputs, axis: attributes.axis, splitSizes }); +}; const calculateOutputIndexImpl = (numberOfTensors: number): string => ` fn calculateOutputIndex(index: u32) -> u32 { @@ -73,7 +82,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; let previousSum = 0; - const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}]; + const programUniforms: ProgramUniform[] = [{ type: DataType.uint32, data: inputSize }]; for (let i = 0; i < attributes.numOutputs; i++) { previousSum += attributes.splitSizes[i]; sizeInSplitAxis[i] = previousSum; @@ -81,15 +90,17 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); - outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); + outputsTensorInfo.push({ dims: outputShapes[i], dataType: inputs[0].dataType }); } programUniforms.push( - {type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes)); + { type: DataType.uint32, data: sizeInSplitAxis }, + ...createTensorShapeVariables(inputShape, ...outputShapes), + ); const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${ - shaderHelper.registerUniform('input_size', 'u32') - .registerUniform('size_in_split_axis', 'u32', sizeInSplitAxis.length) - .declareVariables(input, ...outputs)} + ${shaderHelper + .registerUniform('input_size', 'u32') + .registerUniform('size_in_split_axis', 'u32', sizeInSplitAxis.length) + .declareVariables(input, ...outputs)} ${calculateOutputIndexImpl(sizeInSplitAxis.length)} ${writeBufferDataImpl(outputs)} @@ -107,29 +118,29 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split }`; return { name: 'Split', - shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']}, + shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank'] }, getShaderSource, getRunData: () => ({ outputs: outputsTensorInfo, - dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, - programUniforms - }) + dispatchGroup: { x: Math.ceil(inputSize / 64 /* workgroup size */) }, + programUniforms, + }), }; }; export const split = (context: ComputeContext, attributes: SplitAttributes): void => { validateInputs(context.inputs); const updatedAttributes = - context.inputs.length === 1 ? attributes : createSplitAttributesFromInputs(context.inputs, attributes); - context.compute(createSplitProgramInfo(context.inputs, updatedAttributes), {inputs: [0]}); + context.inputs.length === 1 ? attributes : createSplitAttributesFromInputs(context.inputs, attributes); + context.compute(createSplitProgramInfo(context.inputs, updatedAttributes), { inputs: [0] }); }; export const parseSplitAttributes = (attributes: Record): SplitAttributes => { const axis = attributes.axis as number; const splitSizes: number[] = attributes.splitSizes as number[]; - const numOutputs = attributes.numOutputs as number < 0 ? splitSizes.length : attributes.numOutputs as number; + const numOutputs = (attributes.numOutputs as number) < 0 ? splitSizes.length : (attributes.numOutputs as number); if (numOutputs !== splitSizes.length) { throw new Error('numOutputs and splitSizes lengh must be equal'); } - return createAttributeWithCacheKey({axis, numOutputs, splitSizes}); + return createAttributeWithCacheKey({ axis, numOutputs, splitSizes }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 5a8ecc0c63d86..328324ff5e167 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -1,24 +1,27 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo } from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common'; const getRepeats = (repeatsTensorView: TensorView): readonly number[] => - Array.from(repeatsTensorView.getBigInt64Array(), Number); - + Array.from(repeatsTensorView.getBigInt64Array(), Number); const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { throw new Error('Tile requires 2 inputs.'); } - if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16 && - inputs[0].dataType !== DataType.int32 && inputs[0].dataType !== DataType.uint32) { + if ( + inputs[0].dataType !== DataType.float && + inputs[0].dataType !== DataType.float16 && + inputs[0].dataType !== DataType.int32 && + inputs[0].dataType !== DataType.uint32 + ) { throw new Error('Tile only support float, float16, int32, and uint32 data types'); } @@ -75,12 +78,14 @@ export const createTileProgramInfo = (inputs: readonly TensorView[], shape?: num return { name: 'Tile', - shaderCache: {hint: `${repeats}`, inputDependencies: ['rank']}, + shaderCache: { hint: `${repeats}`, inputDependencies: ['rank'] }, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: - [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(inputs[0].dims, outputShape), + ], }), getShaderSource, }; @@ -88,5 +93,5 @@ export const createTileProgramInfo = (inputs: readonly TensorView[], shape?: num export const tile = (context: ComputeContext): void => { validateInputs(context.inputs); - context.compute(createTileProgramInfo(context.inputs), {inputs: [0]}); + context.compute(createTileProgramInfo(context.inputs), { inputs: [0] }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 8496173b1e8f8..4c1131477cd0f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo } from '../types'; -import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -20,10 +20,10 @@ const validateInputs = (inputs: readonly TensorView[]): void => { }; const getAdjustedPerm = (inputRank: number, perm: number[]): number[] => - (perm && perm.length !== inputRank) ? [...(new Array(inputRank).keys())].reverse() : perm; + perm && perm.length !== inputRank ? [...new Array(inputRank).keys()].reverse() : perm; const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] => - ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape.length, perm)); + ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape.length, perm)); const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, output: IndicesHelper): string => { const reverseFunc = []; @@ -82,14 +82,16 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu } return { name: 'Transpose', - shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, + shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] }, getRunData: (inputs) => { const outputSize = ShapeUtil.size(outputShape); return { - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: - [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(inputs[0].dims, outputShape), + ], }; }, getShaderSource, @@ -102,4 +104,4 @@ export const transpose = (context: ComputeContext, attributes: TransposeAttribut }; export const parseTransposeAttributes = (attributes: Record): TransposeAttributes => - createAttributeWithCacheKey({perm: attributes.perm as number[]}); + createAttributeWithCacheKey({ perm: attributes.perm as number[] }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 12ba2a10cdf9f..1fc2732f245a8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -1,34 +1,39 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { MAX_CLIP, MIN_CLIP, ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo } from '../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType} from './common'; +import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType } from './common'; type BuiltinFunctionName = string; type ElementwiseCustomExpression = (expression: string) => string; -type ElementwiseFunctionCall = BuiltinFunctionName|ElementwiseCustomExpression; - -const createElementwiseProgramShader = - (shaderHelper: ShaderHelper, datasize: number, inputDataType: number, outputDataType: number, - funcCall: ElementwiseFunctionCall, additionalImplementation?: string): string => { - const vecSize = Math.ceil(datasize / 4); - - let expression = ''; - if (typeof funcCall === 'string') { - expression = `${funcCall}(a)`; - } else { - expression = funcCall('a'); - } +type ElementwiseFunctionCall = BuiltinFunctionName | ElementwiseCustomExpression; + +const createElementwiseProgramShader = ( + shaderHelper: ShaderHelper, + datasize: number, + inputDataType: number, + outputDataType: number, + funcCall: ElementwiseFunctionCall, + additionalImplementation?: string, +): string => { + const vecSize = Math.ceil(datasize / 4); + + let expression = ''; + if (typeof funcCall === 'string') { + expression = `${funcCall}(a)`; + } else { + expression = funcCall('a'); + } - const input = inputVariable('inputData', inputDataType, [vecSize], 4); - const output = outputVariable('outputData', outputDataType, [vecSize], 4); + const input = inputVariable('inputData', inputDataType, [vecSize], 4); + const output = outputVariable('outputData', outputDataType, [vecSize], 4); - return ` + return ` ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} ${additionalImplementation ?? ''} @@ -39,24 +44,33 @@ const createElementwiseProgramShader = let a = ${input.getByOffset('global_idx')}; ${output.setByOffset('global_idx', expression)} }`; - }; - -const createElementwiseProgramInfo = - (input: TensorView, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, - cacheKey?: string, outputDataType: number = input.dataType): ProgramInfo => ({ - name, - shaderCache: {hint: cacheKey, inputDependencies: ['type']}, - getShaderSource: shaderHelper => createElementwiseProgramShader( - shaderHelper, ShapeUtil.size(input.dims), input.dataType, outputDataType, funcCall, additionalImplementation), - getRunData: (inputTensors) => ({ - outputs: [{dims: input.dims, dataType: outputDataType}], - dispatchGroup: - {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, - programUniforms: [ - {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, - ], - }) - }); +}; + +const createElementwiseProgramInfo = ( + input: TensorView, + name: string, + funcCall: ElementwiseFunctionCall, + additionalImplementation?: string, + cacheKey?: string, + outputDataType: number = input.dataType, +): ProgramInfo => ({ + name, + shaderCache: { hint: cacheKey, inputDependencies: ['type'] }, + getShaderSource: (shaderHelper) => + createElementwiseProgramShader( + shaderHelper, + ShapeUtil.size(input.dims), + input.dataType, + outputDataType, + funcCall, + additionalImplementation, + ), + getRunData: (inputTensors) => ({ + outputs: [{ dims: input.dims, dataType: outputDataType }], + dispatchGroup: { x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */) }, + programUniforms: [{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) }], + }), +}); export const abs = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Abs', 'abs')); @@ -91,8 +105,7 @@ export interface CastAttributes extends AttributeWithCacheKey { } export const parseCastAttributes = (attributes: Record): CastAttributes => - createAttributeWithCacheKey(attributes as {to: number}); - + createAttributeWithCacheKey(attributes as { to: number }); export const cast = (context: ComputeContext, attributes: CastAttributes): void => { let func: ElementwiseFunctionCall; @@ -116,7 +129,8 @@ export const cast = (context: ComputeContext, attributes: CastAttributes): void throw new RangeError(`not supported type (specified in attribute 'to' from 'Cast' operator): ${attributes.to}`); } context.compute( - createElementwiseProgramInfo(context.inputs[0], 'Cast', func, undefined, attributes.cacheKey, attributes.to)); + createElementwiseProgramInfo(context.inputs[0], 'Cast', func, undefined, attributes.cacheKey, attributes.to), + ); }; export interface ClipAttributes extends AttributeWithCacheKey { @@ -125,22 +139,27 @@ export interface ClipAttributes extends AttributeWithCacheKey { } const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = (inputs.length >= 2 && inputs[1].data !== 0) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = (inputs.length >= 3 && inputs[2].data !== 0) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; - return createAttributeWithCacheKey({min, max}); + const min = inputs.length >= 2 && inputs[1].data !== 0 ? inputs[1].getFloat32Array()[0] : MIN_CLIP; + const max = inputs.length >= 3 && inputs[2].data !== 0 ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + return createAttributeWithCacheKey({ min, max }); }; export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => { const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs); const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute( - createElementwiseProgramInfo( - context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` + createElementwiseProgramInfo( + context.inputs[0], + 'Clip', + (a) => `clamp(${a}, clip_min_, clip_max_)`, + ` const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min})); const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max})); `, - attributes.cacheKey), - {inputs: [0]}); + attributes.cacheKey, + ), + { inputs: [0] }, + ); }; export const ceil = (context: ComputeContext): void => { @@ -160,12 +179,16 @@ export interface AlphaAttributes extends AttributeWithCacheKey { } export const parseAlphaAttributes = (attributes: Record): AlphaAttributes => - createAttributeWithCacheKey(attributes as {alpha: number}); + createAttributeWithCacheKey(attributes as { alpha: number }); export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Elu', a => `elu_vf32(${a})`, ` + context.compute( + createElementwiseProgramInfo( + context.inputs[0], + 'Elu', + (a) => `elu_vf32(${a})`, + ` const elu_alpha_ = ${dataType}(${attributes.alpha}); fn elu_f32(a: ${dataType}) -> ${dataType} { @@ -175,7 +198,9 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void fn elu_vf32(v: vec4<${dataType}>) -> vec4<${dataType}> { return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w)); }`, - attributes.cacheKey)); + attributes.cacheKey, + ), + ); }; export const erfImpl = (varType = 'f32') => ` @@ -194,7 +219,7 @@ fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> { export const erf = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType))); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', (a) => `erf_vf32(${a})`, erfImpl(dataType))); }; export const exp = (context: ComputeContext): void => { @@ -207,37 +232,54 @@ export const floor = (context: ComputeContext): void => { export const gelu = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType))); + context.compute( + createElementwiseProgramInfo( + context.inputs[0], + 'Gelu', + (a) => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, + erfImpl(dataType), + ), + ); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`, - `const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`, attributes.cacheKey)); + context.compute( + createElementwiseProgramInfo( + context.inputs[0], + 'LeakyRelu', + (a) => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`, + `const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`, + attributes.cacheKey, + ), + ); }; export const not = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Not', a => `!${a}`)); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Not', (a) => `!${a}`)); }; export const neg = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Neg', a => `-${a}`)); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Neg', (a) => `-${a}`)); }; export const reciprocal = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Reciprocal', a => `1.0/${a}`)); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Reciprocal', (a) => `1.0/${a}`)); }; export const relu = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Relu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`)); + context.compute( + createElementwiseProgramInfo( + context.inputs[0], + 'Relu', + (a) => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`, + ), + ); }; export const sigmoid = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', (a) => `(1.0 / (1.0 + exp(-${a})))`)); }; export interface HardSigmoidAttributes extends AttributeWithCacheKey { @@ -246,18 +288,27 @@ export interface HardSigmoidAttributes extends AttributeWithCacheKey { } export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes => - createAttributeWithCacheKey(attributes as { + createAttributeWithCacheKey( + attributes as { alpha: number; beta: number; - }); + }, + ); export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'HardSigmoid', - a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ - attributes.beta})))`, - undefined, attributes.cacheKey)); + context.compute( + createElementwiseProgramInfo( + context.inputs[0], + 'HardSigmoid', + (a) => + `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ + attributes.beta + })))`, + undefined, + attributes.cacheKey, + ), + ); }; export const sin = (context: ComputeContext): void => { @@ -294,20 +345,33 @@ fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> { `; export const fastGeluExpression = (x: string) => - `(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`; + `(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`; export const fastGelu = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined, - context.inputs[0].dataType)); + context.compute( + createElementwiseProgramInfo( + context.inputs[0], + 'FastGelu', + fastGeluExpression, + fastGeluImpl(dataType), + undefined, + context.inputs[0].dataType, + ), + ); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'ThresholdedRelu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`, - `const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`, attributes.cacheKey)); + context.compute( + createElementwiseProgramInfo( + context.inputs[0], + 'ThresholdedRelu', + (a) => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`, + `const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`, + attributes.cacheKey, + ), + ); return 0; }; @@ -338,7 +402,14 @@ export const quickGeluExpression = (x: string) => `quick_gelu_impl(${x})`; export const quickgelu = (context: ComputeContext, attributes: AlphaAttributes): void => { const dType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'QuickGelu', quickGeluExpression, quickGeluImpl(dType, attributes.alpha), attributes.cacheKey, - context.inputs[0].dataType)); + context.compute( + createElementwiseProgramInfo( + context.inputs[0], + 'QuickGelu', + quickGeluExpression, + quickGeluImpl(dType, attributes.alpha), + attributes.cacheKey, + context.inputs[0].dataType, + ), + ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index a6375847fc42f..30ea6d011b7d0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -1,34 +1,39 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor-view'; -import {BroadcastUtil, ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo} from '../types'; +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { BroadcastUtil, ShapeUtil } from '../../util'; +import { ComputeContext, ProgramInfo } from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common'; -const createWhereOpProgramShader = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean, - typeOutput: number) => { - const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4); - const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4); - const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4); - const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4); +const createWhereOpProgramShader = ( + shaderHelper: ShaderHelper, + inputs: readonly TensorView[], + dimsOutput: readonly number[], + isBroadcast: boolean, + typeOutput: number, +) => { + const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4); + const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4); + const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4); + const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4); - let assignment: string; - const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; - if (!isBroadcast) { - assignment = output.setByOffset( - 'global_idx', - expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx'))); - } else { - const singleAssignment = (resStr: string, x: number, typeCast = '') => { - const expressionA = `a_data[index_a${x}][component_a${x}]`; - const expressionB = `b_data[index_b${x}][component_b${x}]`; - // eslint-disable-next-line no-bitwise - const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`; - return ` + let assignment: string; + const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; + if (!isBroadcast) { + assignment = output.setByOffset( + 'global_idx', + expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')), + ); + } else { + const singleAssignment = (resStr: string, x: number, typeCast = '') => { + const expressionA = `a_data[index_a${x}][component_a${x}]`; + const expressionB = `b_data[index_b${x}][component_b${x}]`; + // eslint-disable-next-line no-bitwise + const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`; + return ` let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)}; @@ -41,32 +46,32 @@ const createWhereOpProgramShader = let component_c${x} = offset_c${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; - }; - if (typeOutput === DataType.bool) { - assignment = ` + }; + if (typeOutput === DataType.bool) { + assignment = ` var data = vec4(0); ${singleAssignment('data', 0, 'u32')} ${singleAssignment('data', 1, 'u32')} ${singleAssignment('data', 2, 'u32')} ${singleAssignment('data', 3, 'u32')} output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; - } else { - assignment = ` + } else { + assignment = ` ${singleAssignment('output_data[global_idx]', 0)} ${singleAssignment('output_data[global_idx]', 1)} ${singleAssignment('output_data[global_idx]', 2)} ${singleAssignment('output_data[global_idx]', 3)} `; - } - } + } + } - return ` + return ` ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} ${assignment} }`; - }; +}; const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => { const dimsA = inputs[1].dims; @@ -82,7 +87,7 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => if (isBroadcast) { const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false); if (!calculatedShape) { - throw new Error('Can\'t perform where op on the given tensors'); + throw new Error("Can't perform where op on the given tensors"); } outputShape = calculatedShape; outputSize = ShapeUtil.size(outputShape); @@ -92,14 +97,16 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => return { name: 'Where', - shaderCache: {inputDependencies: ['rank', 'rank', 'rank']}, + shaderCache: { inputDependencies: ['rank', 'rank', 'rank'] }, getShaderSource: (shaderHelper) => - createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), + createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), getRunData: () => ({ - outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, - programUniforms: - [{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)], + outputs: [{ dims: outputShape, dataType: outputDataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */) }, + programUniforms: [ + { type: DataType.uint32, data: vecSize }, + ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape), + ], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index ccbcbe48505d6..c5b8f579c3aae 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; +import { TRACE_FUNC_BEGIN, TRACE_FUNC_END } from 'onnxruntime-common'; -import {WebGpuBackend} from '../backend-webgpu'; -import {LOG_DEBUG} from '../log'; +import { WebGpuBackend } from '../backend-webgpu'; +import { LOG_DEBUG } from '../log'; -import {createShaderHelper} from './ops/common'; -import {Artifact, GpuData, ProgramInfo} from './types'; +import { createShaderHelper } from './ops/common'; +import { Artifact, GpuData, ProgramInfo } from './types'; /** * ProgramManager is the main class behind running computations @@ -19,44 +19,52 @@ import {Artifact, GpuData, ProgramInfo} from './types'; * corresponding Location's in the binary program */ export class ProgramManager { - repo: Map; // this should be per-session object + repo: Map; // this should be per-session object attributesBound: boolean; constructor(private backend: WebGpuBackend) { this.repo = new Map(); this.attributesBound = false; } - getArtifact(key: unknown): Artifact|undefined { + getArtifact(key: unknown): Artifact | undefined { return this.repo.get(key); } setArtifact(key: unknown, artifact: Artifact): void { this.repo.set(key, artifact); } - run(buildArtifact: Artifact, inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number], - uniformBufferBinding: GPUBindingResource|undefined): void { + run( + buildArtifact: Artifact, + inputs: GpuData[], + outputs: GpuData[], + dispatchGroup: [number, number, number], + uniformBufferBinding: GPUBindingResource | undefined, + ): void { TRACE_FUNC_BEGIN(buildArtifact.programInfo.name); const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); const entries = []; for (const input of inputs) { - entries.push({binding: entries.length, resource: {buffer: input.buffer}}); + entries.push({ binding: entries.length, resource: { buffer: input.buffer } }); } for (const output of outputs) { - entries.push({binding: entries.length, resource: {buffer: output.buffer}}); + entries.push({ binding: entries.length, resource: { buffer: output.buffer } }); } if (uniformBufferBinding) { - entries.push({binding: entries.length, resource: uniformBufferBinding}); + entries.push({ binding: entries.length, resource: uniformBufferBinding }); } - const bindGroup = device.createBindGroup( - {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); + const bindGroup = device.createBindGroup({ + layout: buildArtifact.computePipeline.getBindGroupLayout(0), + entries, + label: buildArtifact.programInfo.name, + }); if (this.backend.sessionStatus === 'capturing') { const commandInfo = { kernelId: this.backend.currentKernelId!, computePipeline: buildArtifact.computePipeline, bindGroup, - dispatchGroup + dispatchGroup, }; const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); sessionCommandList!.push(commandInfo); @@ -68,8 +76,10 @@ export class ProgramManager { this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; - if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber || - this.backend.queryType === 'at-passes') { + if ( + this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber || + this.backend.queryType === 'at-passes' + ) { this.backend.endComputePass(); } if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) { @@ -90,21 +100,25 @@ export class ProgramManager { const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits); const userCode = programInfo.getShaderSource(shaderHelper); const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; - const shaderModule = device.createShaderModule({code, label: programInfo.name}); + const shaderModule = device.createShaderModule({ code, label: programInfo.name }); LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`); - const computePipeline = device.createComputePipeline( - {compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto', label: programInfo.name}); + const computePipeline = device.createComputePipeline({ + compute: { module: shaderModule, entryPoint: 'main' }, + layout: 'auto', + label: programInfo.name, + }); TRACE_FUNC_END(programInfo.name); - return {programInfo, computePipeline, uniformVariablesInfo: shaderHelper.variablesInfo}; + return { programInfo, computePipeline, uniformVariablesInfo: shaderHelper.variablesInfo }; } - normalizeDispatchGroupSize(dispatchGroup: ReturnType['dispatchGroup']): - [number, number, number] { + normalizeDispatchGroupSize( + dispatchGroup: ReturnType['dispatchGroup'], + ): [number, number, number] { const x = typeof dispatchGroup === 'number' ? dispatchGroup : dispatchGroup.x; - const y = typeof dispatchGroup === 'number' ? 1 : (dispatchGroup.y || 1); - const z = typeof dispatchGroup === 'number' ? 1 : (dispatchGroup.z || 1); + const y = typeof dispatchGroup === 'number' ? 1 : dispatchGroup.y || 1; + const z = typeof dispatchGroup === 'number' ? 1 : dispatchGroup.z || 1; const limitPerDimension = this.backend.device.limits.maxComputeWorkgroupsPerDimension; if (x <= limitPerDimension && y <= limitPerDimension && z <= limitPerDimension) { return [x, y, z]; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 2a584fc0a2218..776263b143be3 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -1,22 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../wasm-common'; -import {TensorView} from '../tensor-view'; +import { DataType } from '../../wasm-common'; +import { TensorView } from '../tensor-view'; -import {ShaderHelper} from './ops/common'; +import { ShaderHelper } from './ops/common'; -export type SessionState = 'default'|'capturing'|'replaying'; +export type SessionState = 'default' | 'capturing' | 'replaying'; export enum GpuDataType { default = 0, upload = 1, - profile = 2 + profile = 2, } export type GpuDataId = number; export type GpuArchitecture = 'ampere'; -export type GpuVendor = 'amd'|'intel'|'nvidia'; +export type GpuVendor = 'amd' | 'intel' | 'nvidia'; export interface AdapterInfo { isArchitecture: (architecture: GpuArchitecture) => boolean; isVendor: (vendor: GpuVendor) => boolean; @@ -35,7 +35,7 @@ export interface TensorInfo { export interface ProgramUniform { type: DataType; - data: number|readonly number[]; + data: number | readonly number[]; } export type ProgramUniformVariableInfo = [type: DataType, length: number]; @@ -49,7 +49,7 @@ export type ProgramUniformVariableInfo = [type: DataType, length: number]; * - 'dims': the shader/uniform depends on data type and the dims of this input * - 'data': the shader/uniform depends on data type, the dims and the data of this input */ -export type ProgramInputTensorInfoDependency = 'none'|'type'|'rank'|'dims'|'data'; +export type ProgramInputTensorInfoDependency = 'none' | 'type' | 'rank' | 'dims' | 'data'; /** * Represent information about a program's cache for shader. @@ -88,7 +88,6 @@ export interface ProgramUniformCacheInfo { inputDependencies?: ProgramInputTensorInfoDependency[]; } - /** * A set of data that represent a shader program */ @@ -119,7 +118,7 @@ export interface ProgramInfo { */ getRunData: (inputs: readonly TensorView[]) => { outputs: readonly TensorInfo[]; - dispatchGroup: {x: number; y?: number; z?: number}; + dispatchGroup: { x: number; y?: number; z?: number }; programUniforms?: readonly ProgramUniform[]; }; } @@ -127,7 +126,7 @@ export interface ProgramInfo { export interface Artifact { programInfo: ProgramInfo; computePipeline: GPUComputePipeline; - uniformVariablesInfo: readonly ProgramUniformVariableInfo[]|undefined; + uniformVariablesInfo: readonly ProgramUniformVariableInfo[] | undefined; } export interface ComputeContextInputsOutputsMapping { @@ -138,7 +137,7 @@ export interface ComputeContextInputsOutputsMapping { * * if inputs is not specified, the mapping will be the kernel's inputs in order. */ - readonly inputs?: ReadonlyArray; + readonly inputs?: ReadonlyArray; /** * specify the mapping to the program's outputs. the value must be a number. * - if it's a non-negative number, it's the index of the kernel's output @@ -174,7 +173,7 @@ export interface ComputeContext { /** * a custom data object that can be used to store any data that is needed by the kernel */ - readonly kernelCustomData: {[key: string]: unknown}; + readonly kernelCustomData: { [key: string]: unknown }; /** * a buffer that can be used to access custom data created each time the kernel is executed @@ -192,4 +191,4 @@ export interface ComputeContext { getMaxComputeWorkgroupStoragesize(): number; } -export type TimestampQuery = 'none'|'inside-passes'|'at-passes'; +export type TimestampQuery = 'none' | 'inside-passes' | 'at-passes'; diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 02246c9ee4767..8f3acdd582445 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -1,13 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import type {Env, InferenceSession, Tensor} from 'onnxruntime-common'; +import type { Env, InferenceSession, Tensor } from 'onnxruntime-common'; /** * Among all the tensor locations, only 'cpu' is serializable. */ -export type SerializableTensorMetadata = - [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu']; +export type SerializableTensorMetadata = [ + dataType: Tensor.Type, + dims: readonly number[], + data: Tensor.DataType, + location: 'cpu', +]; export type GpuBufferMetadata = { gpuBuffer: Tensor.GpuBufferType; @@ -19,8 +23,8 @@ export type GpuBufferMetadata = { * Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable. */ export type UnserializableTensorMetadata = - [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']| - [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; + | [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer'] + | [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; /** * Tensor metadata is a tuple of [dataType, dims, data, location], where @@ -32,7 +36,7 @@ export type UnserializableTensorMetadata = * - gpu-buffer: GpuBufferMetadata * - location: tensor data location */ -export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata; +export type TensorMetadata = SerializableTensorMetadata | UnserializableTensorMetadata; export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]]; @@ -44,38 +48,41 @@ interface MessageError { interface MessageInitWasm extends MessageError { type: 'init-wasm'; - in ?: Env; + in?: Env; out?: never; } interface MessageInitEp extends MessageError { type: 'init-ep'; - in ?: {env: Env; epName: string}; + in?: { env: Env; epName: string }; out?: never; } interface MessageCopyFromExternalBuffer extends MessageError { type: 'copy-from'; - in ?: {buffer: Uint8Array}; + in?: { buffer: Uint8Array }; out?: SerializableInternalBuffer; } interface MessageCreateSession extends MessageError { type: 'create'; - in ?: {model: SerializableInternalBuffer|Uint8Array; options?: InferenceSession.SessionOptions}; + in?: { model: SerializableInternalBuffer | Uint8Array; options?: InferenceSession.SessionOptions }; out?: SerializableSessionMetadata; } interface MessageReleaseSession extends MessageError { type: 'release'; - in ?: number; + in?: number; out?: never; } interface MessageRun extends MessageError { type: 'run'; - in ?: { - sessionId: number; inputIndices: number[]; inputs: SerializableTensorMetadata[]; outputIndices: number[]; + in?: { + sessionId: number; + inputIndices: number[]; + inputs: SerializableTensorMetadata[]; + outputIndices: number[]; options: InferenceSession.RunOptions; }; out?: SerializableTensorMetadata[]; @@ -83,9 +90,15 @@ interface MessageRun extends MessageError { interface MesssageEndProfiling extends MessageError { type: 'end-profiling'; - in ?: number; + in?: number; out?: never; } -export type OrtWasmMessage = MessageInitWasm|MessageInitEp|MessageCopyFromExternalBuffer|MessageCreateSession| - MessageReleaseSession|MessageRun|MesssageEndProfiling; +export type OrtWasmMessage = + | MessageInitWasm + | MessageInitEp + | MessageCopyFromExternalBuffer + | MessageCreateSession + | MessageReleaseSession + | MessageRun + | MesssageEndProfiling; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index ccd75ad16d3c0..163bac4eb676d 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -64,8 +64,8 @@ // declare global { type HTMLImageElement = unknown; - type HTMLScriptElement = {src?: string}; - const document: undefined|{currentScript?: HTMLScriptElement}; + type HTMLScriptElement = { src?: string }; + const document: undefined | { currentScript?: HTMLScriptElement }; } /** @@ -83,10 +83,19 @@ declare global { * This file will be always compiling into ESM format. */ -import type {OrtWasmMessage, SerializableTensorMetadata} from '../proxy-messages.js'; -import {createSession, copyFromExternalBuffer, endProfiling, extractTransferableBuffers, initEp, initRuntime, releaseSession, run} from '../wasm-core-impl.js'; -import {initializeWebAssembly} from '../wasm-factory.js'; -import {scriptSrc} from '../wasm-utils-import.js'; +import type { OrtWasmMessage, SerializableTensorMetadata } from '../proxy-messages.js'; +import { + createSession, + copyFromExternalBuffer, + endProfiling, + extractTransferableBuffers, + initEp, + initRuntime, + releaseSession, + run, +} from '../wasm-core-impl.js'; +import { initializeWebAssembly } from '../wasm-factory.js'; +import { scriptSrc } from '../wasm-utils-import.js'; const WORKER_NAME = 'ort-wasm-proxy-worker'; const isProxyWorker = globalThis.self?.name === WORKER_NAME; @@ -94,90 +103,92 @@ const isProxyWorker = globalThis.self?.name === WORKER_NAME; if (isProxyWorker) { // Worker thread self.onmessage = (ev: MessageEvent): void => { - const {type, in : message} = ev.data; + const { type, in: message } = ev.data; try { switch (type) { case 'init-wasm': - initializeWebAssembly(message!.wasm) - .then( - () => { - initRuntime(message!).then( - () => { - postMessage({type}); - }, - err => { - postMessage({type, err}); - }); - }, - err => { - postMessage({type, err}); - }); + initializeWebAssembly(message!.wasm).then( + () => { + initRuntime(message!).then( + () => { + postMessage({ type }); + }, + (err) => { + postMessage({ type, err }); + }, + ); + }, + (err) => { + postMessage({ type, err }); + }, + ); break; case 'init-ep': { - const {epName, env} = message!; - initEp(env, epName) - .then( - () => { - postMessage({type}); - }, - err => { - postMessage({type, err}); - }); + const { epName, env } = message!; + initEp(env, epName).then( + () => { + postMessage({ type }); + }, + (err) => { + postMessage({ type, err }); + }, + ); break; } case 'copy-from': { - const {buffer} = message!; + const { buffer } = message!; const bufferData = copyFromExternalBuffer(buffer); - postMessage({type, out: bufferData} as OrtWasmMessage); + postMessage({ type, out: bufferData } as OrtWasmMessage); break; } case 'create': { - const {model, options} = message!; - createSession(model, options) - .then( - sessionMetadata => { - postMessage({type, out: sessionMetadata} as OrtWasmMessage); - }, - err => { - postMessage({type, err}); - }); + const { model, options } = message!; + createSession(model, options).then( + (sessionMetadata) => { + postMessage({ type, out: sessionMetadata } as OrtWasmMessage); + }, + (err) => { + postMessage({ type, err }); + }, + ); break; } case 'release': releaseSession(message!); - postMessage({type}); + postMessage({ type }); break; case 'run': { - const {sessionId, inputIndices, inputs, outputIndices, options} = message!; - run(sessionId, inputIndices, inputs, outputIndices, new Array(outputIndices.length).fill(null), options) - .then( - outputs => { - if (outputs.some(o => o[3] !== 'cpu')) { - postMessage({type, err: 'Proxy does not support non-cpu tensor location.'}); - } else { - postMessage( - {type, out: outputs} as OrtWasmMessage, - extractTransferableBuffers([...inputs, ...outputs] as SerializableTensorMetadata[])); - } - }, - err => { - postMessage({type, err}); - }); + const { sessionId, inputIndices, inputs, outputIndices, options } = message!; + run(sessionId, inputIndices, inputs, outputIndices, new Array(outputIndices.length).fill(null), options).then( + (outputs) => { + if (outputs.some((o) => o[3] !== 'cpu')) { + postMessage({ type, err: 'Proxy does not support non-cpu tensor location.' }); + } else { + postMessage( + { type, out: outputs } as OrtWasmMessage, + extractTransferableBuffers([...inputs, ...outputs] as SerializableTensorMetadata[]), + ); + } + }, + (err) => { + postMessage({ type, err }); + }, + ); break; } case 'end-profiling': endProfiling(message!); - postMessage({type}); + postMessage({ type }); break; default: } } catch (err) { - postMessage({type, err} as OrtWasmMessage); + postMessage({ type, err } as OrtWasmMessage); } }; } -export default isProxyWorker ? - null : - (urlOverride?: string) => - new Worker(urlOverride ?? scriptSrc!, {type: BUILD_DEFS.IS_ESM ? 'module' : 'classic', name: WORKER_NAME}); +export default isProxyWorker + ? null + : (urlOverride?: string) => + new Worker(urlOverride ?? scriptSrc!, { type: BUILD_DEFS.IS_ESM ? 'module' : 'classic', name: WORKER_NAME }); diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 2dd8bfb0b6531..ada06cada8584 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -1,19 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession} from 'onnxruntime-common'; - -import {OrtWasmMessage, SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; +import { env, InferenceSession } from 'onnxruntime-common'; + +import { + OrtWasmMessage, + SerializableInternalBuffer, + SerializableSessionMetadata, + SerializableTensorMetadata, + TensorMetadata, +} from './proxy-messages'; import * as core from './wasm-core-impl'; -import {initializeWebAssembly} from './wasm-factory'; -import {importProxyWorker} from './wasm-utils-import'; +import { initializeWebAssembly } from './wasm-factory'; +import { importProxyWorker } from './wasm-utils-import'; const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined'; -let proxyWorker: Worker|undefined; +let proxyWorker: Worker | undefined; let initializing = false; let initialized = false; let aborted = false; -let temporaryObjectUrl: string|undefined; +let temporaryObjectUrl: string | undefined; type PromiseCallbacks = [resolve: (result: T) => void, reject: (reason: unknown) => void]; let initWasmCallbacks: PromiseCallbacks; @@ -68,16 +74,15 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { } }; - -export const initializeWebAssemblyAndOrtRuntime = async(): Promise => { +export const initializeWebAssemblyAndOrtRuntime = async (): Promise => { if (initialized) { return; } if (initializing) { - throw new Error('multiple calls to \'initWasm()\' detected.'); + throw new Error("multiple calls to 'initWasm()' detected."); } if (aborted) { - throw new Error('previous call to \'initWasm()\' failed.'); + throw new Error("previous call to 'initWasm()' failed."); } initializing = true; @@ -92,7 +97,7 @@ export const initializeWebAssemblyAndOrtRuntime = async(): Promise => { proxyWorker.onerror = (ev: ErrorEvent) => reject(ev); proxyWorker.onmessage = onProxyWorkerMessage; initWasmCallbacks = [resolve, reject]; - const message: OrtWasmMessage = {type: 'init-wasm', in : env}; + const message: OrtWasmMessage = { type: 'init-wasm', in: env }; proxyWorker.postMessage(message); temporaryObjectUrl = objectUrl; } catch (e) { @@ -100,7 +105,6 @@ export const initializeWebAssemblyAndOrtRuntime = async(): Promise => { } }, reject); }); - } else { try { await initializeWebAssembly(env.wasm); @@ -115,12 +119,12 @@ export const initializeWebAssemblyAndOrtRuntime = async(): Promise => { } }; -export const initializeOrtEp = async(epName: string): Promise => { +export const initializeOrtEp = async (epName: string): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('init-ep', [resolve, reject]); - const message: OrtWasmMessage = {type: 'init-ep', in : {epName, env}}; + const message: OrtWasmMessage = { type: 'init-ep', in: { epName, env } }; proxyWorker!.postMessage(message); }); } else { @@ -128,12 +132,12 @@ export const initializeOrtEp = async(epName: string): Promise => { } }; -export const copyFromExternalBuffer = async(buffer: Uint8Array): Promise => { +export const copyFromExternalBuffer = async (buffer: Uint8Array): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('copy-from', [resolve, reject]); - const message: OrtWasmMessage = {type: 'copy-from', in : {buffer}}; + const message: OrtWasmMessage = { type: 'copy-from', in: { buffer } }; proxyWorker!.postMessage(message, [buffer.buffer]); }); } else { @@ -141,35 +145,36 @@ export const copyFromExternalBuffer = async(buffer: Uint8Array): Promise => { - if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { - // check unsupported options - if (options?.preferredOutputLocation) { - throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); - } - ensureWorker(); - return new Promise((resolve, reject) => { - enqueueCallbacks('create', [resolve, reject]); - const message: OrtWasmMessage = {type: 'create', in : {model, options: {...options}}}; - const transferable: Transferable[] = []; - if (model instanceof Uint8Array) { - transferable.push(model.buffer); - } - proxyWorker!.postMessage(message, transferable); - }); - } else { - return core.createSession(model, options); - } - }; - -export const releaseSession = async(sessionId: number): Promise => { +export const createSession = async ( + model: SerializableInternalBuffer | Uint8Array, + options?: InferenceSession.SessionOptions, +): Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check unsupported options + if (options?.preferredOutputLocation) { + throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); + } + ensureWorker(); + return new Promise((resolve, reject) => { + enqueueCallbacks('create', [resolve, reject]); + const message: OrtWasmMessage = { type: 'create', in: { model, options: { ...options } } }; + const transferable: Transferable[] = []; + if (model instanceof Uint8Array) { + transferable.push(model.buffer); + } + proxyWorker!.postMessage(message, transferable); + }); + } else { + return core.createSession(model, options); + } +}; + +export const releaseSession = async (sessionId: number): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('release', [resolve, reject]); - const message: OrtWasmMessage = {type: 'release', in : sessionId}; + const message: OrtWasmMessage = { type: 'release', in: sessionId }; proxyWorker!.postMessage(message); }); } else { @@ -177,24 +182,31 @@ export const releaseSession = async(sessionId: number): Promise => { } }; -export const run = async( - sessionId: number, inputIndices: number[], inputs: TensorMetadata[], outputIndices: number[], - outputs: Array, options: InferenceSession.RunOptions): Promise => { +export const run = async ( + sessionId: number, + inputIndices: number[], + inputs: TensorMetadata[], + outputIndices: number[], + outputs: Array, + options: InferenceSession.RunOptions, +): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { // check inputs location - if (inputs.some(t => t[3] !== 'cpu')) { + if (inputs.some((t) => t[3] !== 'cpu')) { throw new Error('input tensor on GPU is not supported for proxy.'); } // check outputs location - if (outputs.some(t => t)) { + if (outputs.some((t) => t)) { throw new Error('pre-allocated output tensor is not supported for proxy.'); } ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('run', [resolve, reject]); - const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. - const message: OrtWasmMessage = - {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}}; + const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. + const message: OrtWasmMessage = { + type: 'run', + in: { sessionId, inputIndices, inputs: serializableInputs, outputIndices, options }, + }; proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs)); }); } else { @@ -202,12 +214,12 @@ export const run = async( } }; -export const endProfiling = async(sessionId: number): Promise => { +export const endProfiling = async (sessionId: number): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('end-profiling', [resolve, reject]); - const message: OrtWasmMessage = {type: 'end-profiling', in : sessionId}; + const message: OrtWasmMessage = { type: 'end-profiling', in: sessionId }; proxyWorker!.postMessage(message); }); } else { diff --git a/js/web/lib/wasm/run-options.ts b/js/web/lib/wasm/run-options.ts index 8fe230003413f..d15c8339b6824 100644 --- a/js/web/lib/wasm/run-options.ts +++ b/js/web/lib/wasm/run-options.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession} from 'onnxruntime-common'; +import { InferenceSession } from 'onnxruntime-common'; -import {getInstance} from './wasm-factory'; -import {allocWasmString, checkLastError, iterateExtraOptions} from './wasm-utils'; +import { getInstance } from './wasm-factory'; +import { allocWasmString, checkLastError, iterateExtraOptions } from './wasm-utils'; export const setRunOptions = (options: InferenceSession.RunOptions): [number, number[]] => { const wasm = getInstance(); @@ -15,15 +15,18 @@ export const setRunOptions = (options: InferenceSession.RunOptions): [number, nu try { if (options?.logSeverityLevel === undefined) { - runOptions.logSeverityLevel = 2; // Default to warning + runOptions.logSeverityLevel = 2; // Default to warning } else if ( - typeof options.logSeverityLevel !== 'number' || !Number.isInteger(options.logSeverityLevel) || - options.logSeverityLevel < 0 || options.logSeverityLevel > 4) { + typeof options.logSeverityLevel !== 'number' || + !Number.isInteger(options.logSeverityLevel) || + options.logSeverityLevel < 0 || + options.logSeverityLevel > 4 + ) { throw new Error(`log serverity level is not valid: ${options.logSeverityLevel}`); } if (options?.logVerbosityLevel === undefined) { - runOptions.logVerbosityLevel = 0; // Default to 0 + runOptions.logVerbosityLevel = 0; // Default to 0 } else if (typeof options.logVerbosityLevel !== 'number' || !Number.isInteger(options.logVerbosityLevel)) { throw new Error(`log verbosity level is not valid: ${options.logVerbosityLevel}`); } @@ -38,9 +41,13 @@ export const setRunOptions = (options: InferenceSession.RunOptions): [number, nu } runOptionsHandle = wasm._OrtCreateRunOptions( - runOptions.logSeverityLevel!, runOptions.logVerbosityLevel!, !!runOptions.terminate!, tagDataOffset); + runOptions.logSeverityLevel!, + runOptions.logVerbosityLevel!, + !!runOptions.terminate!, + tagDataOffset, + ); if (runOptionsHandle === 0) { - checkLastError('Can\'t create run options.'); + checkLastError("Can't create run options."); } if (options?.extra !== undefined) { @@ -59,7 +66,7 @@ export const setRunOptions = (options: InferenceSession.RunOptions): [number, nu if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); } - allocs.forEach(alloc => wasm._free(alloc)); + allocs.forEach((alloc) => wasm._free(alloc)); throw e; } }; diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts index eb77a6b00f11f..eff3e91389c98 100644 --- a/js/web/lib/wasm/session-handler-inference.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -1,20 +1,27 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; - -import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; -import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper'; -import {isGpuBufferSupportedType} from './wasm-common'; -import {isNode} from './wasm-utils-env'; -import {loadFile} from './wasm-utils-load-file'; +import { + InferenceSession, + InferenceSessionHandler, + SessionHandler, + Tensor, + TRACE_FUNC_BEGIN, + TRACE_FUNC_END, +} from 'onnxruntime-common'; + +import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; +import { copyFromExternalBuffer, createSession, endProfiling, releaseSession, run } from './proxy-wrapper'; +import { isGpuBufferSupportedType } from './wasm-common'; +import { isNode } from './wasm-utils-env'; +import { loadFile } from './wasm-utils-load-file'; export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { switch (tensor.location) { case 'cpu': return [tensor.type, tensor.dims, tensor.data, 'cpu']; case 'gpu-buffer': - return [tensor.type, tensor.dims, {gpuBuffer: tensor.gpuBuffer}, 'gpu-buffer']; + return [tensor.type, tensor.dims, { gpuBuffer: tensor.gpuBuffer }, 'gpu-buffer']; default: throw new Error(`invalid data location: ${tensor.location} for ${getName()}`); } @@ -29,8 +36,8 @@ export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { if (!isGpuBufferSupportedType(dataType)) { throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`); } - const {gpuBuffer, download, dispose} = tensor[2]; - return Tensor.fromGpuBuffer(gpuBuffer, {dataType, dims: tensor[1], download, dispose}); + const { gpuBuffer, download, dispose } = tensor[2]; + return Tensor.fromGpuBuffer(gpuBuffer, { dataType, dims: tensor[1], download, dispose }); } default: throw new Error(`invalid data location: ${tensor[3]}`); @@ -48,7 +55,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan return copyFromExternalBuffer(await loadFile(path)); } - async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { + async loadModel(pathOrBuffer: string | Uint8Array, options?: InferenceSession.SessionOptions): Promise { TRACE_FUNC_BEGIN(); let model: Parameters[0]; @@ -73,12 +80,15 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan return releaseSession(this.sessionId); } - async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): - Promise { + async run( + feeds: SessionHandler.FeedsType, + fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions, + ): Promise { TRACE_FUNC_BEGIN(); const inputArray: Tensor[] = []; const inputIndices: number[] = []; - Object.entries(feeds).forEach(kvp => { + Object.entries(feeds).forEach((kvp) => { const name = kvp[0]; const tensor = kvp[1]; const index = this.inputNames.indexOf(name); @@ -89,9 +99,9 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan inputIndices.push(index); }); - const outputArray: Array = []; + const outputArray: Array = []; const outputIndices: number[] = []; - Object.entries(fetches).forEach(kvp => { + Object.entries(fetches).forEach((kvp) => { const name = kvp[0]; const tensor = kvp[1]; const index = this.outputNames.indexOf(name); @@ -102,10 +112,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan outputIndices.push(index); }); - const inputs = - inputArray.map((t, i) => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); - const outputs = outputArray.map( - (t, i) => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + const inputs = inputArray.map((t, i) => + encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`), + ); + const outputs = outputArray.map((t, i) => + t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null, + ); const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index e35759192fe3c..8bbfb9cf06668 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,12 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; - -import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; -import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; -import {copyFromExternalBuffer} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, lazyResetGrad, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl'; +import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common'; + +import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; +import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference'; +import { copyFromExternalBuffer } from './wasm-core-impl'; +import { + createCheckpointHandle, + createTrainingSessionHandle, + getContiguousParameters, + getModelInputOutputNames, + getParametersSize, + lazyResetGrad, + loadParametersBuffer, + releaseTrainingSessionAndCheckpoint, + runEvalStep, + runOptimizerStep, + runTrainStep, +} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { private sessionId: number; @@ -18,7 +30,7 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes evalInputNames: string[] = []; evalOutputNames: string[] = []; - async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise { let buffer: Uint8Array; if (typeof uriOrBuffer === 'string') { const response = await fetch(uriOrBuffer); @@ -31,9 +43,12 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } async createTrainingSession( - checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, - evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, - options: InferenceSession.SessionOptions) { + checkpointStateUriOrBuffer: string | Uint8Array, + trainModelUriOrBuffer: string | Uint8Array, + evalModelUriOrBuffer: string | Uint8Array, + optimizerModelUriOrBuffer: string | Uint8Array, + options: InferenceSession.SessionOptions, + ) { const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer); // 0 is supposed to be the nullptr @@ -48,8 +63,13 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } this.checkpointId = createCheckpointHandle(checkpointData); - this.sessionId = - createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + this.sessionId = createTrainingSessionHandle( + this.checkpointId, + trainModelData, + evalModelData, + optimizerModelData, + options, + ); [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); if (evalModelUriOrBuffer !== '') { [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); @@ -65,10 +85,13 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes * @returns a tuple of a list of values and a list of indices. */ convertMapIntoValuesArrayAndIndicesArray( - feeds: {[name: string]: T}, names: string[], mapFunc: (val: T, index: number) => U): [T[], number[], U[]] { + feeds: { [name: string]: T }, + names: string[], + mapFunc: (val: T, index: number) => U, + ): [T[], number[], U[]] { const values: T[] = []; const indices: number[] = []; - Object.entries(feeds).forEach(kvp => { + Object.entries(feeds).forEach((kvp) => { const name = kvp[0]; const tensor = kvp[1]; const index = names.indexOf(name); @@ -94,7 +117,10 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes * @returns a map of output names and OnnxValues. */ convertTensorMetadataToReturnType( - results: TensorMetadata[], outputArray: Array, outputIndices: number[]): SessionHandler.ReturnType { + results: TensorMetadata[], + outputArray: Array, + outputIndices: number[], + ): SessionHandler.ReturnType { const resultMap: SessionHandler.ReturnType = {}; for (let i = 0; i < results.length; i++) { resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); @@ -107,17 +133,22 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } async runTrainStep( - feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions): Promise { + feeds: SessionHandler.FeedsType, + fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions, + ): Promise { const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, this.inputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); - - const [outputArray, outputIndices, outputs] = - this.convertMapIntoValuesArrayAndIndicesArray( - fetches, this.outputNames, - (t, i): TensorMetadata|null => - t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + feeds, + this.inputNames, + (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`), + ); + + const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< + Tensor | null, + TensorMetadata | null + >(fetches, this.outputNames, (t, i): TensorMetadata | null => + t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null, + ); const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); @@ -128,17 +159,22 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } async runEvalStep( - feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions): Promise { + feeds: SessionHandler.FeedsType, + fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions, + ): Promise { const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, this.evalInputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`)); - - const [outputArray, outputIndices, outputs] = - this.convertMapIntoValuesArrayAndIndicesArray( - fetches, this.evalOutputNames, - (t, i): TensorMetadata|null => - t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null); + feeds, + this.evalInputNames, + (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`), + ); + + const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< + Tensor | null, + TensorMetadata | null + >(fetches, this.evalOutputNames, (t, i): TensorMetadata | null => + t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null, + ); const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index f289fc20bba40..b2594267a595a 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession} from 'onnxruntime-common'; +import { InferenceSession } from 'onnxruntime-common'; -import {getInstance} from './wasm-factory'; -import {allocWasmString, checkLastError, iterateExtraOptions} from './wasm-utils'; +import { getInstance } from './wasm-factory'; +import { allocWasmString, checkLastError, iterateExtraOptions } from './wasm-utils'; -const getGraphOptimzationLevel = (graphOptimizationLevel: string|unknown): number => { +const getGraphOptimzationLevel = (graphOptimizationLevel: string | unknown): number => { switch (graphOptimizationLevel) { case 'disabled': return 0; @@ -21,7 +21,7 @@ const getGraphOptimzationLevel = (graphOptimizationLevel: string|unknown): numbe } }; -const getExecutionMode = (executionMode: 'sequential'|'parallel'): number => { +const getExecutionMode = (executionMode: 'sequential' | 'parallel'): number => { switch (executionMode) { case 'sequential': return 0; @@ -46,67 +46,68 @@ const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => } // if using JSEP with WebGPU, always disable memory pattern - if (options.executionProviders && - options.executionProviders.some(ep => (typeof ep === 'string' ? ep : ep.name) === 'webgpu')) { + if ( + options.executionProviders && + options.executionProviders.some((ep) => (typeof ep === 'string' ? ep : ep.name) === 'webgpu') + ) { options.enableMemPattern = false; } }; -const setExecutionProviders = - (sessionOptionsHandle: number, executionProviders: readonly InferenceSession.ExecutionProviderConfig[], - allocs: number[]): void => { - for (const ep of executionProviders) { - let epName = typeof ep === 'string' ? ep : ep.name; - - // check EP name - switch (epName) { - case 'webnn': - epName = 'WEBNN'; - if (typeof ep !== 'string') { - const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; - // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; - const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; - if (deviceType) { - const keyDataOffset = allocWasmString('deviceType', allocs); - const valueDataOffset = allocWasmString(deviceType, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); - } - } +const setExecutionProviders = ( + sessionOptionsHandle: number, + executionProviders: readonly InferenceSession.ExecutionProviderConfig[], + allocs: number[], +): void => { + for (const ep of executionProviders) { + let epName = typeof ep === 'string' ? ep : ep.name; + + // check EP name + switch (epName) { + case 'webnn': + epName = 'WEBNN'; + if (typeof ep !== 'string') { + const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; + // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; + const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; + if (deviceType) { + const keyDataOffset = allocWasmString('deviceType', allocs); + const valueDataOffset = allocWasmString(deviceType, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); } - break; - case 'webgpu': - epName = 'JS'; - if (typeof ep !== 'string') { - const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; - if (webgpuOptions?.preferredLayout) { - if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { - throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); - } - const keyDataOffset = allocWasmString('preferredLayout', allocs); - const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError( - `Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); - } - } + } + } + break; + case 'webgpu': + epName = 'JS'; + if (typeof ep !== 'string') { + const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; + if (webgpuOptions?.preferredLayout) { + if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { + throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); + } + const keyDataOffset = allocWasmString('preferredLayout', allocs); + const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError(`Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); } - break; - case 'wasm': - case 'cpu': - continue; - default: - throw new Error(`not supported execution provider: ${epName}`); + } } + break; + case 'wasm': + case 'cpu': + continue; + default: + throw new Error(`not supported execution provider: ${epName}`); + } - const epNameDataOffset = allocWasmString(epName, allocs); - if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) { - checkLastError(`Can't append execution provider: ${epName}.`); - } - } - }; + const epNameDataOffset = allocWasmString(epName, allocs); + if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) { + checkLastError(`Can't append execution provider: ${epName}.`); + } + } +}; export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => { const wasm = getInstance(); @@ -120,28 +121,37 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n const graphOptimizationLevel = getGraphOptimzationLevel(sessionOptions.graphOptimizationLevel ?? 'all'); const executionMode = getExecutionMode(sessionOptions.executionMode ?? 'sequential'); const logIdDataOffset = - typeof sessionOptions.logId === 'string' ? allocWasmString(sessionOptions.logId, allocs) : 0; + typeof sessionOptions.logId === 'string' ? allocWasmString(sessionOptions.logId, allocs) : 0; - const logSeverityLevel = sessionOptions.logSeverityLevel ?? 2; // Default to 2 - warning + const logSeverityLevel = sessionOptions.logSeverityLevel ?? 2; // Default to 2 - warning if (!Number.isInteger(logSeverityLevel) || logSeverityLevel < 0 || logSeverityLevel > 4) { throw new Error(`log serverity level is not valid: ${logSeverityLevel}`); } - const logVerbosityLevel = sessionOptions.logVerbosityLevel ?? 0; // Default to 0 - verbose + const logVerbosityLevel = sessionOptions.logVerbosityLevel ?? 0; // Default to 0 - verbose if (!Number.isInteger(logVerbosityLevel) || logVerbosityLevel < 0 || logVerbosityLevel > 4) { throw new Error(`log verbosity level is not valid: ${logVerbosityLevel}`); } - const optimizedModelFilePathOffset = typeof sessionOptions.optimizedModelFilePath === 'string' ? - allocWasmString(sessionOptions.optimizedModelFilePath, allocs) : - 0; + const optimizedModelFilePathOffset = + typeof sessionOptions.optimizedModelFilePath === 'string' + ? allocWasmString(sessionOptions.optimizedModelFilePath, allocs) + : 0; sessionOptionsHandle = wasm._OrtCreateSessionOptions( - graphOptimizationLevel, !!sessionOptions.enableCpuMemArena, !!sessionOptions.enableMemPattern, executionMode, - !!sessionOptions.enableProfiling, 0, logIdDataOffset, logSeverityLevel, logVerbosityLevel, - optimizedModelFilePathOffset); + graphOptimizationLevel, + !!sessionOptions.enableCpuMemArena, + !!sessionOptions.enableMemPattern, + executionMode, + !!sessionOptions.enableProfiling, + 0, + logIdDataOffset, + logSeverityLevel, + logVerbosityLevel, + optimizedModelFilePathOffset, + ); if (sessionOptionsHandle === 0) { - checkLastError('Can\'t create session options.'); + checkLastError("Can't create session options."); } if (sessionOptions.executionProviders) { @@ -156,7 +166,8 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { checkLastError( - `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`); + `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`, + ); } } @@ -191,7 +202,7 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n if (sessionOptionsHandle !== 0) { wasm._OrtReleaseSessionOptions(sessionOptionsHandle); } - allocs.forEach(alloc => wasm._free(alloc)); + allocs.forEach((alloc) => wasm._free(alloc)); throw e; } }; diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 54eaf5e0c43cc..1ef0630d04c8a 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from 'onnxruntime-common'; +import { Tensor } from 'onnxruntime-common'; // a dummy type declaration for Float16Array in case any polyfill is available. declare global { @@ -31,7 +31,7 @@ export const enum DataType { uint64 = 13, complex64 = 14, complex128 = 15, - bfloat16 = 16 + bfloat16 = 16, } /** @@ -112,50 +112,61 @@ export const tensorDataTypeEnumToString = (typeProto: DataType): Tensor.Type => * get tensor element size in bytes by the given data type * @returns size in integer or undefined if the data type is not supported */ -export const getTensorElementSize = (dateType: number): number| - undefined => [undefined, 4, 1, 1, 2, 2, 4, 8, undefined, 1, 2, 8, 4, 8, undefined, undefined, undefined][dateType]; +export const getTensorElementSize = (dateType: number): number | undefined => + [undefined, 4, 1, 1, 2, 2, 4, 8, undefined, 1, 2, 8, 4, 8, undefined, undefined, undefined][dateType]; /** * get typed array constructor by the given tensor type */ -export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32ArrayConstructor|Uint8ArrayConstructor| - Int8ArrayConstructor|Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor| - Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { - switch (type) { - case 'float16': - // allow Float16Array polyfill. - return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array; - case 'float32': - return Float32Array; - case 'uint8': - return Uint8Array; - case 'int8': - return Int8Array; - case 'uint16': - return Uint16Array; - case 'int16': - return Int16Array; - case 'int32': - return Int32Array; - case 'bool': - return Uint8Array; - case 'float64': - return Float64Array; - case 'uint32': - return Uint32Array; - case 'int64': - return BigInt64Array; - case 'uint64': - return BigUint64Array; - default: - throw new Error(`unsupported type: ${type}`); - } - }; +export const tensorTypeToTypedArrayConstructor = ( + type: Tensor.Type, +): + | Float32ArrayConstructor + | Uint8ArrayConstructor + | Int8ArrayConstructor + | Uint16ArrayConstructor + | Int16ArrayConstructor + | Int32ArrayConstructor + | BigInt64ArrayConstructor + | Uint8ArrayConstructor + | Float64ArrayConstructor + | Uint32ArrayConstructor + | BigUint64ArrayConstructor => { + switch (type) { + case 'float16': + // allow Float16Array polyfill. + return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array; + case 'float32': + return Float32Array; + case 'uint8': + return Uint8Array; + case 'int8': + return Int8Array; + case 'uint16': + return Uint16Array; + case 'int16': + return Int16Array; + case 'int32': + return Int32Array; + case 'bool': + return Uint8Array; + case 'float64': + return Float64Array; + case 'uint32': + return Uint32Array; + case 'int64': + return BigInt64Array; + case 'uint64': + return BigUint64Array; + default: + throw new Error(`unsupported type: ${type}`); + } +}; /** * Map string log level to integer value */ -export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal'): number => { +export const logLevelStringToEnum = (logLevel?: 'verbose' | 'info' | 'warning' | 'error' | 'fatal'): number => { switch (logLevel) { case 'verbose': return 0; @@ -175,9 +186,14 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro /** * Check whether the given tensor type is supported by GPU buffer */ -export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || - type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' || - type === 'bool'; +export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => + type === 'float32' || + type === 'float16' || + type === 'int32' || + type === 'int64' || + type === 'uint32' || + type === 'uint8' || + type === 'bool'; /** * Map string data location to integer value @@ -202,5 +218,5 @@ export const dataLocationStringToEnum = (location: Tensor.DataLocation): number /** * Map integer data location to string value */ -export const dataLocationEnumToString = (location: number): Tensor.DataLocation|undefined => - (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location]; +export const dataLocationEnumToString = (location: number): Tensor.DataLocation | undefined => + (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 9fc8786192c5c..8f72a8fcda1c3 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -6,15 +6,28 @@ // https://github.com/webmachinelearning/webnn/issues/677 /// -import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; - -import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; -import {setRunOptions} from './run-options'; -import {setSessionOptions} from './session-options'; -import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; -import {getInstance} from './wasm-factory'; -import {allocWasmString, checkLastError} from './wasm-utils'; -import {loadFile} from './wasm-utils-load-file'; +import { Env, InferenceSession, Tensor } from 'onnxruntime-common'; + +import { + SerializableInternalBuffer, + SerializableSessionMetadata, + SerializableTensorMetadata, + TensorMetadata, +} from './proxy-messages'; +import { setRunOptions } from './run-options'; +import { setSessionOptions } from './session-options'; +import { + dataLocationStringToEnum, + getTensorElementSize, + isGpuBufferSupportedType, + logLevelStringToEnum, + tensorDataTypeEnumToString, + tensorDataTypeStringToEnum, + tensorTypeToTypedArrayConstructor, +} from './wasm-common'; +import { getInstance } from './wasm-factory'; +import { allocWasmString, checkLastError } from './wasm-utils'; +import { loadFile } from './wasm-utils-load-file'; // #region Initializations @@ -69,7 +82,7 @@ import {loadFile} from './wasm-utils-load-file'; const initOrt = (numThreads: number, loggingLevel: number): void => { const errorCode = getInstance()._OrtInit(numThreads, loggingLevel); if (errorCode !== 0) { - checkLastError('Can\'t initialize onnxruntime.'); + checkLastError("Can't initialize onnxruntime."); } }; @@ -77,7 +90,7 @@ const initOrt = (numThreads: number, loggingLevel: number): void => { * initialize runtime environment. * @param env passed in the environment config object. */ -export const initRuntime = async(env: Env): Promise => { +export const initRuntime = async (env: Env): Promise => { // init ORT initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel)); }; @@ -88,7 +101,7 @@ export const initRuntime = async(env: Env): Promise => { * @param env * @param epName */ -export const initEp = async(env: Env, epName: string): Promise => { +export const initEp = async (env: Env, epName: string): Promise => { if (!BUILD_DEFS.DISABLE_JSEP) { // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires const initJsep = require('./jsep/init').init; @@ -103,24 +116,31 @@ export const initEp = async(env: Env, epName: string): Promise => { if (!adapter) { // if adapter is not set, request a new adapter. const powerPreference = env.webgpu.powerPreference; - if (powerPreference !== undefined && powerPreference !== 'low-power' && - powerPreference !== 'high-performance') { + if ( + powerPreference !== undefined && + powerPreference !== 'low-power' && + powerPreference !== 'high-performance' + ) { throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); } const forceFallbackAdapter = env.webgpu.forceFallbackAdapter; if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); } - adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); + adapter = await navigator.gpu.requestAdapter({ powerPreference, forceFallbackAdapter }); if (!adapter) { throw new Error( - 'Failed to get GPU adapter. ' + - 'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + 'Failed to get GPU adapter. ' + + 'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.', + ); } } else { // if adapter is set, validate it. - if (typeof adapter.limits !== 'object' || typeof adapter.features !== 'object' || - typeof adapter.requestDevice !== 'function') { + if ( + typeof adapter.limits !== 'object' || + typeof adapter.features !== 'object' || + typeof adapter.requestDevice !== 'function' + ) { throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.'); } } @@ -129,7 +149,7 @@ export const initEp = async(env: Env, epName: string): Promise => { } if (epName === 'webnn') { // perform WebNN availability check - if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) { + if (typeof navigator === 'undefined' || !(navigator as unknown as { ml: unknown }).ml) { throw new Error('WebNN is not supported in current environment'); } @@ -143,7 +163,7 @@ export const initEp = async(env: Env, epName: string): Promise => { /** * valid data locations for input/output tensors. */ -type SupportedTensorDataLocationForInputOutput = 'cpu'|'cpu-pinned'|'gpu-buffer'; +type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer'; type IOBindingState = { /** @@ -168,8 +188,12 @@ type IOBindingState = { * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState */ type SessionMetadata = [ - inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean + inferenceSessionId: number, + inputNamesUTF8Encoded: number[], + outputNamesUTF8Encoded: number[], + bindingState: IOBindingState | null, + enableGraphCapture: boolean, + inputOutputBound: boolean, ]; const activeSessions = new Map(); @@ -186,7 +210,7 @@ const getSessionInputOutputCount = (sessionHandle: number): [number, number] => const dataOffset = wasm.stackAlloc(8); const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4); if (errorCode !== 0) { - checkLastError('Can\'t get session input/output count.'); + checkLastError("Can't get session input/output count."); } return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; } finally { @@ -218,9 +242,10 @@ export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => { * @param options an optional session options object. * @returns a 3-elements tuple containing [session handle, input names, output names] */ -export const createSession = async( - modelData: Uint8Array|SerializableInternalBuffer, - options?: InferenceSession.SessionOptions): Promise => { +export const createSession = async ( + modelData: Uint8Array | SerializableInternalBuffer, + options?: InferenceSession.SessionOptions, +): Promise => { let modelDataOffset: number, modelDataLength: number; const wasm = getInstance(); @@ -249,9 +274,11 @@ export const createSession = async( const loadingPromises = []; for (const file of options.externalData) { const path = typeof file === 'string' ? file : file.path; - loadingPromises.push(loadFile(typeof file === 'string' ? file : file.data).then(data => { - wasm.mountExternalData!(path, data); - })); + loadingPromises.push( + loadFile(typeof file === 'string' ? file : file.data).then((data) => { + wasm.mountExternalData!(path, data); + }), + ); } // wait for all external data files to be loaded @@ -276,7 +303,7 @@ export const createSession = async( } else if (gpuDevice) { wasm.currentContext = await navigator.ml.createContext(gpuDevice); } else { - wasm.currentContext = await navigator.ml.createContext({deviceType, numThreads, powerPreference}); + wasm.currentContext = await navigator.ml.createContext({ deviceType, numThreads, powerPreference }); } } else { wasm.currentContext = await navigator.ml.createContext(); @@ -287,7 +314,7 @@ export const createSession = async( sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { - checkLastError('Can\'t create a session.'); + checkLastError("Can't create a session."); } // clear current MLContext after session creation @@ -305,7 +332,7 @@ export const createSession = async( for (let i = 0; i < inputCount; i++) { const name = wasm._OrtGetInputName(sessionHandle, i); if (name === 0) { - checkLastError('Can\'t get an input name.'); + checkLastError("Can't get an input name."); } inputNamesUTF8Encoded.push(name); inputNames.push(wasm.UTF8ToString(name)); @@ -313,7 +340,7 @@ export const createSession = async( for (let i = 0; i < outputCount; i++) { const name = wasm._OrtGetOutputName(sessionHandle, i); if (name === 0) { - checkLastError('Can\'t get an output name.'); + checkLastError("Can't get an output name."); } outputNamesUTF8Encoded.push(name); const nameString = wasm.UTF8ToString(name); @@ -324,42 +351,51 @@ export const createSession = async( outputPreferredLocations.push('gpu-buffer'); continue; } - const location = typeof options?.preferredOutputLocation === 'string' ? - options.preferredOutputLocation : - options?.preferredOutputLocation?.[nameString] ?? 'cpu'; + const location = + typeof options?.preferredOutputLocation === 'string' + ? options.preferredOutputLocation + : (options?.preferredOutputLocation?.[nameString] ?? 'cpu'); if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { throw new Error(`Not supported preferred output location: ${location}.`); } if (enableGraphCapture && location !== 'gpu-buffer') { - throw new Error(`Not supported preferred output location: ${ - location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`); + throw new Error( + `Not supported preferred output location: ${ + location + }. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`, + ); } outputPreferredLocations.push(location); } } // use IO binding only when at least one output is preffered to be on GPU. - let bindingState: IOBindingState|null = null; - if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some(l => l === 'gpu-buffer')) { + let bindingState: IOBindingState | null = null; + if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer')) { ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); if (ioBindingHandle === 0) { - checkLastError('Can\'t create IO binding.'); + checkLastError("Can't create IO binding."); } bindingState = { handle: ioBindingHandle, outputPreferredLocations, - outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)), + outputPreferredLocationsEncoded: outputPreferredLocations.map((l) => dataLocationStringToEnum(l)), }; } - activeSessions.set( - sessionHandle, - [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]); + activeSessions.set(sessionHandle, [ + sessionHandle, + inputNamesUTF8Encoded, + outputNamesUTF8Encoded, + bindingState, + enableGraphCapture, + false, + ]); return [sessionHandle, inputNames, outputNames]; } catch (e) { - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); if (ioBindingHandle !== 0) { wasm._OrtReleaseBinding(ioBindingHandle); @@ -374,7 +410,7 @@ export const createSession = async( if (sessionOptionsHandle !== 0) { wasm._OrtReleaseSessionOptions(sessionOptionsHandle); } - allocs.forEach(alloc => wasm._free(alloc)); + allocs.forEach((alloc) => wasm._free(alloc)); // unmount external data if necessary wasm.unmountExternalData?.(); @@ -398,94 +434,110 @@ export const releaseSession = (sessionId: number): void => { wasm.jsepOnReleaseSession?.(sessionId); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); wasm._OrtReleaseSession(sessionHandle); activeSessions.delete(sessionId); }; -export const prepareInputOutputTensor = - (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, - enableGraphCapture = false): void => { - if (!tensor) { - tensorHandles.push(0); - return; - } +export const prepareInputOutputTensor = ( + tensor: TensorMetadata | null, + tensorHandles: number[], + allocs: number[], + sessionId: number, + index: number, + enableGraphCapture = false, +): void => { + if (!tensor) { + tensorHandles.push(0); + return; + } - const wasm = getInstance(); + const wasm = getInstance(); - const dataType = tensor[0]; - const dims = tensor[1]; - const location = tensor[3]; + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; - let rawData: number; - let dataByteLength: number; + let rawData: number; + let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { - throw new Error('String tensor is not supported on GPU.'); - } + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } - if (enableGraphCapture && location !== 'gpu-buffer') { - throw new Error( - `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`); - } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error( + `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`, + ); + } - if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; - dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - const registerBuffer = wasm.jsepRegisterBuffer; - if (!registerBuffer) { - throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); - } - rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); - } else { - const data = tensor[2]; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - let dataIndex = rawData / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); - } - } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); - } - } + const registerBuffer = wasm.jsepRegisterBuffer; + if (!registerBuffer) { + throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + } + rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location)); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } - tensorHandles.push(tensor); - } finally { - wasm.stackRestore(stack); + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); } - }; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach((d) => (wasm.HEAP32[dimIndex++] = d)); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), + rawData, + dataByteLength, + dimsOffset, + dims.length, + dataLocationStringToEnum(location), + ); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } +}; /** * perform inference run */ -export const run = async( - sessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], - outputTensors: Array, options: InferenceSession.RunOptions): Promise => { +export const run = async ( + sessionId: number, + inputIndices: number[], + inputTensors: TensorMetadata[], + outputIndices: number[], + outputTensors: Array, + options: InferenceSession.RunOptions, +): Promise => { const wasm = getInstance(); const session = activeSessions.get(sessionId); if (!session) { @@ -520,14 +572,25 @@ export const run = async( // create input tensors for (let i = 0; i < inputCount; i++) { prepareInputOutputTensor( - inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture); + inputTensors[i], + inputTensorHandles, + inputOutputAllocs, + sessionId, + inputIndices[i], + enableGraphCapture, + ); } // create output tensors for (let i = 0; i < outputCount; i++) { prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], - enableGraphCapture); + outputTensors[i], + outputTensorHandles, + inputOutputAllocs, + sessionId, + inputCount + outputIndices[i], + enableGraphCapture, + ); } let inputValuesIndex = inputValuesOffset / 4; @@ -544,11 +607,14 @@ export const run = async( } if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) { - const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; + const { handle, outputPreferredLocations, outputPreferredLocationsEncoded } = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { - throw new Error(`input count from feeds (${ - inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); + throw new Error( + `input count from feeds (${ + inputCount + }) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`, + ); } // process inputs @@ -563,7 +629,7 @@ export const run = async( // process pre-allocated outputs for (let i = 0; i < outputCount; i++) { const index = outputIndices[i]; - const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. + const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. if (location) { // output is pre-allocated. bind the tensor. @@ -573,27 +639,48 @@ export const run = async( } } else { // output is not pre-allocated. reset preferred location. - const errorCode = - wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); + const errorCode = wasm._OrtBindOutput( + handle, + outputNamesUTF8Encoded[index], + 0, + outputPreferredLocationsEncoded[index], + ); if (errorCode !== 0) { checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); } } } - activeSessions.set( - sessionId, - [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); + activeSessions.set(sessionId, [ + sessionHandle, + inputNamesUTF8Encoded, + outputNamesUTF8Encoded, + ioBindingState, + enableGraphCapture, + true, + ]); } wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( - sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); + sessionHandle, + ioBindingState.handle, + outputCount, + outputValuesOffset, + runOptionsHandle, + ); } else { errorCode = await wasm._OrtRun( - sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, - outputValuesOffset, runOptionsHandle); + sessionHandle, + inputNamesOffset, + inputValuesOffset, + inputCount, + outputNamesOffset, + outputCount, + outputValuesOffset, + runOptionsHandle, + ); } if (errorCode !== 0) { @@ -615,10 +702,16 @@ export const run = async( const tensorDataOffset = wasm.stackAlloc(4 * 4); let keepOutputTensor = false; - let type: Tensor.Type|undefined, dataOffset = 0; + let type: Tensor.Type | undefined, + dataOffset = 0; try { const errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + tensor, + tensorDataOffset, + tensorDataOffset + 4, + tensorDataOffset + 8, + tensorDataOffset + 12, + ); if (errorCode !== 0) { checkLastError(`Can't access output tensor data on index ${i}.`); } @@ -668,20 +761,23 @@ export const run = async( keepOutputTensor = true; output.push([ - type, dims, { + type, + dims, + { gpuBuffer, download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type), dispose: () => { wasm._OrtReleaseTensor(tensor); - } + }, }, - 'gpu-buffer' + 'gpu-buffer', ]); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength) - .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( + wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength), + ); output.push([type, dims, data, 'cpu']); } } @@ -698,22 +794,27 @@ export const run = async( if (ioBindingState && !enableGraphCapture) { wasm._OrtClearBoundOutputs(ioBindingState.handle); - activeSessions.set( - sessionId, - [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); + activeSessions.set(sessionId, [ + sessionHandle, + inputNamesUTF8Encoded, + outputNamesUTF8Encoded, + ioBindingState, + enableGraphCapture, + false, + ]); } return output; } finally { wasm.stackRestore(beforeRunStack); - inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach(p => wasm._free(p)); + inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach((p) => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); } - runOptionsAllocs.forEach(p => wasm._free(p)); + runOptionsAllocs.forEach((p) => wasm._free(p)); } }; @@ -731,7 +832,7 @@ export const endProfiling = (sessionId: number): void => { // profile file name is not used yet, but it must be freed. const profileFileName = wasm._OrtEndProfiling(sessionHandle); if (profileFileName === 0) { - checkLastError('Can\'t get an profile file name.'); + checkLastError("Can't get an profile file name."); } wasm._OrtFree(profileFileName); }; diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 0f5f10716a00b..316adf6706074 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from 'onnxruntime-common'; +import { Env } from 'onnxruntime-common'; -import type {OrtWasmModule} from './wasm-types'; -import {importWasmModule} from './wasm-utils-import'; +import type { OrtWasmModule } from './wasm-types'; +import { importWasmModule } from './wasm-utils-import'; -let wasm: OrtWasmModule|undefined; +let wasm: OrtWasmModule | undefined; let initialized = false; let initializing = false; let aborted = false; @@ -26,10 +26,12 @@ const isMultiThreadSupported = (): boolean => { // Test for WebAssembly threads capability (for both browsers and Node.js) // This typed array is a WebAssembly program containing threaded instructions. - return WebAssembly.validate(new Uint8Array([ - 0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 5, - 4, 1, 3, 1, 1, 10, 11, 1, 9, 0, 65, 0, 254, 16, 2, 0, 26, 11 - ])); + return WebAssembly.validate( + new Uint8Array([ + 0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 5, 4, 1, 3, 1, 1, 10, 11, 1, 9, 0, 65, 0, 254, 16, + 2, 0, 26, 11, + ]), + ); } catch (e) { return false; } @@ -51,24 +53,26 @@ const isSimdSupported = (): boolean => { // (i32.const 0)) // (v128.const i32x4 0x00000000 0x00000000 0x00000000 0x00000000))))) - return WebAssembly.validate(new Uint8Array([ - 0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 10, 30, 1, 28, 0, 65, 0, - 253, 15, 253, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253, 186, 1, 26, 11 - ])); + return WebAssembly.validate( + new Uint8Array([ + 0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 10, 30, 1, 28, 0, 65, 0, 253, 15, 253, 12, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253, 186, 1, 26, 11, + ]), + ); } catch (e) { return false; } }; -export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise => { +export const initializeWebAssembly = async (flags: Env.WebAssemblyFlags): Promise => { if (initialized) { return Promise.resolve(); } if (initializing) { - throw new Error('multiple calls to \'initializeWebAssembly()\' detected.'); + throw new Error("multiple calls to 'initializeWebAssembly()' detected."); } if (aborted) { - throw new Error('previous call to \'initializeWebAssembly()\' failed.'); + throw new Error("previous call to 'initializeWebAssembly()' failed."); } initializing = true; @@ -88,15 +92,17 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise if (typeof self !== 'undefined' && !self.crossOriginIsolated) { // eslint-disable-next-line no-console console.warn( - 'env.wasm.numThreads is set to ' + numThreads + + 'env.wasm.numThreads is set to ' + + numThreads + ', but this will not work unless you enable crossOriginIsolated mode. ' + - 'See https://web.dev/cross-origin-isolation-guide/ for more info.'); + 'See https://web.dev/cross-origin-isolation-guide/ for more info.', + ); } // eslint-disable-next-line no-console console.warn( - 'WebAssembly multi-threading is not supported in the current environment. ' + - 'Falling back to single-threading.'); + 'WebAssembly multi-threading is not supported in the current environment. ' + 'Falling back to single-threading.', + ); // set flags.numThreads to 1 so that OrtInit() will not create a global thread pool. flags.numThreads = numThreads = 1; @@ -110,7 +116,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const wasmPathOverride = (wasmPathOverrideFlag as URL)?.href ?? wasmPathOverrideFlag; const wasmBinaryOverride = flags.wasmBinary; - const [objectUrl, ortWasmFactory] = (await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1)); + const [objectUrl, ortWasmFactory] = await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1); let isTimeout = false; @@ -118,42 +124,45 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise // promise for timeout if (timeout > 0) { - tasks.push(new Promise((resolve) => { - setTimeout(() => { - isTimeout = true; - resolve(); - }, timeout); - })); + tasks.push( + new Promise((resolve) => { + setTimeout(() => { + isTimeout = true; + resolve(); + }, timeout); + }), + ); } // promise for module initialization - tasks.push(new Promise((resolve, reject) => { - const config: Partial = { - /** - * The number of threads. WebAssembly will create (Module.numThreads - 1) workers. If it is 1, no worker will be - * created. - */ - numThreads, - }; - - if (wasmBinaryOverride) { - /** - * Set a custom buffer which contains the WebAssembly binary. This will skip the wasm file fetching. - */ - config.wasmBinary = wasmBinaryOverride; - } else if (wasmPathOverride || wasmPrefixOverride) { - /** - * A callback function to locate the WebAssembly file. The function should return the full path of the file. - * - * Since Emscripten 3.1.58, this function is only called for the .wasm file. - */ - config.locateFile = (fileName, scriptDirectory) => + tasks.push( + new Promise((resolve, reject) => { + const config: Partial = { + /** + * The number of threads. WebAssembly will create (Module.numThreads - 1) workers. If it is 1, no worker will be + * created. + */ + numThreads, + }; + + if (wasmBinaryOverride) { + /** + * Set a custom buffer which contains the WebAssembly binary. This will skip the wasm file fetching. + */ + config.wasmBinary = wasmBinaryOverride; + } else if (wasmPathOverride || wasmPrefixOverride) { + /** + * A callback function to locate the WebAssembly file. The function should return the full path of the file. + * + * Since Emscripten 3.1.58, this function is only called for the .wasm file. + */ + config.locateFile = (fileName, scriptDirectory) => wasmPathOverride ?? (wasmPrefixOverride ?? scriptDirectory) + fileName; - } + } - ortWasmFactory(config).then( + ortWasmFactory(config).then( // wasm module initialized successfully - module => { + (module) => { initializing = false; initialized = true; wasm = module; @@ -167,8 +176,10 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise initializing = false; aborted = true; reject(what); - }); - })); + }, + ); + }), + ); await Promise.race(tasks); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index c65178e2358d2..22cd6ec30732c 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -1,20 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, Tensor} from 'onnxruntime-common'; - -import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; -import {setRunOptions} from './run-options'; -import {setSessionOptions} from './session-options'; -import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; -import {prepareInputOutputTensor} from './wasm-core-impl'; -import {getInstance} from './wasm-factory'; -import {checkLastError} from './wasm-utils'; +import { InferenceSession, Tensor } from 'onnxruntime-common'; + +import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; +import { setRunOptions } from './run-options'; +import { setSessionOptions } from './session-options'; +import { + dataLocationStringToEnum, + tensorDataTypeEnumToString, + tensorDataTypeStringToEnum, + tensorTypeToTypedArrayConstructor, +} from './wasm-common'; +import { prepareInputOutputTensor } from './wasm-core-impl'; +import { getInstance } from './wasm-factory'; +import { checkLastError } from './wasm-utils'; const NO_TRAIN_FUNCS_MSG = - 'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' + - 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + - 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; + "Built without training API's enabled. Use the onnxruntime-web/training import for training " + + 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; /** * Runs the checkLastError function which will throw an error, if the provided error code matches the specified @@ -64,9 +69,13 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea try { const dataOffset = wasm.stackAlloc(8); if (wasm._OrtTrainingGetModelInputOutputCount) { - const errorCode = - wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); - ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.'); + const errorCode = wasm._OrtTrainingGetModelInputOutputCount( + trainingSessionId, + dataOffset, + dataOffset + 4, + isEvalModel, + ); + ifErrCodeCheckLastError(errorCode, "Can't get session input/output count."); return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); @@ -76,24 +85,28 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea } }; -const getModelInputOutputNamesLoop = - (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): string[] => { - const names = []; - const wasm = getInstance(); +const getModelInputOutputNamesLoop = ( + trainingSessionId: number, + count: number, + isInput: boolean, + isEvalModel: boolean, +): string[] => { + const names = []; + const wasm = getInstance(); - for (let i = 0; i < count; i++) { - if (wasm._OrtTrainingGetModelInputOutputName) { - const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); - ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetModelInputOutputName) { + const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); + ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); - names.push(wasm.UTF8ToString(name)); - wasm._free(name); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } - return names; - }; + names.push(wasm.UTF8ToString(name)); + wasm._free(name); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } + return names; +}; export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { let inputNames: string[] = []; @@ -107,43 +120,54 @@ export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: return [inputNames, outputNames]; }; -export const createTrainingSessionHandle = - (checkpointHandle: number, trainModelData: SerializableInternalBuffer, evalModelData: SerializableInternalBuffer, - optimizerModelData: SerializableInternalBuffer, options: InferenceSession.SessionOptions): number => { - const wasm = getInstance(); - - let trainingSessionHandle = 0; - let sessionOptionsHandle = 0; - let allocs: number[] = []; - - try { - [sessionOptionsHandle, allocs] = setSessionOptions(options); - if (wasm._OrtTrainingCreateSession) { - trainingSessionHandle = wasm._OrtTrainingCreateSession( - sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], - evalModelData[1], optimizerModelData[0], optimizerModelData[1]); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } +export const createTrainingSessionHandle = ( + checkpointHandle: number, + trainModelData: SerializableInternalBuffer, + evalModelData: SerializableInternalBuffer, + optimizerModelData: SerializableInternalBuffer, + options: InferenceSession.SessionOptions, +): number => { + const wasm = getInstance(); - ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); - return trainingSessionHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { - wasm._OrtTrainingReleaseSession(trainingSessionHandle); - } - throw e; - } finally { - wasm._free(trainModelData[0]); - wasm._free(evalModelData[0]); - wasm._free(optimizerModelData[0]); - - if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); - } - allocs.forEach(alloc => wasm._free(alloc)); - } - }; + let trainingSessionHandle = 0; + let sessionOptionsHandle = 0; + let allocs: number[] = []; + + try { + [sessionOptionsHandle, allocs] = setSessionOptions(options); + if (wasm._OrtTrainingCreateSession) { + trainingSessionHandle = wasm._OrtTrainingCreateSession( + sessionOptionsHandle, + checkpointHandle, + trainModelData[0], + trainModelData[1], + evalModelData[0], + evalModelData[1], + optimizerModelData[0], + optimizerModelData[1], + ); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); + return trainingSessionHandle; + } catch (e) { + if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { + wasm._OrtTrainingReleaseSession(trainingSessionHandle); + } + throw e; + } finally { + wasm._free(trainModelData[0]); + wasm._free(evalModelData[0]); + wasm._free(optimizerModelData[0]); + + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach((alloc) => wasm._free(alloc)); + } +}; /** * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the @@ -157,27 +181,31 @@ export const createTrainingSessionHandle = * @param inputOutputAllocs modified in-place by this method * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor */ -const createAndAllocateTensors = - (trainingSessionId: number, indices: number[], tensors: Array, tensorHandles: number[], - inputOutputAllocs: number[], indexAdd: number) => { - const count = indices.length; - - // creates the tensors - for (let i = 0; i < count; i++) { - prepareInputOutputTensor( - tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); - } +const createAndAllocateTensors = ( + trainingSessionId: number, + indices: number[], + tensors: Array, + tensorHandles: number[], + inputOutputAllocs: number[], + indexAdd: number, +) => { + const count = indices.length; + + // creates the tensors + for (let i = 0; i < count; i++) { + prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); + } - // moves to heap - const wasm = getInstance(); - const valuesOffset = wasm.stackAlloc(count * 4); - let valuesIndex = valuesOffset / 4; - for (let i = 0; i < count; i++) { - wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; - } + // moves to heap + const wasm = getInstance(); + const valuesOffset = wasm.stackAlloc(count * 4); + let valuesIndex = valuesOffset / 4; + for (let i = 0; i < count; i++) { + wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; + } - return valuesOffset; - }; + return valuesOffset; +}; /** * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information @@ -187,86 +215,101 @@ const createAndAllocateTensors = * @param outputCount * @returns list of TensorMetadata retrieved from the output handles. */ -const moveOutputToTensorMetadataArr = - (outputValuesOffset: number, outputCount: number, outputTensorHandles: number[], - outputTensors: Array) => { - const wasm = getInstance(); - const output: TensorMetadata[] = []; - - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - if (tensor === outputTensorHandles[i]) { - // output tensor is pre-allocated. no need to copy data. - output.push(outputTensors[i]!); - continue; - } +const moveOutputToTensorMetadataArr = ( + outputValuesOffset: number, + outputCount: number, + outputTensorHandles: number[], + outputTensors: Array, +) => { + const wasm = getInstance(); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; + } - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); - - let type: Tensor.Type|undefined, dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); - - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength) - .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - wasm._OrtReleaseTensor(tensor); + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); + + let type: Tensor.Type | undefined, + dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, + tensorDataOffset, + tensorDataOffset + 4, + tensorDataOffset + 8, + tensorDataOffset + 12, + ); + ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); + + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( + wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength), + ); + output.push([type, dims, data, 'cpu']); + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); } + wasm._OrtReleaseTensor(tensor); + } + } - return output; - }; + return output; +}; -export const lazyResetGrad = async(trainingSessionId: number): Promise => { +export const lazyResetGrad = async (trainingSessionId: number): Promise => { const wasm = getInstance(); if (wasm._OrtTrainingLazyResetGrad) { const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId); - ifErrCodeCheckLastError(errorCode, 'Can\'t call lazyResetGrad.'); + ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad."); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } }; -export const runTrainStep = async( - trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], - outputTensors: Array, options: InferenceSession.RunOptions): Promise => { +export const runTrainStep = async ( + trainingSessionId: number, + inputIndices: number[], + inputTensors: TensorMetadata[], + outputIndices: number[], + outputTensors: Array, + options: InferenceSession.RunOptions, +): Promise => { const wasm = getInstance(); const inputCount = inputIndices.length; @@ -287,15 +330,33 @@ export const runTrainStep = async( // handle inputs -- you don't want anything added to the index const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + trainingSessionId, + inputIndices, + inputTensors, + inputTensorHandles, + inputOutputAllocs, + 0, + ); // handle outputs // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); + trainingSessionId, + outputIndices, + outputTensors, + outputTensorHandles, + inputOutputAllocs, + inputCount, + ); if (wasm._OrtTrainingRunTrainStep) { const errorCode = wasm._OrtTrainingRunTrainStep( - trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + trainingSessionId, + inputValuesOffset, + inputCount, + outputValuesOffset, + outputCount, + runOptionsHandle, + ); ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); @@ -305,19 +366,21 @@ export const runTrainStep = async( } finally { wasm.stackRestore(beforeRunStack); - inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach(p => wasm._free(p)); + inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach((p) => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); } - runOptionsAllocs.forEach(p => wasm._free(p)); + runOptionsAllocs.forEach((p) => wasm._free(p)); } }; -export const runOptimizerStep = - async(trainingSessionId: number, options: InferenceSession.RunOptions): Promise => { +export const runOptimizerStep = async ( + trainingSessionId: number, + options: InferenceSession.RunOptions, +): Promise => { const wasm = getInstance(); let runOptionsHandle = 0; @@ -336,13 +399,18 @@ export const runOptimizerStep = if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); } - runOptionsAllocs.forEach(p => wasm._free(p)); + runOptionsAllocs.forEach((p) => wasm._free(p)); } }; -export const runEvalStep = async( - trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], - outputTensors: Array, options: InferenceSession.RunOptions): Promise => { +export const runEvalStep = async ( + trainingSessionId: number, + inputIndices: number[], + inputTensors: TensorMetadata[], + outputIndices: number[], + outputTensors: Array, + options: InferenceSession.RunOptions, +): Promise => { const wasm = getInstance(); const inputCount = inputIndices.length; @@ -363,15 +431,33 @@ export const runEvalStep = async( // handle inputs -- you don't want anything added to the index const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + trainingSessionId, + inputIndices, + inputTensors, + inputTensorHandles, + inputOutputAllocs, + 0, + ); // handle outputs // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); + trainingSessionId, + outputIndices, + outputTensors, + outputTensorHandles, + inputOutputAllocs, + inputCount, + ); if (wasm._OrtTrainingEvalStep) { const errorCode = wasm._OrtTrainingEvalStep( - trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + trainingSessionId, + inputValuesOffset, + inputCount, + outputValuesOffset, + outputCount, + runOptionsHandle, + ); ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); } else { @@ -382,14 +468,14 @@ export const runEvalStep = async( } finally { wasm.stackRestore(beforeRunStack); - inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach(p => wasm._free(p)); + inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach((p) => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); } - runOptionsAllocs.forEach(p => wasm._free(p)); + runOptionsAllocs.forEach((p) => wasm._free(p)); } }; @@ -401,7 +487,7 @@ export const getParametersSize = (trainingSessionId: number, trainableOnly: bool const sizeOffset = wasm.stackAlloc(4); if (wasm._OrtTrainingGetParametersSize) { const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); - ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size'); + ifErrCodeCheckLastError(errorCode, "Can't get parameters size"); return wasm.HEAP32[sizeOffset / 4]; } else { @@ -412,8 +498,10 @@ export const getParametersSize = (trainingSessionId: number, trainableOnly: bool } }; -export const getContiguousParameters = - async(trainingSessionId: number, trainableOnly: boolean): Promise => { +export const getContiguousParameters = async ( + trainingSessionId: number, + trainableOnly: boolean, +): Promise => { const wasm = getInstance(); const stack = wasm.stackSave(); @@ -437,15 +525,22 @@ export const getContiguousParameters = try { // wraps allocated array in a tensor tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(locationAsString)); + tensorDataTypeStringToEnum(tensorTypeAsString), + paramsOffset, + paramsByteLength, + dimsOffset, + dims.length, + dataLocationStringToEnum(locationAsString), + ); ifErrCodeCheckLastError( - tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false); + tensor, + `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, + false, + ); if (wasm._OrtTrainingCopyParametersToBuffer) { const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); - ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.'); - + ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters."); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -454,8 +549,9 @@ export const getContiguousParameters = const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); const data = new typedArrayConstructor(parametersSize); const output: TensorMetadata[] = []; - new Uint8Array(data.buffer, data.byteOffset, data.byteLength) - .set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( + wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength), + ); output.push([tensorTypeAsString, dims, data, locationAsString]); if (output.length !== 1) { throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of @@ -473,8 +569,11 @@ export const getContiguousParameters = } }; -export const loadParametersBuffer = - async(trainingSessionId: number, buffer: Uint8Array, trainableOnly: boolean): Promise => { +export const loadParametersBuffer = async ( + trainingSessionId: number, + buffer: Uint8Array, + trainableOnly: boolean, +): Promise => { const wasm = getInstance(); const stack = wasm.stackSave(); @@ -495,13 +594,18 @@ export const loadParametersBuffer = try { tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength, - dataLocationStringToEnum(locationAsString)); + tensorDataTypeStringToEnum(tensorTypeAsString), + bufferOffset, + bufferByteLength, + dimsOffset, + dimsLength, + dataLocationStringToEnum(locationAsString), + ); ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); if (wasm._OrtTrainingCopyParametersFromBuffer) { const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); - ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.'); + ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters."); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70728c82e7753..70b6cceab0eef 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -6,7 +6,7 @@ // https://github.com/webmachinelearning/webnn/issues/677 /// -import type {Tensor} from 'onnxruntime-common'; +import type { Tensor } from 'onnxruntime-common'; /* eslint-disable @typescript-eslint/naming-convention */ @@ -18,8 +18,12 @@ export declare namespace JSEP { type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise; type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void; type ReleaseKernelFunction = (kernel: number) => void; - type RunFunction = - (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; + type RunFunction = ( + kernel: number, + contextDataOffset: number, + sessionHandle: number, + errors: Array>, + ) => number; type CaptureBeginFunction = () => void; type CaptureEndFunction = () => void; type ReplayFunction = () => void; @@ -42,11 +46,22 @@ export declare namespace JSEP { * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and * registers a few callbacks that will be called in C++ code. */ - jsepInit(name: 'webgpu', initParams: [ - backend: BackendType, alloc: AllocFunction, free: FreeFunction, upload: UploadFunction, - download: DownloadFunction, createKernel: CreateKernelFunction, releaseKernel: ReleaseKernelFunction, - run: RunFunction, captureBegin: CaptureBeginFunction, captureEnd: CaptureEndFunction, replay: ReplayFunction - ]): void; + jsepInit( + name: 'webgpu', + initParams: [ + backend: BackendType, + alloc: AllocFunction, + free: FreeFunction, + upload: UploadFunction, + download: DownloadFunction, + createKernel: CreateKernelFunction, + releaseKernel: ReleaseKernelFunction, + run: RunFunction, + captureBegin: CaptureBeginFunction, + captureEnd: CaptureEndFunction, + replay: ReplayFunction, + ], + ): void; jsepInit(name: 'webnn', initParams?: never): void; } @@ -94,9 +109,11 @@ export declare namespace JSEP { * @param type - specify the tensor type. * @returns the generated downloader function. */ - jsepCreateDownloader: - (gpuBuffer: GPUBuffer, size: number, - type: Tensor.GpuBufferDataTypes) => () => Promise; + jsepCreateDownloader: ( + gpuBuffer: GPUBuffer, + size: number, + type: Tensor.GpuBufferDataTypes, + ) => () => Promise; /** * [exported from pre-jsep.js] Called when InferenceSession.run started. This function will be called before * _OrtRun[WithBinding]() is called. @@ -134,10 +151,20 @@ export interface OrtInferenceAPIs { _OrtFree(stringHandle: number): void; _OrtCreateTensor( - dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number, - dataLocation: number): number; - _OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): - number; + dataType: number, + dataOffset: number, + dataLength: number, + dimsOffset: number, + dimsLength: number, + dataLocation: number, + ): number; + _OrtGetTensorData( + tensorHandle: number, + dataType: number, + dataOffset: number, + dimsOffset: number, + dimsLength: number, + ): number; _OrtReleaseTensor(tensorHandle: number): void; _OrtCreateBinding(sessionHandle: number): number; _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise; @@ -145,16 +172,35 @@ export interface OrtInferenceAPIs { _OrtClearBoundOutputs(ioBindingHandle: number): void; _OrtReleaseBinding(ioBindingHandle: number): void; _OrtRunWithBinding( - sessionHandle: number, ioBindingHandle: number, outputCount: number, outputsOffset: number, - runOptionsHandle: number): Promise; + sessionHandle: number, + ioBindingHandle: number, + outputCount: number, + outputsOffset: number, + runOptionsHandle: number, + ): Promise; _OrtRun( - sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number, - outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): Promise; + sessionHandle: number, + inputNamesOffset: number, + inputsOffset: number, + inputCount: number, + outputNamesOffset: number, + outputCount: number, + outputsOffset: number, + runOptionsHandle: number, + ): Promise; _OrtCreateSessionOptions( - graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number, - enableProfiling: boolean, profileFilePrefix: number, logId: number, logSeverityLevel: number, - logVerbosityLevel: number, optimizedModelFilePath: number): number; + graphOptimizationLevel: number, + enableCpuMemArena: boolean, + enableMemPattern: boolean, + executionMode: number, + enableProfiling: boolean, + profileFilePrefix: number, + logId: number, + logSeverityLevel: number, + logVerbosityLevel: number, + optimizedModelFilePath: number, + ): number; _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; _OrtAddFreeDimensionOverride(sessionOptionsHandle: number, name: number, dim: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; @@ -173,33 +219,66 @@ export interface OrtTrainingAPIs { _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; _OrtTrainingCreateSession( - sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number, - evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number; + sessionOptionsHandle: number, + checkpointHandle: number, + trainOffset: number, + trainLength: number, + evalOffset: number, + evalLength: number, + optimizerOffset: number, + optimizerLength: number, + ): number; _OrtTrainingLazyResetGrad(trainingHandle: number): number; _OrtTrainingRunTrainStep( - trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, - runOptionsHandle: number): number; + trainingHandle: number, + inputsOffset: number, + inputCount: number, + outputsOffset: number, + outputCount: number, + runOptionsHandle: number, + ): number; _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number; _OrtTrainingEvalStep( - trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, - runOptionsHandle: number): number; + trainingHandle: number, + inputsOffset: number, + inputCount: number, + outputsOffset: number, + outputCount: number, + runOptionsHandle: number, + ): number; _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; _OrtTrainingCopyParametersToBuffer( - trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + trainingHandle: number, + parametersBuffer: number, + parameterCount: number, + trainableOnly: boolean, + ): number; _OrtTrainingCopyParametersFromBuffer( - trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + trainingHandle: number, + parametersBuffer: number, + parameterCount: number, + trainableOnly: boolean, + ): number; _OrtTrainingGetModelInputOutputCount( - trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; - _OrtTrainingGetModelInputOutputName(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): - number; + trainingHandle: number, + inputCount: number, + outputCount: number, + isEvalModel: boolean, + ): number; + _OrtTrainingGetModelInputOutputName( + trainingHandle: number, + index: number, + isInput: boolean, + isEvalModel: boolean, + ): number; _OrtTrainingReleaseSession(trainingHandle: number): void; } @@ -207,8 +286,11 @@ export interface OrtTrainingAPIs { /** * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ -export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial, - Partial { +export interface OrtWasmModule + extends EmscriptenModule, + OrtInferenceAPIs, + Partial, + Partial { // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index f80bd7195d456..008b9b41b1592 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import type {OrtWasmModule} from './wasm-types'; -import {isNode} from './wasm-utils-env'; +import type { OrtWasmModule } from './wasm-types'; +import { isNode } from './wasm-utils-env'; /** * The classic script source URL. This is not always available in non ESModule environments. @@ -10,14 +10,18 @@ import {isNode} from './wasm-utils-env'; * In Node.js, this is undefined. */ export const scriptSrc = - // if Nodejs, return undefined - isNode ? undefined : - // if It's ESM, use import.meta.url - BUILD_DEFS.ESM_IMPORT_META_URL ?? - // use `document.currentScript.src` if available - (typeof document !== 'undefined' ? (document.currentScript as HTMLScriptElement)?.src : - // use `self.location.href` if available - (typeof self !== 'undefined' ? self.location?.href : undefined)); + // if Nodejs, return undefined + isNode + ? undefined + : // if It's ESM, use import.meta.url + (BUILD_DEFS.ESM_IMPORT_META_URL ?? + // use `document.currentScript.src` if available + (typeof document !== 'undefined' + ? (document.currentScript as HTMLScriptElement)?.src + : // use `self.location.href` if available + typeof self !== 'undefined' + ? self.location?.href + : undefined)); /** * The origin of the current location. @@ -69,8 +73,8 @@ const fallbackUrl = (filename: string, prefixOverride?: string) => `${prefixOver * * @returns - A promise that resolves to a new Blob URL */ -const preload = async(absoluteUrl: string): Promise => { - const response = await fetch(absoluteUrl, {credentials: 'same-origin'}); +const preload = async (absoluteUrl: string): Promise => { + const response = await fetch(absoluteUrl, { credentials: 'same-origin' }); const blob = await response.blob(); return URL.createObjectURL(blob); }; @@ -84,16 +88,17 @@ const preload = async(absoluteUrl: string): Promise => { * * @returns - A promise that resolves to the default export of the module. */ -const dynamicImportDefault = async(url: string): Promise => (await import(/* webpackIgnore: true */ url)).default; +const dynamicImportDefault = async (url: string): Promise => + (await import(/* webpackIgnore: true */ url)).default; /** * The proxy worker factory imported from the proxy worker module. * * This is only available when the WebAssembly proxy is not disabled. */ -const createProxyWorker: ((urlOverride?: string) => Worker)|undefined = - // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires - BUILD_DEFS.DISABLE_WASM_PROXY ? undefined : require('./proxy-worker/main').default; +const createProxyWorker: ((urlOverride?: string) => Worker) | undefined = + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + BUILD_DEFS.DISABLE_WASM_PROXY ? undefined : require('./proxy-worker/main').default; /** * Import the proxy worker. @@ -106,7 +111,7 @@ const createProxyWorker: ((urlOverride?: string) => Worker)|undefined = * - The object URL of the preloaded module, or undefined if no preload is needed. * - The proxy worker. */ -export const importProxyWorker = async(): Promise<[undefined | string, Worker]> => { +export const importProxyWorker = async (): Promise<[undefined | string, Worker]> => { if (!scriptSrc) { throw new Error('Failed to load proxy worker: cannot determine the script source URL.'); } @@ -126,15 +131,17 @@ export const importProxyWorker = async(): Promise<[undefined | string, Worker]> * * This is only available in ESM and when embedding is not disabled. */ -const embeddedWasmModule: EmscriptenModuleFactory|undefined = - BUILD_DEFS.IS_ESM && BUILD_DEFS.DISABLE_DYNAMIC_IMPORT ? - // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires - require( - !BUILD_DEFS.DISABLE_TRAINING ? '../../dist/ort-training-wasm-simd-threaded.mjs' : - !BUILD_DEFS.DISABLE_JSEP ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' : - '../../dist/ort-wasm-simd-threaded.mjs') - .default : - undefined; +const embeddedWasmModule: EmscriptenModuleFactory | undefined = + BUILD_DEFS.IS_ESM && BUILD_DEFS.DISABLE_DYNAMIC_IMPORT + ? // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + require( + !BUILD_DEFS.DISABLE_TRAINING + ? '../../dist/ort-training-wasm-simd-threaded.mjs' + : !BUILD_DEFS.DISABLE_JSEP + ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' + : '../../dist/ort-wasm-simd-threaded.mjs', + ).default + : undefined; /** * Import the WebAssembly module. @@ -148,15 +155,19 @@ const embeddedWasmModule: EmscriptenModuleFactory|undefined = * - The object URL of the preloaded module, or undefined if no preload is needed. * - The default export of the module, which is a factory function to create the WebAssembly module. */ -export const importWasmModule = async( - urlOverride: string|undefined, prefixOverride: string|undefined, - isMultiThreaded: boolean): Promise<[undefined | string, EmscriptenModuleFactory]> => { +export const importWasmModule = async ( + urlOverride: string | undefined, + prefixOverride: string | undefined, + isMultiThreaded: boolean, +): Promise<[undefined | string, EmscriptenModuleFactory]> => { if (BUILD_DEFS.DISABLE_DYNAMIC_IMPORT) { return [undefined, embeddedWasmModule!]; } else { - const wasmModuleFilename = !BUILD_DEFS.DISABLE_TRAINING ? 'ort-training-wasm-simd-threaded.mjs' : - !BUILD_DEFS.DISABLE_JSEP ? 'ort-wasm-simd-threaded.jsep.mjs' : - 'ort-wasm-simd-threaded.mjs'; + const wasmModuleFilename = !BUILD_DEFS.DISABLE_TRAINING + ? 'ort-training-wasm-simd-threaded.mjs' + : !BUILD_DEFS.DISABLE_JSEP + ? 'ort-wasm-simd-threaded.jsep.mjs' + : 'ort-wasm-simd-threaded.mjs'; const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride); // need to preload if all of the following conditions are met: // 1. not in Node.js. @@ -169,8 +180,9 @@ export const importWasmModule = async( // 4. the worker URL is not from the same origin. // - If the worker URL is from the same origin, we can create the worker directly. const needPreload = !isNode && isMultiThreaded && wasmModuleUrl && !isSameOrigin(wasmModuleUrl, prefixOverride); - const url = needPreload ? (await preload(wasmModuleUrl)) : - (wasmModuleUrl ?? fallbackUrl(wasmModuleFilename, prefixOverride)); + const url = needPreload + ? await preload(wasmModuleUrl) + : (wasmModuleUrl ?? fallbackUrl(wasmModuleFilename, prefixOverride)); return [needPreload ? url : undefined, await dynamicImportDefault>(url)]; } }; diff --git a/js/web/lib/wasm/wasm-utils-load-file.ts b/js/web/lib/wasm/wasm-utils-load-file.ts index 75c4df74a8af2..53cba46eeac2b 100644 --- a/js/web/lib/wasm/wasm-utils-load-file.ts +++ b/js/web/lib/wasm/wasm-utils-load-file.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {isNode} from './wasm-utils-env'; +import { isNode } from './wasm-utils-env'; /** * Load a file into a Uint8Array. @@ -9,17 +9,17 @@ import {isNode} from './wasm-utils-env'; * @param file - the file to load. Can be a URL/path, a Blob, an ArrayBuffer, or a Uint8Array. * @returns a Uint8Array containing the file data. */ -export const loadFile = async(file: string|Blob|ArrayBufferLike|Uint8Array): Promise => { +export const loadFile = async (file: string | Blob | ArrayBufferLike | Uint8Array): Promise => { if (typeof file === 'string') { if (isNode) { // load file into ArrayBuffer in Node.js try { - const {readFile} = require('node:fs/promises'); + const { readFile } = require('node:fs/promises'); return new Uint8Array(await readFile(file)); } catch (e) { if (e.code === 'ERR_FS_FILE_TOO_LARGE') { // file is too large, use fs.createReadStream instead - const {createReadStream} = require('node:fs'); + const { createReadStream } = require('node:fs'); const stream = createReadStream(file); const chunks: Uint8Array[] = []; for await (const chunk of stream) { @@ -56,7 +56,7 @@ export const loadFile = async(file: string|Blob|ArrayBufferLike|Uint8Array): Pro if (e instanceof RangeError) { // use WebAssembly Memory to allocate larger ArrayBuffer const pages = Math.ceil(fileSize / 65536); - buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer; + buffer = new WebAssembly.Memory({ initial: pages, maximum: pages }).buffer; } else { throw e; } @@ -65,7 +65,7 @@ export const loadFile = async(file: string|Blob|ArrayBufferLike|Uint8Array): Pro let offset = 0; // eslint-disable-next-line no-constant-condition while (true) { - const {done, value} = await reader.read(); + const { done, value } = await reader.read(); if (done) { break; } @@ -77,7 +77,6 @@ export const loadFile = async(file: string|Blob|ArrayBufferLike|Uint8Array): Pro return new Uint8Array(buffer, 0, fileSize); } } - } else if (file instanceof Blob) { return new Uint8Array(await file.arrayBuffer()); } else if (file instanceof Uint8Array) { diff --git a/js/web/lib/wasm/wasm-utils.ts b/js/web/lib/wasm/wasm-utils.ts index 37762b353f575..a820fd216ee03 100644 --- a/js/web/lib/wasm/wasm-utils.ts +++ b/js/web/lib/wasm/wasm-utils.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {getInstance} from './wasm-factory'; +import { getInstance } from './wasm-factory'; export const allocWasmString = (data: string, allocs: number[]): number => { const wasm = getInstance(); @@ -18,30 +18,33 @@ interface ExtraOptionsHandler { (name: string, value: string): void; } -export const iterateExtraOptions = - (options: Record, prefix: string, seen: WeakSet>, - handler: ExtraOptionsHandler): void => { - if (typeof options == 'object' && options !== null) { - if (seen.has(options)) { - throw new Error('Circular reference in options'); - } else { - seen.add(options); - } - } +export const iterateExtraOptions = ( + options: Record, + prefix: string, + seen: WeakSet>, + handler: ExtraOptionsHandler, +): void => { + if (typeof options == 'object' && options !== null) { + if (seen.has(options)) { + throw new Error('Circular reference in options'); + } else { + seen.add(options); + } + } - Object.entries(options).forEach(([key, value]) => { - const name = (prefix) ? prefix + key : key; - if (typeof value === 'object') { - iterateExtraOptions(value as Record, name + '.', seen, handler); - } else if (typeof value === 'string' || typeof value === 'number') { - handler(name, value.toString()); - } else if (typeof value === 'boolean') { - handler(name, (value) ? '1' : '0'); - } else { - throw new Error(`Can't handle extra config type: ${typeof value}`); - } - }); - }; + Object.entries(options).forEach(([key, value]) => { + const name = prefix ? prefix + key : key; + if (typeof value === 'object') { + iterateExtraOptions(value as Record, name + '.', seen, handler); + } else if (typeof value === 'string' || typeof value === 'number') { + handler(name, value.toString()); + } else if (typeof value === 'boolean') { + handler(name, value ? '1' : '0'); + } else { + throw new Error(`Can't handle extra config type: ${typeof value}`); + } + }); +}; /** * check web assembly API's last error and throw error if any error occurred. diff --git a/js/web/script/build.ts b/js/web/script/build.ts index eba5efa3f11e0..6d1b3bdb65068 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -5,7 +5,7 @@ import * as esbuild from 'esbuild'; import minimist from 'minimist'; import * as fs from 'node:fs/promises'; import * as path from 'node:path'; -import {SourceMapConsumer, SourceMapGenerator} from 'source-map'; +import { SourceMapConsumer, SourceMapGenerator } from 'source-map'; console.time('BUILD'); @@ -27,7 +27,7 @@ const args = minimist(process.argv.slice(2)); * --bundle-mode=node * Build a single ort-web bundle for nodejs. */ -const BUNDLE_MODE: 'prod'|'dev'|'perf'|'node' = args['bundle-mode'] || 'prod'; +const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = args['bundle-mode'] || 'prod'; /** * --debug @@ -41,7 +41,7 @@ const BUNDLE_MODE: 'prod'|'dev'|'perf'|'node' = args['bundle-mode'] || 'prod'; * Enable debug mode. In this mode, esbuild metafile feature will be enabled. Full bundle analysis will be saved to a * file as JSON. */ -const DEBUG = args.debug; // boolean|'verbose'|'save' +const DEBUG = args.debug; // boolean|'verbose'|'save' /** * Root folder of the source code: `/js/` @@ -72,7 +72,7 @@ const COPYRIGHT_HEADER = `/*! interface OrtBuildOptions { readonly isProduction?: boolean; readonly isNode?: boolean; - readonly format: 'iife'|'cjs'|'esm'; + readonly format: 'iife' | 'cjs' | 'esm'; readonly outputName: string; readonly define?: Record; } @@ -116,7 +116,7 @@ async function minifyWasmModuleJsForBrowser(filepath: string): Promise { const TIME_TAG = `BUILD:terserMinify:${filepath}`; console.time(TIME_TAG); - const contents = await fs.readFile(filepath, {encoding: 'utf-8'}); + const contents = await fs.readFile(filepath, { encoding: 'utf-8' }); // Find the first and the only occurrence of minified function implementation of "_emscripten_thread_set_strongref": // ```js @@ -145,8 +145,11 @@ async function minifyWasmModuleJsForBrowser(filepath: string): Promise { // If it is not the original source file, we need to find the minified function call. const matches = [...contents.matchAll(/\{[_a-zA-Z][_a-zA-Z0-9]*&&([_a-zA-Z][_a-zA-Z0-9]*\[.+?]\.ref)\(\)}/g)]; if (matches.length !== 1) { - throw new Error(`Unexpected number of matches for minified "PThread.pthreads[thread].ref()" in "${filepath}": ${ - matches.length}.`); + throw new Error( + `Unexpected number of matches for minified "PThread.pthreads[thread].ref()" in "${filepath}": ${ + matches.length + }.`, + ); } // matches[0] is the first and the only match. // matches[0][0] is the full matched string and matches[0][1] is the first capturing group. @@ -158,7 +161,7 @@ async function minifyWasmModuleJsForBrowser(filepath: string): Promise { module: true, compress: { passes: 2, - global_defs: {'process': undefined, 'globalThis.process': undefined}, + global_defs: { process: undefined, 'globalThis.process': undefined }, pure_funcs: markedAsPure, }, }); @@ -195,8 +198,10 @@ async function buildBundle(options: esbuild.BuildOptions) { // (see: https://github.com/evanw/esbuild/pull/2067#issuecomment-1981642558) const NODE_ESM_FIX_MIN = 'import{createRequire}from"module";const require=createRequire(import.meta.url);'; const banner = { - js: options.platform === 'node' && options.format === 'esm' ? COPYRIGHT_HEADER + '\n' + NODE_ESM_FIX_MIN : - COPYRIGHT_HEADER + js: + options.platform === 'node' && options.format === 'esm' + ? COPYRIGHT_HEADER + '\n' + NODE_ESM_FIX_MIN + : COPYRIGHT_HEADER, }; // Patch footer: @@ -211,7 +216,7 @@ async function buildBundle(options: esbuild.BuildOptions) { // see also: https://github.com/evanw/esbuild/issues/507 // const COMMONJS_FOOTER_MIN = 'typeof exports=="object"&&typeof module=="object"&&(module.exports=ort);'; - const footer = options.format === 'iife' ? {js: COMMONJS_FOOTER_MIN} : undefined; + const footer = options.format === 'iife' ? { js: COMMONJS_FOOTER_MIN } : undefined; // set BUILD_DEFS for ESM. if (options.format === 'esm') { @@ -229,14 +234,16 @@ async function buildBundle(options: esbuild.BuildOptions) { bundle: true, banner, footer, - ...options + ...options, }); if (DEBUG) { if (DEBUG === 'save') { await fs.writeFile( - `${path.basename(options.outfile!)}.esbuild.metafile.json`, JSON.stringify(result.metafile!, null, 2)); + `${path.basename(options.outfile!)}.esbuild.metafile.json`, + JSON.stringify(result.metafile!, null, 2), + ); } else { - console.log(await esbuild.analyzeMetafile(result.metafile!, {verbose: DEBUG === 'verbose'})); + console.log(await esbuild.analyzeMetafile(result.metafile!, { verbose: DEBUG === 'verbose' })); } } } @@ -256,8 +263,9 @@ async function buildOrt({ define = DEFAULT_DEFINE, }: OrtBuildOptions) { const platform = isNode ? 'node' : 'browser'; - const external = - isNode ? ['onnxruntime-common'] : ['node:fs/promises', 'node:fs', 'node:os', 'module', 'worker_threads']; + const external = isNode + ? ['onnxruntime-common'] + : ['node:fs/promises', 'node:fs', 'node:os', 'module', 'worker_threads']; const plugins: esbuild.Plugin[] = []; const defineOverride: Record = {}; if (!isNode) { @@ -269,10 +277,10 @@ async function buildOrt({ plugins.push({ name: 'emscripten-mjs-handler', setup(build: esbuild.PluginBuild) { - build.onLoad( - {filter: /dist[\\/]ort-.*wasm.*\.mjs$/}, - async args => ({contents: await minifyWasmModuleJsForBrowser(args.path)})); - } + build.onLoad({ filter: /dist[\\/]ort-.*wasm.*\.mjs$/ }, async (args) => ({ + contents: await minifyWasmModuleJsForBrowser(args.path), + })); + }, }); } @@ -284,7 +292,7 @@ async function buildOrt({ globalName: 'ort', plugins, external, - define: {...define, ...defineOverride}, + define: { ...define, ...defineOverride }, sourcemap: isProduction ? 'linked' : 'inline', minify: isProduction, }); @@ -306,25 +314,25 @@ async function buildTest() { external: ['../../node'], plugins: [ // polyfill nodejs modules - require('esbuild-plugin-polyfill-node').polyfillNode({globals: false}), + require('esbuild-plugin-polyfill-node').polyfillNode({ globals: false }), // make "ort" external { name: 'make-ort-external', setup(build: esbuild.PluginBuild) { - build.onResolve( - {filter: /^onnxruntime-common$/}, - _args => ({path: 'onnxruntime-common', namespace: 'make-ort-external'})); - build.onLoad( - {filter: /.*/, namespace: 'make-ort-external'}, - _args => ({contents: 'module.exports = globalThis.ort;'})); - } - } + build.onResolve({ filter: /^onnxruntime-common$/ }, (_args) => ({ + path: 'onnxruntime-common', + namespace: 'make-ort-external', + })); + build.onLoad({ filter: /.*/, namespace: 'make-ort-external' }, (_args) => ({ + contents: 'module.exports = globalThis.ort;', + })); + }, + }, ], minify: isProduction, }); } - /** * Perform the post-process step after ESBuild finishes the build. * @@ -375,7 +383,9 @@ async function postProcess() { const jsFileLines = (await fs.readFile(jsFilePath, 'utf-8')).split('\n'); - let line = -1, column = -1, found = false; + let line = -1, + column = -1, + found = false; for (let i = 0; i < jsFileLines.length; i++) { const importColumnIndex = jsFileLines[i].indexOf(IMPORT_ORIGINAL); if (importColumnIndex !== -1) { @@ -414,9 +424,9 @@ async function postProcess() { } updatedSourceMap.addMapping({ - generated: {line: mapping.generatedLine, column: mapping.generatedColumn}, + generated: { line: mapping.generatedLine, column: mapping.generatedColumn }, source: mapping.source, - original: {line: mapping.originalLine, column: mapping.originalColumn}, + original: { line: mapping.originalLine, column: mapping.originalColumn }, name: mapping.name, }); }); @@ -427,9 +437,11 @@ async function postProcess() { const originalSourcemap = JSON.parse(originalSourcemapString); const updatedSourcemap = JSON.parse(updatedSourcemapString); - if (originalSourcemap.sources.length !== updatedSourcemap.sources.length || - originalSourcemap.sourcesContent.length !== updatedSourcemap.sourcesContent.length || - new Set(originalSourcemap.names).size !== new Set(updatedSourcemap.names).size) { + if ( + originalSourcemap.sources.length !== updatedSourcemap.sources.length || + originalSourcemap.sourcesContent.length !== updatedSourcemap.sourcesContent.length || + new Set(originalSourcemap.names).size !== new Set(updatedSourcemap.names).size + ) { throw new Error('Failed to update source map: source map length mismatch.'); } const originalMappingsCount = originalSourcemap.mappings.split(/[;,]/); @@ -444,8 +456,11 @@ async function postProcess() { await fs.writeFile(jsFilePath, jsFileLines.join('\n')); const newJsFileSize = (await fs.stat(jsFilePath)).size; if (newJsFileSize - originalJsFileSize !== IMPORT_MAGIC_COMMENT.length) { - throw new Error(`Failed to insert magic comment to file "${file}". Original size: ${ - originalJsFileSize}, New size: ${newJsFileSize}`); + throw new Error( + `Failed to insert magic comment to file "${file}". Original size: ${ + originalJsFileSize + }, New size: ${newJsFileSize}`, + ); } } } @@ -551,7 +566,7 @@ async function main() { if (BUNDLE_MODE === 'dev') { // ort.all.js - await buildOrt({outputName: 'ort.all', format: 'iife', define: {...DEFAULT_DEFINE}}); + await buildOrt({ outputName: 'ort.all', format: 'iife', define: { ...DEFAULT_DEFINE } }); } if (BUNDLE_MODE === 'perf') { @@ -565,45 +580,45 @@ async function main() { if (BUNDLE_MODE === 'prod') { // ort.all[.min].[m]js - await addAllWebBuildTasks({outputName: 'ort.all'}); + await addAllWebBuildTasks({ outputName: 'ort.all' }); // ort.all.bundle.min.mjs await buildOrt({ isProduction: true, outputName: 'ort.all.bundle', format: 'esm', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true'}, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true' }, }); // ort[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true'}, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true' }, }); // ort.bundle.min.mjs await buildOrt({ isProduction: true, outputName: 'ort.bundle', format: 'esm', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true'}, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true' }, }); // ort.webgpu[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort.webgpu', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, }); // ort.webgpu.bundle.min.mjs await buildOrt({ isProduction: true, outputName: 'ort.webgpu.bundle', format: 'esm', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true'}, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'true' }, }); // ort.wasm[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort.wasm', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, + define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, }); // ort.webgl[.min].[m]js await addAllWebBuildTasks({ diff --git a/js/web/script/generate-webgl-operator-md.ts b/js/web/script/generate-webgl-operator-md.ts index 878a4c9a4008b..5cc43eb903527 100644 --- a/js/web/script/generate-webgl-operator-md.ts +++ b/js/web/script/generate-webgl-operator-md.ts @@ -3,19 +3,19 @@ import * as assert from 'assert'; import * as fs from 'fs'; -import {EOL} from 'os'; +import { EOL } from 'os'; import * as path from 'path'; -import {Attribute} from '../lib/onnxjs/attribute'; -import {WEBGL_OP_RESOLVE_RULES} from '../lib/onnxjs/backends/webgl/op-resolve-rules'; -import {OpSet, resolveOperator} from '../lib/onnxjs/opset'; -import {Tensor} from '../lib/onnxjs/tensor'; +import { Attribute } from '../lib/onnxjs/attribute'; +import { WEBGL_OP_RESOLVE_RULES } from '../lib/onnxjs/backends/webgl/op-resolve-rules'; +import { OpSet, resolveOperator } from '../lib/onnxjs/opset'; +import { Tensor } from '../lib/onnxjs/tensor'; function checkSupport(type: string, range: [number, number], rules: readonly OpSet.ResolveRule[]) { - const node = {name: '', opType: type, inputs: [], outputs: [], attributes: new Attribute(undefined)}; + const node = { name: '', opType: type, inputs: [], outputs: [], attributes: new Attribute(undefined) }; for (let i = range[0]; i <= range[1]; i++) { try { - resolveOperator(node, [{domain: '', version: i}], rules); + resolveOperator(node, [{ domain: '', version: i }], rules); } catch (_e) { return false; } @@ -36,34 +36,35 @@ function dummyOpImpl(): Tensor[] { } const ops = new Map>(); -const webglCheckOnlyRules = - WEBGL_OP_RESOLVE_RULES.map(rule => [rule[0], rule[1], rule[2], dummyOpImpl] as OpSet.ResolveRule); +const webglCheckOnlyRules = WEBGL_OP_RESOLVE_RULES.map( + (rule) => [rule[0], rule[1], rule[2], dummyOpImpl] as OpSet.ResolveRule, +); fs.readFileSync(path.join(__dirname, '../../../cmake/external/onnx/onnx/defs/operator_sets.h'), 'utf8') - .split(/\r?\n/) - .forEach(line => { - const matcher = /class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME\(\s*(\w+),\s*(\d+),\s*(\w+)\)/; - const matches = matcher.exec(line); - if (matches) { - const opset = matches[1]; - const version = Number.parseInt(matches[2], 10); - const opType = matches[3]; - - let currentSet = ops.get(opset); - if (currentSet === undefined) { - currentSet = new Map(); - ops.set(opset, currentSet); - } - - let currentOp = currentSet.get(opType); - if (currentOp === undefined) { - currentOp = []; - currentSet.set(opType, currentOp); - } - - currentOp.push(version); + .split(/\r?\n/) + .forEach((line) => { + const matcher = /class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME\(\s*(\w+),\s*(\d+),\s*(\w+)\)/; + const matches = matcher.exec(line); + if (matches) { + const opset = matches[1]; + const version = Number.parseInt(matches[2], 10); + const opType = matches[3]; + + let currentSet = ops.get(opset); + if (currentSet === undefined) { + currentSet = new Map(); + ops.set(opset, currentSet); } - }); + + let currentOp = currentSet.get(opType); + if (currentOp === undefined) { + currentOp = []; + currentSet.set(opType, currentOp); + } + + currentOp.push(version); + } + }); const opsets = Array.from(ops.keys()); assert.ok(opsets.length === 1 && opsets[0] === 'Onnx'); @@ -84,8 +85,8 @@ doc.write(`| Operator | WebGl Backend |${EOL}`); doc.write(`|:--------:|:-------------:|${EOL}`); let VERSION_MAX = 0; -onnxOpset.forEach(versions => { - versions.forEach(version => VERSION_MAX = Math.max(VERSION_MAX, version)); +onnxOpset.forEach((versions) => { + versions.forEach((version) => (VERSION_MAX = Math.max(VERSION_MAX, version))); }); for (const type of opTypes) { @@ -99,7 +100,10 @@ for (const type of opTypes) { webgl.push(formatDesc(type, versionRange, checkSupport(type, versionRange, webglCheckOnlyRules), last)); } - doc.write(`| [${type}](https://github.com/onnx/onnx/blob/main/docs/Operators.md#${type}) | ${ - webgl.filter(d => d.length > 0).join(', ')} |${EOL}`); + doc.write( + `| [${type}](https://github.com/onnx/onnx/blob/main/docs/Operators.md#${type}) | ${webgl + .filter((d) => d.length > 0) + .join(', ')} |${EOL}`, + ); } doc.end(); diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index eab8175a941bd..5e9a7152bf185 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -2,22 +2,22 @@ // Licensed under the MIT License. import fs from 'fs'; -import {EOL} from 'os'; +import { EOL } from 'os'; import path from 'path'; // The following variable allows to insert comments per operator const COMMENTS: Record = { - 'AveragePool': 'need perf optimization; need implementing activation', - 'MaxPool': 'need perf optimization; need implementing activation', - 'Conv': 'need perf optimization; conv3d is not supported; need implementing activation', - 'ConvTranspose': 'need perf optimization; ConvTranspose3d is not supported; need implementing activation', - 'Transpose': 'need perf optimization', - 'Reshape': 'no GPU kernel', - 'Shape': 'no GPU kernel; an ORT warning is generated - need to fix', - 'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling', - 'Attention': 'need implementing mask and past/present', - 'MultiHeadAttention': 'need implementing mask and past/present', + AveragePool: 'need perf optimization; need implementing activation', + MaxPool: 'need perf optimization; need implementing activation', + Conv: 'need perf optimization; conv3d is not supported; need implementing activation', + ConvTranspose: 'need perf optimization; ConvTranspose3d is not supported; need implementing activation', + Transpose: 'need perf optimization', + Reshape: 'no GPU kernel', + Shape: 'no GPU kernel; an ORT warning is generated - need to fix', + Resize: 'CoordinateTransformMode align_corners is not supported with downsampling', + Attention: 'need implementing mask and past/present', + MultiHeadAttention: 'need implementing mask and past/present', }; /* eslint-disable max-len */ @@ -29,20 +29,22 @@ const MATCHERS = [ ]; /* eslint-enable max-len */ -const ALL_REGISTERED_OPERATORS: Map < string, { - opset: Map>; - comments: string; -} +const ALL_REGISTERED_OPERATORS: Map< + string, + { + opset: Map>; + comments: string; + } > = new Map(); // parse js_execution_provider.cc const JS_EXECUTION_PROVIDER_CONTENTS = - fs.readFileSync(path.join(__dirname, '../../../onnxruntime/core/providers/js/js_execution_provider.cc'), 'utf8') + - fs.readFileSync(path.join(__dirname, '../../../onnxruntime/contrib_ops/js/js_contrib_kernels.cc'), 'utf8'); -MATCHERS.forEach(m => { + fs.readFileSync(path.join(__dirname, '../../../onnxruntime/core/providers/js/js_execution_provider.cc'), 'utf8') + + fs.readFileSync(path.join(__dirname, '../../../onnxruntime/contrib_ops/js/js_contrib_kernels.cc'), 'utf8'); +MATCHERS.forEach((m) => { for (const match of JS_EXECUTION_PROVIDER_CONTENTS.matchAll(m)) { const groups = match.groups!; - const {ep, opsetDomain, opsetVersion, opsetVersionStart, opsetVersionEnd, op} = groups; + const { ep, opsetDomain, opsetVersion, opsetVersionStart, opsetVersionEnd, op } = groups; if (ep !== 'kJsExecutionProvider') { throw new Error(`invalid EP registration for EP name: ${ep}`); @@ -64,10 +66,10 @@ MATCHERS.forEach(m => { let opInfo = ALL_REGISTERED_OPERATORS.get(op); if (!opInfo) { - opInfo = {opset: new Map(), comments: COMMENTS[op]}; + opInfo = { opset: new Map(), comments: COMMENTS[op] }; ALL_REGISTERED_OPERATORS.set(op, opInfo); } - const {opset} = opInfo; + const { opset } = opInfo; let currentDomainInfo = opset.get(domain); if (!currentDomainInfo) { currentDomainInfo = []; @@ -93,17 +95,23 @@ Do not modify directly.*${EOL}${EOL}`); doc.write(`| Operator | Opset | Comments |${EOL}`); doc.write(`|:--------:|:-------------:|-----|${EOL}`); -Array.from(ALL_REGISTERED_OPERATORS.keys()).sort().forEach(op => { - const {opset, comments} = ALL_REGISTERED_OPERATORS.get(op)!; - const opsetString = - Array.from(opset.keys()) - .sort() - .map( - domain => `${domain}(${ - [...new Set(opset.get(domain)!.map( - ver => ver[1] ? (ver[0] === ver[1] ? `${ver[0]}` : `${ver[0]}-${ver[1]}`) : `${ver[0]}+`))] - .join(',')})`) - .join('; '); - doc.write(`| ${op} | ${opsetString} | ${comments ?? ''} |${EOL}`); -}); +Array.from(ALL_REGISTERED_OPERATORS.keys()) + .sort() + .forEach((op) => { + const { opset, comments } = ALL_REGISTERED_OPERATORS.get(op)!; + const opsetString = Array.from(opset.keys()) + .sort() + .map( + (domain) => + `${domain}(${[ + ...new Set( + opset + .get(domain)! + .map((ver) => (ver[1] ? (ver[0] === ver[1] ? `${ver[0]}` : `${ver[0]}-${ver[1]}`) : `${ver[0]}+`)), + ), + ].join(',')})`, + ) + .join('; '); + doc.write(`| ${op} | ${opsetString} | ${comments ?? ''} |${EOL}`); + }); doc.end(); diff --git a/js/web/script/parse-profiler.ts b/js/web/script/parse-profiler.ts index 674be5cf8eeb3..95053bab161bd 100644 --- a/js/web/script/parse-profiler.ts +++ b/js/web/script/parse-profiler.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - /* eslint-disable @typescript-eslint/restrict-plus-operands */ // parse-profiler @@ -13,15 +12,14 @@ // STEP.2 - parse // > node script/parse-profiler < profile.raw.log > profile.parsed.log - import * as readline from 'readline'; -const lines = readline.createInterface({input: process.stdin, output: process.stdout, terminal: false}); +const lines = readline.createInterface({ input: process.stdin, output: process.stdout, terminal: false }); // eslint-disable-next-line no-control-regex const matcher = /Profiler\.([^[\s\x1b]+)(\x1b\[0m)? (\d.+Z)\|([\d.]+)ms on event '([^']+)' at (\d*\.*\d*)/; const allEvents: any[] = []; -lines.on('line', input => { +lines.on('line', (input) => { const matches = matcher.exec(input); if (matches) { // console.log(matches); @@ -30,13 +28,16 @@ lines.on('line', input => { const ms = Number.parseFloat(matches[4]); const event = matches[5]; const endTimeInNumber = matches[6]; - allEvents.push({event, ms, logTimeStamp, category, endTimeInNumber}); + allEvents.push({ event, ms, logTimeStamp, category, endTimeInNumber }); } }); lines.on('close', () => { for (const i of allEvents) { - console.log(`${(i.category + ' ').substring(0, 12)} ${((i.ms) + ' ').substring(0, 12)} ${ - (i.event + ' ').substring(0, 40)} ${i.endTimeInNumber}`); + console.log( + `${(i.category + ' ').substring(0, 12)} ${(i.ms + ' ').substring(0, 12)} ${( + i.event + ' ' + ).substring(0, 40)} ${i.endTimeInNumber}`, + ); } }); diff --git a/js/web/script/prepack.ts b/js/web/script/prepack.ts index 4c5941d8dae12..d7c0ff3959fc6 100644 --- a/js/web/script/prepack.ts +++ b/js/web/script/prepack.ts @@ -12,7 +12,7 @@ function updatePackageJson() { const packageSelf = fs.readJSONSync(selfPackageJsonPath); const version = packageCommon.version; packageSelf.dependencies['onnxruntime-common'] = `${version}`; - fs.writeJSONSync(selfPackageJsonPath, packageSelf, {spaces: 2}); + fs.writeJSONSync(selfPackageJsonPath, packageSelf, { spaces: 2 }); console.log('=== finished updating package.json.'); } diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index 3e9042bf9fb3f..b1b2fa26b2351 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -34,8 +34,12 @@ Usage: const argv = process.argv.slice(2); -if (argv.indexOf('--help') !== -1 || argv.indexOf('-h') !== -1 || argv.indexOf('help') !== -1 || - argv.indexOf('h') !== -1) { +if ( + argv.indexOf('--help') !== -1 || + argv.indexOf('-h') !== -1 || + argv.indexOf('help') !== -1 || + argv.indexOf('h') !== -1 +) { console.log(HELP_MESSAGE); process.exit(); } @@ -48,8 +52,8 @@ const buildId = arg0isInteger ? argv[0] : (argv[1] ?? ''); const folderName = config === 'release' ? 'Release_wasm' : 'Debug_wasm'; function downloadJson(url: string, onSuccess: (data: any) => void) { - https.get(url, res => { - const {statusCode} = res; + https.get(url, (res) => { + const { statusCode } = res; const contentType = res.headers['content-type']; if (statusCode !== 200) { @@ -70,8 +74,8 @@ function downloadJson(url: string, onSuccess: (data: any) => void) { } function downloadZip(url: string, onSuccess: (data: Buffer) => void) { - https.get(url, res => { - const {statusCode} = res; + https.get(url, (res) => { + const { statusCode } = res; const contentType = res.headers['content-type']; if (statusCode !== 200) { @@ -92,59 +96,67 @@ function downloadZip(url: string, onSuccess: (data: Buffer) => void) { } function extractFile(zip: jszip, folder: string, file: string, artifactName: string) { - zip.file(`${artifactName}/${file}`)!.nodeStream() - .pipe(fs.createWriteStream(path.join(folder, file))) - .on('finish', () => { - console.log('# file downloaded and extracted: ' + file); - }); + zip + .file(`${artifactName}/${file}`)! + .nodeStream() + .pipe(fs.createWriteStream(path.join(folder, file))) + .on('finish', () => { + console.log('# file downloaded and extracted: ' + file); + }); } -console.log(`=== Start to pull ${config} WebAssembly artifacts from CI for ${ - buildId ? `build "${buildId}"` : 'latest "main" branch'} ===`); - -const filter = buildId ? `&buildIds=${buildId}` : - '&definitions=161' + - '&resultFilter=succeeded%2CpartiallySucceeded' + - '&$top=1' + - '&repositoryId=Microsoft/onnxruntime' + - '&repositoryType=GitHub' + - '&branchName=refs/heads/main'; +console.log( + `=== Start to pull ${config} WebAssembly artifacts from CI for ${ + buildId ? `build "${buildId}"` : 'latest "main" branch' + } ===`, +); + +const filter = buildId + ? `&buildIds=${buildId}` + : '&definitions=161' + + '&resultFilter=succeeded%2CpartiallySucceeded' + + '&$top=1' + + '&repositoryId=Microsoft/onnxruntime' + + '&repositoryType=GitHub' + + '&branchName=refs/heads/main'; // API reference: https://docs.microsoft.com/en-us/rest/api/azure/devops/build/builds/list downloadJson( - `https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/builds?api-version=6.1-preview.6${filter}`, data => { - const buildId = data.value[0].id; - - console.log(`=== Found latest build on main branch: ${buildId} ===`); - - // API reference: https://docs.microsoft.com/en-us/rest/api/azure/devops/build/artifacts/get%20artifact - downloadJson( - `https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/builds/${ - buildId}/artifacts?api-version=6.1-preview.5`, - data => { - let zipLink; - for (const v of data.value) { - if (v.name === folderName) { - zipLink = v.resource.downloadUrl; - } - } - - console.log('=== Ready to download zip files ==='); - - const WASM_FOLDER = path.join(__dirname, '../dist'); - if (!fs.existsSync(WASM_FOLDER)) { - fs.mkdirSync(WASM_FOLDER); - } - downloadZip(zipLink, buffer => { - void jszip.loadAsync(buffer).then(zip => { - extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', folderName); - extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.wasm', folderName); - - extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.mjs', folderName); - extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.mjs', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.mjs', folderName); - }); - }); + `https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/builds?api-version=6.1-preview.6${filter}`, + (data) => { + const buildId = data.value[0].id; + + console.log(`=== Found latest build on main branch: ${buildId} ===`); + + // API reference: https://docs.microsoft.com/en-us/rest/api/azure/devops/build/artifacts/get%20artifact + downloadJson( + `https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/builds/${buildId}/artifacts?api-version=6.1-preview.5`, + (data) => { + let zipLink; + for (const v of data.value) { + if (v.name === folderName) { + zipLink = v.resource.downloadUrl; + } + } + + console.log('=== Ready to download zip files ==='); + + const WASM_FOLDER = path.join(__dirname, '../dist'); + if (!fs.existsSync(WASM_FOLDER)) { + fs.mkdirSync(WASM_FOLDER); + } + downloadZip(zipLink, (buffer) => { + void jszip.loadAsync(buffer).then((zip) => { + extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', folderName); + extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', folderName); + extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.wasm', folderName); + + extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.mjs', folderName); + extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.mjs', folderName); + extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.mjs', folderName); }); - }); + }); + }, + ); + }, +); diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index adcd940178e07..506b6e54e2102 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -3,10 +3,10 @@ import minimist from 'minimist'; import npmlog from 'npmlog'; -import {Env, InferenceSession} from 'onnxruntime-common'; +import { Env, InferenceSession } from 'onnxruntime-common'; -import {Logger} from '../lib/onnxjs/instrument'; -import {Test} from '../test/test-types'; +import { Logger } from '../lib/onnxjs/instrument'; +import { Test } from '../test/test-types'; /* eslint-disable max-len */ const HELP_MESSAGE = ` @@ -129,11 +129,11 @@ Examples: /* eslint-enable max-len */ export declare namespace TestRunnerCliArgs { - type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; - type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'webnn'; - type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; - type BundleMode = 'dev'|'perf'; - type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; + type Mode = 'suite0' | 'suite1' | 'model' | 'unittest' | 'op'; + type Backend = 'cpu' | 'webgl' | 'webgpu' | 'wasm' | 'onnxruntime' | 'webnn'; + type Environment = 'chrome' | 'edge' | 'firefox' | 'electron' | 'safari' | 'node' | 'bs'; + type BundleMode = 'dev' | 'perf'; + type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location'; } export interface TestRunnerCliArgs { @@ -187,7 +187,7 @@ export interface TestRunnerCliArgs { /** * Specify graph optimization level */ - graphOptimizationLevel: 'disabled'|'basic'|'extended'|'all'; + graphOptimizationLevel: 'disabled' | 'basic' | 'extended' | 'all'; cpuOptions?: InferenceSession.CpuExecutionProviderOption; cudaOptions?: InferenceSession.CudaExecutionProviderOption; @@ -200,10 +200,9 @@ export interface TestRunnerCliArgs { chromiumFlags: string[]; } - function parseBooleanArg(arg: unknown, defaultValue: boolean): boolean; -function parseBooleanArg(arg: unknown): boolean|undefined; -function parseBooleanArg(arg: unknown, defaultValue?: boolean): boolean|undefined { +function parseBooleanArg(arg: unknown): boolean | undefined; +function parseBooleanArg(arg: unknown, defaultValue?: boolean): boolean | undefined { if (typeof arg === 'undefined') { return defaultValue; } @@ -229,7 +228,7 @@ function parseBooleanArg(arg: unknown, defaultValue?: boolean): boolean|undefine } function parseLogLevel(arg: T) { - let v: string[]|boolean; + let v: string[] | boolean; if (typeof arg === 'string') { v = arg.split(','); } else if (Array.isArray(arg)) { @@ -244,61 +243,61 @@ function parseLogLevel(arg: T) { } function parseLogConfig(args: minimist.ParsedArgs) { - const config: Array<{category: string; config: Logger.Config}> = []; + const config: Array<{ category: string; config: Logger.Config }> = []; const verbose = parseLogLevel(args['log-verbose']); const info = parseLogLevel(args['log-info']); const warning = parseLogLevel(args['log-warning']); const error = parseLogLevel(args['log-error']); if (typeof error === 'boolean' && error) { - config.push({category: '*', config: {minimalSeverity: 'error'}}); + config.push({ category: '*', config: { minimalSeverity: 'error' } }); } else if (typeof warning === 'boolean' && warning) { - config.push({category: '*', config: {minimalSeverity: 'warning'}}); + config.push({ category: '*', config: { minimalSeverity: 'warning' } }); } else if (typeof info === 'boolean' && info) { - config.push({category: '*', config: {minimalSeverity: 'info'}}); + config.push({ category: '*', config: { minimalSeverity: 'info' } }); } else if (typeof verbose === 'boolean' && verbose) { - config.push({category: '*', config: {minimalSeverity: 'verbose'}}); + config.push({ category: '*', config: { minimalSeverity: 'verbose' } }); } if (Array.isArray(error)) { - config.push(...error.map(i => ({category: i, config: {minimalSeverity: 'error' as Logger.Severity}}))); + config.push(...error.map((i) => ({ category: i, config: { minimalSeverity: 'error' as Logger.Severity } }))); } if (Array.isArray(warning)) { - config.push(...warning.map(i => ({category: i, config: {minimalSeverity: 'warning' as Logger.Severity}}))); + config.push(...warning.map((i) => ({ category: i, config: { minimalSeverity: 'warning' as Logger.Severity } }))); } if (Array.isArray(info)) { - config.push(...info.map(i => ({category: i, config: {minimalSeverity: 'info' as Logger.Severity}}))); + config.push(...info.map((i) => ({ category: i, config: { minimalSeverity: 'info' as Logger.Severity } }))); } if (Array.isArray(verbose)) { - config.push(...verbose.map(i => ({category: i, config: {minimalSeverity: 'verbose' as Logger.Severity}}))); + config.push(...verbose.map((i) => ({ category: i, config: { minimalSeverity: 'verbose' as Logger.Severity } }))); } return config; } function parseCpuOptions(_args: minimist.ParsedArgs): InferenceSession.CpuExecutionProviderOption { - return {name: 'cpu'}; + return { name: 'cpu' }; } function parseWasmOptions(_args: minimist.ParsedArgs): InferenceSession.WebAssemblyExecutionProviderOption { - return {name: 'wasm'}; + return { name: 'wasm' }; } function parseWasmFlags(args: minimist.ParsedArgs): Env.WebAssemblyFlags { const wasm = args.wasm || {}; - const numThreads = wasm.numThreads = wasm.numThreads ?? (args.x ?? args['wasm-number-threads']); + const numThreads = (wasm.numThreads = wasm.numThreads ?? args.x ?? args['wasm-number-threads']); if (typeof numThreads !== 'undefined' && typeof numThreads !== 'number') { throw new Error('Flag "wasm.numThreads"/"x"/"wasm-number-threads" must be a number value'); } - const initTimeout = wasm.initTimeout = wasm.initTimeout ?? args['wasm-init-timeout']; + const initTimeout = (wasm.initTimeout = wasm.initTimeout ?? args['wasm-init-timeout']); if (typeof initTimeout !== 'undefined' && typeof initTimeout !== 'number') { throw new Error('Flag "wasm.initTimeout"/"wasm-init-timeout" must be a number value'); } - const simd = wasm.simd = parseBooleanArg(wasm.simd ?? args['wasm-enable-simd']); + const simd = (wasm.simd = parseBooleanArg(wasm.simd ?? args['wasm-enable-simd'])); if (typeof simd !== 'undefined' && typeof simd !== 'boolean') { throw new Error('Flag "wasm.simd"/"wasm-enable-simd" must be a boolean value'); } - const proxy = wasm.proxy = parseBooleanArg(wasm.proxy ?? args['wasm-enable-proxy']); + const proxy = (wasm.proxy = parseBooleanArg(wasm.proxy ?? args['wasm-enable-proxy'])); if (typeof proxy !== 'undefined' && typeof proxy !== 'boolean') { throw new Error('Flag "wasm.proxy"/"wasm-enable-proxy" must be a boolean value'); } @@ -306,28 +305,29 @@ function parseWasmFlags(args: minimist.ParsedArgs): Env.WebAssemblyFlags { } function parseWebglOptions(_args: minimist.ParsedArgs): InferenceSession.WebGLExecutionProviderOption { - return {name: 'webgl'}; + return { name: 'webgl' }; } function parseWebglFlags(args: minimist.ParsedArgs): Partial { const webgl = args.webgl || {}; - const contextId = webgl.contextId = webgl.contextId ?? args['webgl-context-id']; + const contextId = (webgl.contextId = webgl.contextId ?? args['webgl-context-id']); if (contextId !== undefined && contextId !== 'webgl' && contextId !== 'webgl2') { throw new Error('Flag "webgl.contextId"/"webgl-context-id" is invalid'); } - const matmulMaxBatchSize = webgl.matmulMaxBatchSize = webgl.matmulMaxBatchSize ?? args['webgl-matmul-max-batch-size']; + const matmulMaxBatchSize = (webgl.matmulMaxBatchSize = + webgl.matmulMaxBatchSize ?? args['webgl-matmul-max-batch-size']); if (matmulMaxBatchSize !== undefined && typeof matmulMaxBatchSize !== 'number') { throw new Error('Flag "webgl.matmulMaxBatchSize"/"webgl-matmul-max-batch-size" must be a number value'); } - const textureCacheMode = webgl.textureCacheMode = webgl.textureCacheMode ?? args['webgl-texture-cache-mode']; + const textureCacheMode = (webgl.textureCacheMode = webgl.textureCacheMode ?? args['webgl-texture-cache-mode']); if (textureCacheMode !== undefined && textureCacheMode !== 'initializerOnly' && textureCacheMode !== 'full') { throw new Error('Flag "webgl.textureCacheMode"/"webgl-texture-cache-mode" is invalid'); } - const pack = webgl.pack = parseBooleanArg(webgl.pack ?? args['webgl-texture-pack-mode']); + const pack = (webgl.pack = parseBooleanArg(webgl.pack ?? args['webgl-texture-pack-mode'])); if (pack !== undefined && typeof pack !== 'boolean') { throw new Error('Flag "webgl.pack"/"webgl-texture-pack-mode" is invalid'); } - const async = webgl.async = parseBooleanArg(webgl.async ?? args['webgl-async']); + const async = (webgl.async = parseBooleanArg(webgl.async ?? args['webgl-async'])); if (async !== undefined && typeof async !== 'boolean') { throw new Error('Flag "webgl.async"/"webgl-async" is invalid'); } @@ -336,13 +336,14 @@ function parseWebglFlags(args: minimist.ParsedArgs): Partial { function parseWebgpuFlags(args: minimist.ParsedArgs): Partial { const webgpu = args.webgpu || {}; - const profilingMode = (webgpu.profiling = webgpu.profiling ?? {}).mode = - webgpu?.profiling?.mode ?? webgpu.profilingMode ?? args['webgpu-profiling-mode']; + const profilingMode = ((webgpu.profiling = webgpu.profiling ?? {}).mode = + webgpu?.profiling?.mode ?? webgpu.profilingMode ?? args['webgpu-profiling-mode']); if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') { throw new Error('Flag "webgpu-profiling-mode" is invalid'); } - const validateInputContent = webgpu.validateInputContent = - parseBooleanArg(webgpu.validateInputContent ?? args['webgpu-validate-input-content']); + const validateInputContent = (webgpu.validateInputContent = parseBooleanArg( + webgpu.validateInputContent ?? args['webgpu-validate-input-content'], + )); if (validateInputContent !== undefined && typeof validateInputContent !== 'boolean') { throw new Error('Flag "webgpu-validate-input-content" is invalid'); } @@ -354,14 +355,14 @@ function parseWebNNOptions(args: minimist.ParsedArgs): InferenceSession.WebNNExe if (deviceType !== undefined && !['cpu', 'gpu', 'npu'].includes(deviceType)) { throw new Error('Flag "webnn-device-type" is invalid'); } - return {name: 'webnn', deviceType}; + return { name: 'webnn', deviceType }; } function parseGlobalEnvFlags(args: minimist.ParsedArgs) { const wasm = parseWasmFlags(args); const webgl = parseWebglFlags(args); const webgpu = parseWebgpuFlags(args); - return {webgl, wasm, webgpu}; + return { webgl, wasm, webgpu }; } export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs { @@ -383,7 +384,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs // Option: -e=<...>, --env=<...> const envArg = args.env || args.e; - const env = (typeof envArg !== 'string') ? 'chrome' : envArg; + const env = typeof envArg !== 'string' ? 'chrome' : envArg; if (['chrome', 'edge', 'firefox', 'electron', 'safari', 'node', 'bs'].indexOf(env) === -1) { throw new Error(`not supported env ${env}`); } @@ -398,8 +399,12 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const defaultBrowserBackends = ['webgl', 'webgpu', 'wasm' /*, 'webnn'*/]; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; - const backend = (typeof backendArgs !== 'string') ? (env === 'node' ? nodejsBackends : defaultBrowserBackends) : - backendArgs.split(','); + const backend = + typeof backendArgs !== 'string' + ? env === 'node' + ? nodejsBackends + : defaultBrowserBackends + : backendArgs.split(','); for (const b of backend) { if ((env !== 'node' && browserBackends.indexOf(b) === -1) || (env === 'node' && nodejsBackends.indexOf(b) === -1)) { throw new Error(`backend ${b} is not supported in env ${env}`); @@ -415,12 +420,12 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs let logLevel = logConfig[0]?.config.minimalSeverity; // Option: -p, --profile - const profile = (args.profile || args.p) ? true : false; + const profile = args.profile || args.p ? true : false; if (profile) { - logConfig.push({category: 'Profiler.session', config: {minimalSeverity: 'verbose'}}); - logConfig.push({category: 'Profiler.node', config: {minimalSeverity: 'verbose'}}); - logConfig.push({category: 'Profiler.op', config: {minimalSeverity: 'verbose'}}); - logConfig.push({category: 'Profiler.backend', config: {minimalSeverity: 'verbose'}}); + logConfig.push({ category: 'Profiler.session', config: { minimalSeverity: 'verbose' } }); + logConfig.push({ category: 'Profiler.node', config: { minimalSeverity: 'verbose' } }); + logConfig.push({ category: 'Profiler.op', config: { minimalSeverity: 'verbose' } }); + logConfig.push({ category: 'Profiler.backend', config: { minimalSeverity: 'verbose' } }); logLevel = 'verbose'; } @@ -431,25 +436,25 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs // --wasm.<...>=<...> // --webgl.<...>=<...> // --webgpu.<...>=<...> - const globalEnvFlags = {...parseGlobalEnvFlags(args), debug, trace, logLevel}; + const globalEnvFlags = { ...parseGlobalEnvFlags(args), debug, trace, logLevel }; // Option: -P[=<...>], --perf[=<...>] - const perfArg = (args.perf || args.P); + const perfArg = args.perf || args.P; const perf = perfArg ? true : false; - const times = (typeof perfArg === 'number') ? perfArg : 10; + const times = typeof perfArg === 'number' ? perfArg : 10; if (debug && perf) { throw new Error('Flag "perf" cannot be used together with flag "debug".'); } - if (perf && (mode !== 'model')) { + if (perf && mode !== 'model') { throw new Error('Flag "perf" can only be used in mode "model".'); } if (perf) { - logConfig.push({category: 'TestRunner.Perf', config: {minimalSeverity: 'verbose'}}); + logConfig.push({ category: 'TestRunner.Perf', config: { minimalSeverity: 'verbose' } }); } // Option: -i=<...>, --io-binding=<...> const ioBindingArg = args['io-binding'] || args.i; - const ioBindingMode = (typeof ioBindingArg !== 'string') ? 'none' : ioBindingArg; + const ioBindingMode = typeof ioBindingArg !== 'string' ? 'none' : ioBindingArg; if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) { throw new Error(`not supported io binding mode ${ioBindingMode}`); } @@ -462,8 +467,10 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs // Option: -o, --graph-optimization-level const graphOptimizationLevel = args['graph-optimization-level'] || args.o || 'all'; - if (typeof graphOptimizationLevel !== 'string' || - ['disabled', 'basic', 'extended', 'all'].indexOf(graphOptimizationLevel) === -1) { + if ( + typeof graphOptimizationLevel !== 'string' || + ['disabled', 'basic', 'extended', 'all'].indexOf(graphOptimizationLevel) === -1 + ) { throw new Error(`graph optimization level is invalid: ${graphOptimizationLevel}`); } @@ -492,7 +499,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs throw new Error(`Invalid command line arg: --chromium-flags: ${chromiumFlags}`); } - npmlog.verbose('TestRunnerCli.Init', ` Mode: ${mode}`); npmlog.verbose('TestRunnerCli.Init', ` Env: ${env}`); npmlog.verbose('TestRunnerCli.Init', ` Debug: ${debug}`); @@ -521,6 +527,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs globalEnvFlags, noSandbox, userDataDir, - chromiumFlags + chromiumFlags, }; } diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index fbde81524ccec..15df62b30e6c4 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -4,23 +4,23 @@ /* eslint-disable guard-for-in */ /* eslint-disable @typescript-eslint/no-use-before-define */ -import {spawnSync} from 'child_process'; +import { spawnSync } from 'child_process'; import * as fs from 'fs-extra'; -import {default as minimatch} from 'minimatch'; +import { default as minimatch } from 'minimatch'; import npmlog from 'npmlog'; import * as os from 'os'; import * as path from 'path'; -import {inspect} from 'util'; +import { inspect } from 'util'; -import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; -import {bufferToBase64} from '../test/test-shared'; -import {Test} from '../test/test-types'; +import { onnx } from '../lib/onnxjs/ort-schema/protobuf/onnx'; +import { bufferToBase64 } from '../test/test-shared'; +import { Test } from '../test/test-types'; -import {parseTestRunnerCliArgs, TestRunnerCliArgs} from './test-runner-cli-args'; +import { parseTestRunnerCliArgs, TestRunnerCliArgs } from './test-runner-cli-args'; async function main() { // use dynamic import so that we can use ESM only libraries in commonJS. - const {globbySync} = await import('globby'); + const { globbySync } = await import('globby'); const stripJsonComments = (await import('strip-json-comments')).default; npmlog.info('TestRunnerCli', 'Initializing...'); @@ -41,29 +41,30 @@ async function main() { npmlog.verbose('TestRunnerCli.Init', 'Ensure test data folder... DONE'); let testlist: Test.TestList; - const shouldLoadSuiteTestData = (args.mode === 'suite0' || args.mode === 'suite1'); + const shouldLoadSuiteTestData = args.mode === 'suite0' || args.mode === 'suite1'; if (shouldLoadSuiteTestData) { npmlog.verbose('TestRunnerCli.Init', 'Loading testlist...'); // The following is a list of unittests for already implemented operators. // Modify this list to control what node tests to run. const jsonWithComments = fs.readFileSync(path.resolve(TEST_ROOT, './suite-test-list.jsonc')).toString(); - const json = stripJsonComments(jsonWithComments, {whitespace: true}); + const json = stripJsonComments(jsonWithComments, { whitespace: true }); testlist = JSON.parse(json) as Test.TestList; npmlog.verbose('TestRunnerCli.Init', 'Loading testlist... DONE'); } // The default backends and opset version lists. Those will be used in suite tests. const DEFAULT_BACKENDS: readonly TestRunnerCliArgs.Backend[] = - args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl', 'webgpu', 'webnn']; - const DEFAULT_OPSET_VERSIONS = fs.readdirSync(TEST_DATA_MODEL_NODE_ROOT, {withFileTypes: true}) - .filter(dir => dir.isDirectory() && dir.name.startsWith('opset')) - .map(dir => dir.name.slice(5)); - const MAX_OPSET_VERSION = Math.max(...DEFAULT_OPSET_VERSIONS.map(v => Number.parseInt(v, 10))); - - const FILE_CACHE_ENABLED = args.fileCache; // whether to enable file cache - const FILE_CACHE_MAX_FILE_SIZE = 1 * 1024 * 1024; // The max size of the file that will be put into file cache - const FILE_CACHE_SPLIT_SIZE = 4 * 1024 * 1024; // The min size of the cache file + args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl', 'webgpu', 'webnn']; + const DEFAULT_OPSET_VERSIONS = fs + .readdirSync(TEST_DATA_MODEL_NODE_ROOT, { withFileTypes: true }) + .filter((dir) => dir.isDirectory() && dir.name.startsWith('opset')) + .map((dir) => dir.name.slice(5)); + const MAX_OPSET_VERSION = Math.max(...DEFAULT_OPSET_VERSIONS.map((v) => Number.parseInt(v, 10))); + + const FILE_CACHE_ENABLED = args.fileCache; // whether to enable file cache + const FILE_CACHE_MAX_FILE_SIZE = 1 * 1024 * 1024; // The max size of the file that will be put into file cache + const FILE_CACHE_SPLIT_SIZE = 4 * 1024 * 1024; // The min size of the cache file const fileCache: Test.FileCache = {}; const nodeTests = new Map(); @@ -74,16 +75,13 @@ async function main() { npmlog.verbose('TestRunnerCli.Init', 'Loading test groups for suite test...'); // collect all model test folders - const allNodeTestsFolders = - DEFAULT_OPSET_VERSIONS - .map(version => { - const suiteRootFolder = path.join(TEST_DATA_MODEL_NODE_ROOT, `opset${version}`); - if (!fs.existsSync(suiteRootFolder) || !fs.statSync(suiteRootFolder).isDirectory()) { - throw new Error(`model test root folder '${suiteRootFolder}' does not exist.`); - } - return fs.readdirSync(suiteRootFolder).map(f => `opset${version}/${f}`); - }) - .flat(); + const allNodeTestsFolders = DEFAULT_OPSET_VERSIONS.map((version) => { + const suiteRootFolder = path.join(TEST_DATA_MODEL_NODE_ROOT, `opset${version}`); + if (!fs.existsSync(suiteRootFolder) || !fs.statSync(suiteRootFolder).isDirectory()) { + throw new Error(`model test root folder '${suiteRootFolder}' does not exist.`); + } + return fs.readdirSync(suiteRootFolder).map((f) => `opset${version}/${f}`); + }).flat(); for (const backend of DEFAULT_BACKENDS) { if (args.backends.indexOf(backend) !== -1) { @@ -111,8 +109,8 @@ async function main() { case 'suite1': for (const backend of DEFAULT_BACKENDS) { if (args.backends.indexOf(backend) !== -1) { - modelTestGroups.push(...nodeTests.get(backend)!); // model test : node - opTestGroups.push(...opTests.get(backend)!); // operator test + modelTestGroups.push(...nodeTests.get(backend)!); // model test : node + opTestGroups.push(...opTests.get(backend)!); // operator test } } if (args.mode === 'suite0') { @@ -122,12 +120,15 @@ async function main() { case 'model': if (!args.param) { - throw new Error('the test folder should be specified in mode \'node\''); + throw new Error("the test folder should be specified in mode 'node'"); } else { const testFolderSearchPattern = args.param; const testFolder = tryLocateModelTestFolder(testFolderSearchPattern); for (const b of args.backends) { - modelTestGroups.push({name: testFolder, tests: [modelTestFromFolder(testFolder, b, undefined, args.times)]}); + modelTestGroups.push({ + name: testFolder, + tests: [modelTestFromFolder(testFolder, b, undefined, args.times)], + }); } } break; @@ -138,7 +139,7 @@ async function main() { case 'op': if (!args.param) { - throw new Error('the test manifest should be specified in mode \'op\''); + throw new Error("the test manifest should be specified in mode 'op'"); } else { const manifestFileSearchPattern = args.param; const manifestFile = tryLocateOpTestManifest(manifestFileSearchPattern); @@ -161,15 +162,17 @@ async function main() { log: args.logConfig, profile: args.profile, options: { - sessionOptions: - {graphOptimizationLevel: args.graphOptimizationLevel, optimizedModelFilePath: args.optimizedModelFilePath}, + sessionOptions: { + graphOptimizationLevel: args.graphOptimizationLevel, + optimizedModelFilePath: args.optimizedModelFilePath, + }, debug: args.debug, cpuOptions: args.cpuOptions, webglOptions: args.webglOptions, webnnOptions: args.webnnOptions, wasmOptions: args.wasmOptions, - globalEnvFlags: args.globalEnvFlags - } + globalEnvFlags: args.globalEnvFlags, + }, }); npmlog.info('TestRunnerCli', 'Tests completed successfully'); @@ -181,11 +184,12 @@ async function main() { const testCaseName = typeof testCase === 'string' ? testCase : testCase.name; let found = false; for (const testGroup of nodeTest) { - found ||= minimatch - .match( - testGroup.tests.map(test => test.modelUrl).filter(url => url !== ''), - path.join('**', testCaseName, '*.+(onnx|ort)').replace(/\\/g, '/'), {matchBase: true}) - .length > 0; + found ||= + minimatch.match( + testGroup.tests.map((test) => test.modelUrl).filter((url) => url !== ''), + path.join('**', testCaseName, '*.+(onnx|ort)').replace(/\\/g, '/'), + { matchBase: true }, + ).length > 0; } if (!found) { throw new Error(`node model test case '${testCaseName}' in test list does not exist.`); @@ -195,7 +199,7 @@ async function main() { const onnxTest = onnxTests.get(backend); if (onnxTest) { - const onnxModelTests = onnxTest.tests.map(i => i.name); + const onnxModelTests = onnxTest.tests.map((i) => i.name); for (const testCase of testlist[backend].onnx) { const testCaseName = typeof testCase === 'string' ? testCase : testCase.name; if (onnxModelTests.indexOf(testCaseName) === -1) { @@ -206,7 +210,7 @@ async function main() { const opTest = opTests.get(backend); if (opTest) { - const opTests = opTest.map(i => i.name); + const opTests = opTest.map((i) => i.name); for (const testCase of testlist[backend].ops) { const testCaseName = typeof testCase === 'string' ? testCase : testCase.name; if (opTests.indexOf(testCaseName) === -1) { @@ -221,14 +225,14 @@ async function main() { const allTests = testlist[backend]?.node; // key is folder name, value is test index array - const folderTestMatchCount = new Map(allFolders.map(f => [f, []])); + const folderTestMatchCount = new Map(allFolders.map((f) => [f, []])); // key is test category, value is a list of model test const opsetTests = new Map(); allTests.forEach((test, i) => { const testName = typeof test === 'string' ? test : test.name; const matches = minimatch.match(allFolders, path.join('**', testName).replace(/\\/g, '/')); - matches.forEach(m => folderTestMatchCount.get(m)!.push(i)); + matches.forEach((m) => folderTestMatchCount.get(m)!.push(i)); }); for (const folder of allFolders) { @@ -249,23 +253,33 @@ async function main() { opsetTests.set(category, modelTests); } modelTests.push( - modelTestFromFolder(path.resolve(TEST_DATA_MODEL_NODE_ROOT, folder), backend, platformCondition, times)); + modelTestFromFolder(path.resolve(TEST_DATA_MODEL_NODE_ROOT, folder), backend, platformCondition, times), + ); } - return Array.from(opsetTests.keys()).map(category => ({name: category, tests: opsetTests.get(category)!})); + return Array.from(opsetTests.keys()).map((category) => ({ name: category, tests: opsetTests.get(category)! })); } function modelTestFromFolder( - testDataRootFolder: string, backend: string, platformCondition?: Test.PlatformCondition, - times?: number): Test.ModelTest { + testDataRootFolder: string, + backend: string, + platformCondition?: Test.PlatformCondition, + times?: number, + ): Test.ModelTest { if (times === 0) { npmlog.verbose('TestRunnerCli.Init.Model', `Skip test data from folder: ${testDataRootFolder}`); - return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: [], ioBinding: args.ioBindingMode}; + return { + name: path.basename(testDataRootFolder), + backend, + modelUrl: '', + cases: [], + ioBinding: args.ioBindingMode, + }; } - let modelUrl: string|null = null; + let modelUrl: string | null = null; let cases: Test.ModelTestCase[] = []; - let externalData: Array<{data: string; path: string}>|undefined; + let externalData: Array<{ data: string; path: string }> | undefined; npmlog.verbose('TestRunnerCli.Init.Model', `Start to prepare test data from folder: ${testDataRootFolder}`); @@ -297,14 +311,17 @@ async function main() { if (ext.toLowerCase() === '.pb') { const dataFileUrl = path.join(TEST_DATA_BASE, path.relative(TEST_ROOT, dataFileFullPath)); dataFiles.push(dataFileUrl); - if (FILE_CACHE_ENABLED && !fileCache[dataFileUrl] && - fs.lstatSync(dataFileFullPath).size <= FILE_CACHE_MAX_FILE_SIZE) { + if ( + FILE_CACHE_ENABLED && + !fileCache[dataFileUrl] && + fs.lstatSync(dataFileFullPath).size <= FILE_CACHE_MAX_FILE_SIZE + ) { fileCache[dataFileUrl] = bufferToBase64(fs.readFileSync(dataFileFullPath)); } } } if (dataFiles.length > 0) { - cases.push({dataFiles, name: thisPath}); + cases.push({ dataFiles, name: thisPath }); } } } @@ -318,8 +335,9 @@ async function main() { // (e.g., model file is "model_abc.onnx", and there is a file "model_abc.pb" or "model_abc.onnx.data") // 2. the file size is larger than 1GB const likelyToHaveExternalData = maybeExternalDataFiles.some( - ([fileNameWithoutExtension, size]) => - path.basename(modelUrl!).startsWith(fileNameWithoutExtension) || size >= 1 * 1024 * 1024 * 1024); + ([fileNameWithoutExtension, size]) => + path.basename(modelUrl!).startsWith(fileNameWithoutExtension) || size >= 1 * 1024 * 1024 * 1024, + ); if (likelyToHaveExternalData) { const model = onnx.ModelProto.decode(fs.readFileSync(path.join(testDataRootFolder, path.basename(modelUrl!)))); const externalDataPathSet = new Set(); @@ -337,7 +355,7 @@ async function main() { for (const dataPath of externalDataPaths) { const fullPath = path.resolve(testDataRootFolder, dataPath); const url = path.join(TEST_DATA_BASE, path.relative(TEST_ROOT, fullPath)); - externalData.push({data: url, path: dataPath}); + externalData.push({ data: url, path: dataPath }); } } } catch (e) { @@ -350,7 +368,10 @@ async function main() { if (times > caseCount) { for (let i = 0; cases.length < times; i++) { const origin = cases[i % caseCount]; - const duplicated = {name: `${origin.name} - copy ${Math.floor(i / caseCount)}`, dataFiles: origin.dataFiles}; + const duplicated = { + name: `${origin.name} - copy ${Math.floor(i / caseCount)}`, + dataFiles: origin.dataFiles, + }; cases.push(duplicated); } } else { @@ -361,13 +382,14 @@ async function main() { let ioBinding: Test.IOBindingMode; if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { npmlog.warn( - 'TestRunnerCli.Init.Model', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + 'TestRunnerCli.Init.Model', + `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`, + ); ioBinding = 'none'; } else { ioBinding = args.ioBindingMode; } - npmlog.verbose('TestRunnerCli.Init.Model', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); @@ -388,7 +410,7 @@ async function main() { backend, cases, ioBinding, - externalData + externalData, }; } @@ -401,17 +423,22 @@ async function main() { // 2 - check the globby result of searchPattern // 3 - check the globby result of ONNX root combined with searchPattern - const globbyPattern = - [searchPattern, path.join(TEST_DATA_MODEL_NODE_ROOT, '**', searchPattern).replace(/\\/g, '/')]; + const globbyPattern = [ + searchPattern, + path.join(TEST_DATA_MODEL_NODE_ROOT, '**', searchPattern).replace(/\\/g, '/'), + ]; // 4 - check the globby result of NODE root combined with opset versions and searchPattern - globbyPattern.push(...DEFAULT_OPSET_VERSIONS.map( - v => path.join(TEST_DATA_MODEL_NODE_ROOT, `opset${v}`, '**', searchPattern).replace(/\\/g, '/'))); + globbyPattern.push( + ...DEFAULT_OPSET_VERSIONS.map((v) => + path.join(TEST_DATA_MODEL_NODE_ROOT, `opset${v}`, '**', searchPattern).replace(/\\/g, '/'), + ), + ); - folderCandidates.push(...globbySync(globbyPattern, {onlyDirectories: true, absolute: true})); + folderCandidates.push(...globbySync(globbyPattern, { onlyDirectories: true, absolute: true })); // pick the first folder that matches the pattern for (const folderCandidate of folderCandidates) { - const modelCandidates = globbySync('*.{onnx,ort}', {onlyFiles: true, cwd: folderCandidate}); + const modelCandidates = globbySync('*.{onnx,ort}', { onlyFiles: true, cwd: folderCandidate }); if (modelCandidates && modelCandidates.length === 1) { return folderCandidate; } @@ -443,15 +470,17 @@ async function main() { } else { npmlog.verbose('TestRunnerCli.Init.Op', `Start to prepare test data from manifest file: ${filePath}`); const jsonWithComments = fs.readFileSync(filePath).toString(); - const json = stripJsonComments(jsonWithComments, {whitespace: true}); + const json = stripJsonComments(jsonWithComments, { whitespace: true }); tests = JSON.parse(json) as Test.OperatorTest[]; // field 'verbose' and 'backend' is not set for (const test of tests) { test.backend = backend; - test.opset = test.opset || {domain: '', version: MAX_OPSET_VERSION}; + test.opset = test.opset || { domain: '', version: MAX_OPSET_VERSION }; if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { npmlog.warn( - 'TestRunnerCli.Init.Op', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + 'TestRunnerCli.Init.Op', + `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`, + ); test.ioBinding = 'none'; } else { test.ioBinding = args.ioBindingMode; @@ -464,17 +493,19 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Op', ` Test case(s): ${tests.length}`); npmlog.verbose('TestRunnerCli.Init.Op', '==============================================================='); } - return {name: path.relative(TEST_DATA_OP_ROOT, filePath), tests}; + return { name: path.relative(TEST_DATA_OP_ROOT, filePath), tests }; } function tryLocateOpTestManifest(searchPattern: string): string { for (const manifestCandidate of globbySync( - [ - searchPattern, path.join(TEST_DATA_OP_ROOT, '**', searchPattern).replace(/\\/g, '/'), - path.join(TEST_DATA_OP_ROOT, '**', searchPattern + '.json').replace(/\\/g, '/'), - path.join(TEST_DATA_OP_ROOT, '**', searchPattern + '.jsonc').replace(/\\/g, '/') - ], - {onlyFiles: true, absolute: true, cwd: TEST_ROOT})) { + [ + searchPattern, + path.join(TEST_DATA_OP_ROOT, '**', searchPattern).replace(/\\/g, '/'), + path.join(TEST_DATA_OP_ROOT, '**', searchPattern + '.json').replace(/\\/g, '/'), + path.join(TEST_DATA_OP_ROOT, '**', searchPattern + '.jsonc').replace(/\\/g, '/'), + ], + { onlyFiles: true, absolute: true, cwd: TEST_ROOT }, + )) { return manifestCandidate; } @@ -489,9 +520,11 @@ async function main() { config.fileCacheUrls = fileCacheUrls; } npmlog.info( - 'TestRunnerCli.Run', - `(1/4) Writing file cache to file: testdata-file-cache-*.json ... ${ - fileCacheUrls.length > 0 ? `DONE, ${fileCacheUrls.length} file(s) generated` : 'SKIPPED'}`); + 'TestRunnerCli.Run', + `(1/4) Writing file cache to file: testdata-file-cache-*.json ... ${ + fileCacheUrls.length > 0 ? `DONE, ${fileCacheUrls.length} file(s) generated` : 'SKIPPED' + }`, + ); // STEP 2. write the config to testdata-config.json npmlog.info('TestRunnerCli.Run', '(2/4) Writing config to file: testdata-config.json ...'); @@ -503,7 +536,7 @@ async function main() { const buildCommand = `node ${path.join(__dirname, 'build')}`; const buildArgs = [`--bundle-mode=${args.env === 'node' ? 'node' : args.bundleMode}`]; npmlog.info('TestRunnerCli.Run', `CMD: ${buildCommand} ${buildArgs.join(' ')}`); - const build = spawnSync(buildCommand, buildArgs, {shell: true, stdio: 'inherit'}); + const build = spawnSync(buildCommand, buildArgs, { shell: true, stdio: 'inherit' }); if (build.status !== 0) { console.error(build.error); process.exit(build.status === null ? undefined : build.status); @@ -513,7 +546,7 @@ async function main() { if (args.env === 'node') { // STEP 5. run tsc and run mocha npmlog.info('TestRunnerCli.Run', '(4/4) Running tsc...'); - const tsc = spawnSync('npx', ['tsc'], {shell: true, stdio: 'inherit'}); + const tsc = spawnSync('npx', ['tsc'], { shell: true, stdio: 'inherit' }); if (tsc.status !== 0) { console.error(tsc.error); process.exit(tsc.status === null ? undefined : tsc.status); @@ -530,13 +563,12 @@ async function main() { path.join(TEST_ROOT, 'test-main'), ]; npmlog.info('TestRunnerCli.Run', `CMD: npx ${mochaArgs.join(' ')}`); - const mocha = spawnSync('npx', mochaArgs, {shell: true, stdio: 'inherit'}); + const mocha = spawnSync('npx', mochaArgs, { shell: true, stdio: 'inherit' }); if (mocha.status !== 0) { console.error(mocha.error); process.exit(mocha.status === null ? undefined : mocha.status); } npmlog.info('TestRunnerCli.Run', '(4/4) Running mocha... DONE'); - } else { // STEP 5. use Karma to run test npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...'); @@ -578,7 +610,7 @@ async function main() { if (args.userDataDir) { karmaArgs.push(`--user-data-dir="${args.userDataDir}"`); } - karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); + karmaArgs.push(...chromiumFlags.map((flag) => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { // There are currently 2 Edge browser launchers: // - karma-edge-launcher: used to launch the old Edge browser @@ -593,13 +625,16 @@ async function main() { // - remove "karma-edge-launcher". // check if we have the latest Edge installed: - if (os.platform() === 'darwin' || - (os.platform() === 'win32' && - require('@chiragrupani/karma-chromium-edge-launcher/dist/Utilities').default.GetEdgeExe('Edge') !== '')) { + if ( + os.platform() === 'darwin' || + (os.platform() === 'win32' && + require('@chiragrupani/karma-chromium-edge-launcher/dist/Utilities').default.GetEdgeExe('Edge') !== '') + ) { // use "@chiragrupani/karma-chromium-edge-launcher" karmaArgs.push( - '--karma-plugins=@chiragrupani/karma-chromium-edge-launcher', - '--karma-plugins=(?!karma-edge-launcher$)karma-*'); + '--karma-plugins=@chiragrupani/karma-chromium-edge-launcher', + '--karma-plugins=(?!karma-edge-launcher$)karma-*', + ); } else { // use "karma-edge-launcher" @@ -622,14 +657,14 @@ async function main() { // delete the files stores in the specific folder to clean up the recovery page list. // see also: https://www.laptopmag.com/articles/edge-browser-stop-tab-restore const deleteEdgeActiveRecoveryCommand = - // eslint-disable-next-line max-len - 'del /F /Q % LOCALAPPDATA %\\Packages\\Microsoft.MicrosoftEdge_8wekyb3d8bbwe\\AC\\MicrosoftEdge\\User\\Default\\Recovery\\Active\\*'; + // eslint-disable-next-line max-len + 'del /F /Q % LOCALAPPDATA %\\Packages\\Microsoft.MicrosoftEdge_8wekyb3d8bbwe\\AC\\MicrosoftEdge\\User\\Default\\Recovery\\Active\\*'; npmlog.info('TestRunnerCli.Run', `CMD: ${deleteEdgeActiveRecoveryCommand}`); - spawnSync(deleteEdgeActiveRecoveryCommand, {shell: true, stdio: 'inherit'}); + spawnSync(deleteEdgeActiveRecoveryCommand, { shell: true, stdio: 'inherit' }); } } npmlog.info('TestRunnerCli.Run', `CMD: npx ${karmaArgs.join(' ')}`); - const karma = spawnSync('npx', karmaArgs, {shell: true, stdio: 'inherit'}); + const karma = spawnSync('npx', karmaArgs, { shell: true, stdio: 'inherit' }); if (karma.status !== 0) { console.error(karma.error); process.exit(karma.status === null ? undefined : karma.status); diff --git a/js/web/test/e2e/browser-test-wasm-binary-override.js b/js/web/test/e2e/browser-test-wasm-binary-override.js index 35d427fa3b722..471c26f6990b5 100644 --- a/js/web/test/e2e/browser-test-wasm-binary-override.js +++ b/js/web/test/e2e/browser-test-wasm-binary-override.js @@ -5,7 +5,7 @@ const documentUrl = document.currentScript.src; -it('Browser E2E testing - WebAssembly backend', async function() { +it('Browser E2E testing - WebAssembly backend', async function () { // preload .wasm file binary const wasmUrl = new URL('./node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.wasm', documentUrl).href; const response = await fetch(wasmUrl); @@ -18,5 +18,5 @@ it('Browser E2E testing - WebAssembly backend', async function() { const binary = await response.arrayBuffer(); ort.env.wasm.wasmBinary = binary; - await testFunction(ort, {executionProviders: ['wasm']}); + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/browser-test-wasm-image-tensor-image.js b/js/web/test/e2e/browser-test-wasm-image-tensor-image.js index f82fa48fad3ff..c34e571c7445e 100644 --- a/js/web/test/e2e/browser-test-wasm-image-tensor-image.js +++ b/js/web/test/e2e/browser-test-wasm-image-tensor-image.js @@ -3,12 +3,14 @@ 'use strict'; -const IMAGE_HEIGHT = 20 -const IMAGE_WIDTH = 15 +const IMAGE_HEIGHT = 20; +const IMAGE_WIDTH = 15; function getRndColor() { - let r = 255 * Math.random() | 0, g = 255 * Math.random() | 0, b = 255 * Math.random() | 0, - a = 255 * Math.random() | 0; + let r = (255 * Math.random()) | 0, + g = (255 * Math.random()) | 0, + b = (255 * Math.random()) | 0, + a = (255 * Math.random()) | 0; return 'rgb(' + r + ',' + g + ',' + b + ',' + a + ')'; } @@ -30,7 +32,7 @@ function compareTensors(tensorA, tensorB, msg) { // - the test is composed by 3 different test cases. split them to 3 different cases. // - some test cases are wriiten incorrectly. // -it('Browser E2E testing - Tensor <--> Image E2E test', async function() { +it('Browser E2E testing - Tensor <--> Image E2E test', async function () { // Creating Image HTML Image Element let img = new Image(); img.crossOrigin = 'Anonymous'; @@ -54,15 +56,16 @@ it('Browser E2E testing - Tensor <--> Image E2E test', async function() { img.src = canvas.toDataURL(); // Testing HTML Image Element --> Tensor --> ImageData --> Tensor - img.onload = - async () => { + img.onload = async () => { // Image HTML element to tensor API - HTML - const inputTensorHTML = await ort.Tensor.fromImage(img, {norm: {bias: [2, 3, 9, 5], mean: [5, 6, 17, 8]}}); + const inputTensorHTML = await ort.Tensor.fromImage(img, { norm: { bias: [2, 3, 9, 5], mean: [5, 6, 17, 8] } }); // Tensor to ImageDAta API - let newImage = inputTensorHTML.toImageData({norm: {bias: [2 / 5, 3 / 6, 9 / 17, 5 / 8], mean: [5, 6, 17, 8]}}); + let newImage = inputTensorHTML.toImageData({ norm: { bias: [2 / 5, 3 / 6, 9 / 17, 5 / 8], mean: [5, 6, 17, 8] } }); // ImageData to tensor API - let inputTensorImageData = - await ort.Tensor.fromImage(newImage, options = {norm: {bias: [2, 3, 9, 5], mean: [5, 6, 17, 8]}}); + let inputTensorImageData = await ort.Tensor.fromImage( + newImage, + (options = { norm: { bias: [2, 3, 9, 5], mean: [5, 6, 17, 8] } }), + ); // TODO: fix this test case // @@ -71,20 +74,24 @@ it('Browser E2E testing - Tensor <--> Image E2E test', async function() { // is not executed. to fix this, wrap a try-catch to deal with exceptions. compareTensors(inputTensorHTML, inputTensorImageData, 'BUG in HTML image element & ImageData use case'); - } + }; // Copying the canavas data to the image as Data URL let image = canvas.toDataURL(); // Testing Data URL --> Tensor --> Data URL --> Tensor // Data URL to tensor API - - const inputTensorDataURL = - await ort.Tensor.fromImage(image, {format: 'RBG', norm: {bias: [1, 10, 5, 0], mean: [5, 7, 11, 0]}}); + const inputTensorDataURL = await ort.Tensor.fromImage(image, { + format: 'RBG', + norm: { bias: [1, 10, 5, 0], mean: [5, 7, 11, 0] }, + }); // Tensor to ImageDAta API - let newImage = inputTensorDataURL.toDataURL({norm: {bias: [1 / 5, 10 / 7, 5 / 11, 0], mean: [5, 7, 11, 0]}}); + let newImage = inputTensorDataURL.toDataURL({ norm: { bias: [1 / 5, 10 / 7, 5 / 11, 0], mean: [5, 7, 11, 0] } }); // ImageData to tensor API - let inputTensorImageData = - await ort.Tensor.fromImage(newImage, {format: 'RGBA', norm: {bias: [1, 10, 5, 0], mean: [5, 7, 11, 0]}}); + let inputTensorImageData = await ort.Tensor.fromImage(newImage, { + format: 'RGBA', + norm: { bias: [1, 10, 5, 0], mean: [5, 7, 11, 0] }, + }); // TODO: fix this // creating tensor from image data should not depend on `options.format`. @@ -97,17 +104,22 @@ it('Browser E2E testing - Tensor <--> Image E2E test', async function() { if (online) { // URL element to tensor API const inputTensorURL = await ort.Tensor.fromImage( - 'https://media.istockphoto.com/id/172859087/photo/square-eggs.jpg?s=2048x2048&w=is&k=20&c=KiBRyyYaoUUSjcJLBh1-qqVu7LW6UQZBopZdva0f5e4=', - {norm: {bias: [2, 3, 9, 0], mean: [5, 6, 17, 0]}}); + 'https://media.istockphoto.com/id/172859087/photo/square-eggs.jpg?s=2048x2048&w=is&k=20&c=KiBRyyYaoUUSjcJLBh1-qqVu7LW6UQZBopZdva0f5e4=', + { norm: { bias: [2, 3, 9, 0], mean: [5, 6, 17, 0] } }, + ); // Tensor to ImageDAta API - let newImage = - inputTensorURL.toImageData({format: 'RGB', norm: {bias: [2 / 5, 3 / 6, 9 / 17, 0], mean: [5, 6, 17, 0]}}); + let newImage = inputTensorURL.toImageData({ + format: 'RGB', + norm: { bias: [2 / 5, 3 / 6, 9 / 17, 0], mean: [5, 6, 17, 0] }, + }); // ImageData to tensor API - let inputTensorImageData = - await ort.Tensor.fromImage(newImage, {format: 'RGB', norm: {bias: [2, 3, 9, 0], mean: [5, 6, 17, 0]}}); + let inputTensorImageData = await ort.Tensor.fromImage(newImage, { + format: 'RGB', + norm: { bias: [2, 3, 9, 0], mean: [5, 6, 17, 0] }, + }); compareTensors(inputTensorURL, inputTensorImageData, 'BUG in ImageData & URL'); } else { - console.log('No internet connection - didn\'t test Image URL to tensor API'); + console.log("No internet connection - didn't test Image URL to tensor API"); } }); diff --git a/js/web/test/e2e/browser-test-wasm-multi-session-create.js b/js/web/test/e2e/browser-test-wasm-multi-session-create.js index 5efc3e712f2ed..1ac7a99b52ceb 100644 --- a/js/web/test/e2e/browser-test-wasm-multi-session-create.js +++ b/js/web/test/e2e/browser-test-wasm-multi-session-create.js @@ -3,7 +3,7 @@ 'use strict'; -it('Browser E2E testing - WebAssembly backend (multiple inference session create calls)', async function() { +it('Browser E2E testing - WebAssembly backend (multiple inference session create calls)', async function () { const sessionPromiseA = createSession(ort); const sessionPromiseB = createSession(ort); await Promise.all([sessionPromiseA, sessionPromiseB]); diff --git a/js/web/test/e2e/browser-test-wasm-path-override-filename.js b/js/web/test/e2e/browser-test-wasm-path-override-filename.js index a6f25548b1433..d2647f03980be 100644 --- a/js/web/test/e2e/browser-test-wasm-path-override-filename.js +++ b/js/web/test/e2e/browser-test-wasm-path-override-filename.js @@ -3,7 +3,7 @@ 'use strict'; -it('Browser E2E testing - WebAssembly backend (path override filename)', async function() { +it('Browser E2E testing - WebAssembly backend (path override filename)', async function () { // check base URL port from test args if (typeof __ort_arg_port === 'undefined') { throw new Error('test flag --port= is required'); @@ -24,5 +24,5 @@ it('Browser E2E testing - WebAssembly backend (path override filename)', async f ort.env.wasm.wasmPaths.mjs = overrideMjsUrl; } - await testFunction(ort, {executionProviders: ['wasm']}); + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/browser-test-wasm-path-override-prefix.js b/js/web/test/e2e/browser-test-wasm-path-override-prefix.js index 7a905fbd9d8b9..0b42335883852 100644 --- a/js/web/test/e2e/browser-test-wasm-path-override-prefix.js +++ b/js/web/test/e2e/browser-test-wasm-path-override-prefix.js @@ -3,7 +3,7 @@ 'use strict'; -it('Browser E2E testing - WebAssembly backend (path override prefix)', async function() { +it('Browser E2E testing - WebAssembly backend (path override prefix)', async function () { // check base URL port from test args if (typeof __ort_arg_port === 'undefined') { throw new Error('test flag --port= is required'); @@ -15,5 +15,5 @@ it('Browser E2E testing - WebAssembly backend (path override prefix)', async fun console.log(`ort.env.wasm.wasmPaths = ${JSON.stringify(prefix)};`); ort.env.wasm.wasmPaths = prefix; - await testFunction(ort, {executionProviders: ['wasm']}); + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/browser-test-wasm.js b/js/web/test/e2e/browser-test-wasm.js index 8e4f500d16749..dec40f95b16c3 100644 --- a/js/web/test/e2e/browser-test-wasm.js +++ b/js/web/test/e2e/browser-test-wasm.js @@ -3,6 +3,6 @@ 'use strict'; -it('Browser E2E testing - WebAssembly backend', async function() { - await testFunction(ort, {executionProviders: ['wasm']}); +it('Browser E2E testing - WebAssembly backend', async function () { + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/browser-test-webgl.js b/js/web/test/e2e/browser-test-webgl.js index 974c81d064c89..ff09efcdef258 100644 --- a/js/web/test/e2e/browser-test-webgl.js +++ b/js/web/test/e2e/browser-test-webgl.js @@ -3,14 +3,15 @@ 'use strict'; -it('Browser E2E testing - WebGL backend', async function() { - await testFunction(ort, {executionProviders: ['webgl']}); +it('Browser E2E testing - WebGL backend', async function () { + await testFunction(ort, { executionProviders: ['webgl'] }); }); it('Browser E2E testing - invalid buffer', async () => { try { - await ort.InferenceSession.create( - new Uint8Array(Array.from({length: 100}, () => 42)), {executionProviders: ['webgl']}); + await ort.InferenceSession.create(new Uint8Array(Array.from({ length: 100 }, () => 42)), { + executionProviders: ['webgl'], + }); // Should not reach here. assert(false); diff --git a/js/web/test/e2e/browser-test-webgpu-external-data.js b/js/web/test/e2e/browser-test-webgpu-external-data.js index 8fb0b4d6ec545..d293092b7245e 100644 --- a/js/web/test/e2e/browser-test-webgpu-external-data.js +++ b/js/web/test/e2e/browser-test-webgpu-external-data.js @@ -3,13 +3,13 @@ 'use strict'; -it('Browser E2E testing - WebGPU backend with external data', async function() { +it('Browser E2E testing - WebGPU backend with external data', async function () { const session = await ort.InferenceSession.create('./model_with_orig_ext_data.onnx', { executionProviders: ['webgpu'], - externalData: [{data: './model_with_orig_ext_data.bin', path: 'model_with_orig_ext_data.bin'}] + externalData: [{ data: './model_with_orig_ext_data.bin', path: 'model_with_orig_ext_data.bin' }], }); - const fetches = await session.run({X: new ort.Tensor('float32', [1, 1], [1, 2])}); + const fetches = await session.run({ X: new ort.Tensor('float32', [1, 1], [1, 2]) }); const Y = fetches.Y; diff --git a/js/web/test/e2e/bundler.esm.postprocess.js b/js/web/test/e2e/bundler.esm.postprocess.js index 8eadaf04e4121..c675da9bb8546 100644 --- a/js/web/test/e2e/bundler.esm.postprocess.js +++ b/js/web/test/e2e/bundler.esm.postprocess.js @@ -27,7 +27,7 @@ const content = fs.readFileSync(inputFilePath, 'utf8'); // replace all `"file://*/ort.*.mjs"` paths back to `import.meta.url`. Try to keep the same length to make source map // work. -const updatedContent = content.replace(/['"]file:\/\/.+?\/ort\..+?\.mjs['"]/g, match => { +const updatedContent = content.replace(/['"]file:\/\/.+?\/ort\..+?\.mjs['"]/g, (match) => { return 'import.meta.url'.padEnd(match.length, ' '); }); diff --git a/js/web/test/e2e/common.js b/js/web/test/e2e/common.js index c74a7d42a4b51..efaeca1833a92 100644 --- a/js/web/test/e2e/common.js +++ b/js/web/test/e2e/common.js @@ -12,7 +12,7 @@ function createSession(ort, options) { } function delay(ms) { - return new Promise(resolve => setTimeout(resolve, ms)); + return new Promise((resolve) => setTimeout(resolve, ms)); } async function testFunction(ort, options) { @@ -23,8 +23,10 @@ async function testFunction(ort, options) { const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); - const fetches = - await session.run({a: new ort.Tensor('float32', dataA, [3, 4]), b: new ort.Tensor('float32', dataB, [4, 3])}); + const fetches = await session.run({ + a: new ort.Tensor('float32', dataA, [3, 4]), + b: new ort.Tensor('float32', dataB, [4, 3]), + }); const c = fetches.c; diff --git a/js/web/test/e2e/common.mjs b/js/web/test/e2e/common.mjs index 53ba34445cf15..cd0d18bc6905e 100644 --- a/js/web/test/e2e/common.mjs +++ b/js/web/test/e2e/common.mjs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {createRequire} from 'module'; +import { createRequire } from 'module'; const require = createRequire(import.meta.url); const testFunction = require('./common'); diff --git a/js/web/test/e2e/karma.conf.js b/js/web/test/e2e/karma.conf.js index 70ebb136c1ae3..e6dadfaac248d 100644 --- a/js/web/test/e2e/karma.conf.js +++ b/js/web/test/e2e/karma.conf.js @@ -26,28 +26,31 @@ const testArgs = args['test-args']; const normalizedTestArgs = !testArgs || Array.isArray(testArgs) ? testArgs : [testArgs]; const files = [ - {pattern: './model.onnx', included: false}, - {pattern: './model_with_orig_ext_data.onnx', included: false}, - {pattern: './model_with_orig_ext_data.bin', included: false}, - {pattern: './test-wasm-path-override/*', included: false, nocache: true, watched: false}, + { pattern: './model.onnx', included: false }, + { pattern: './model_with_orig_ext_data.onnx', included: false }, + { pattern: './model_with_orig_ext_data.bin', included: false }, + { pattern: './test-wasm-path-override/*', included: false, nocache: true, watched: false }, ]; if (ORT_MAIN) { if (ORT_MAIN.endsWith('.mjs')) { - files.push( - {pattern: (SELF_HOST ? './esm-loaders/' : 'http://localhost:8081/esm-loaders/') + ORT_MAIN, type: 'module'}); + files.push({ + pattern: (SELF_HOST ? './esm-loaders/' : 'http://localhost:8081/esm-loaders/') + ORT_MAIN, + type: 'module', + }); } else { - files.push( - {pattern: (SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/') + ORT_MAIN}); + files.push({ + pattern: (SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/') + ORT_MAIN, + }); } } if (FORMAT === 'esm') { - files.push({pattern: TEST_MAIN, type: 'module'}); + files.push({ pattern: TEST_MAIN, type: 'module' }); } else { - files.push({pattern: './common.js'}, {pattern: TEST_MAIN}); + files.push({ pattern: './common.js' }, { pattern: TEST_MAIN }); } -files.push({pattern: './dist/**/*', included: false, nocache: true, watched: false}); +files.push({ pattern: './dist/**/*', included: false, nocache: true, watched: false }); if (SELF_HOST) { - files.push({pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true}); + files.push({ pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true }); } const flags = ['--ignore-gpu-blocklist', '--gpu-vendor-id=0x10de']; @@ -55,7 +58,7 @@ if (ENABLE_SHARED_ARRAY_BUFFER) { flags.push('--enable-features=SharedArrayBuffer'); } -module.exports = function(config) { +module.exports = function (config) { config.set({ frameworks: ['mocha'], files, @@ -66,7 +69,7 @@ module.exports = function(config) { '/model_with_orig_ext_data.bin': '/base/model_with_orig_ext_data.bin', '/test-wasm-path-override/': '/base/test-wasm-path-override/', }, - client: {captureConsole: true, args: normalizedTestArgs, mocha: {expose: ['body'], timeout: 60000}}, + client: { captureConsole: true, args: normalizedTestArgs, mocha: { expose: ['body'], timeout: 60000 } }, reporters: ['mocha'], captureTimeout: 120000, reportSlowerThan: 100, @@ -77,14 +80,14 @@ module.exports = function(config) { hostname: 'localhost', browsers: [], customLaunchers: { - Chrome_default: {base: 'Chrome', flags, chromeDataDir: USER_DATA}, + Chrome_default: { base: 'Chrome', flags, chromeDataDir: USER_DATA }, Chrome_no_threads: { base: 'Chrome', chromeDataDir: USER_DATA, - flags + flags, // TODO: no-thread flags }, - Edge_default: {base: 'Edge', edgeDataDir: USER_DATA} - } + Edge_default: { base: 'Edge', edgeDataDir: USER_DATA }, + }, }); }; diff --git a/js/web/test/e2e/node-test-main-no-threads.js b/js/web/test/e2e/node-test-main-no-threads.js index e586c68ca98a9..15182a197de4d 100644 --- a/js/web/test/e2e/node-test-main-no-threads.js +++ b/js/web/test/e2e/node-test-main-no-threads.js @@ -6,7 +6,7 @@ const ort = require('onnxruntime-web'); const testFunction = require('./common'); -it('Node.js E2E testing - WebAssembly backend (no threads)', async function() { +it('Node.js E2E testing - WebAssembly backend (no threads)', async function () { ort.env.wasm.numThreads = 1; - await testFunction(ort, {executionProviders: ['wasm']}); + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/node-test-main-no-threads.mjs b/js/web/test/e2e/node-test-main-no-threads.mjs index b8f50d6db6ae2..99edcd84b62bd 100644 --- a/js/web/test/e2e/node-test-main-no-threads.mjs +++ b/js/web/test/e2e/node-test-main-no-threads.mjs @@ -7,7 +7,7 @@ import * as ort from 'onnxruntime-web'; import testFunction from './common.mjs'; -it('Node.js E2E testing - WebAssembly backend[esm]', async function() { +it('Node.js E2E testing - WebAssembly backend[esm]', async function () { ort.env.wasm.numThreads = 1; - await testFunction(ort, {executionProviders: ['wasm']}); + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/node-test-main.js b/js/web/test/e2e/node-test-main.js index 2f1f8fdcf5ff5..320bdfdc325d2 100644 --- a/js/web/test/e2e/node-test-main.js +++ b/js/web/test/e2e/node-test-main.js @@ -6,6 +6,6 @@ const ort = require('onnxruntime-web'); const testFunction = require('./common'); -it('Node.js E2E testing - WebAssembly backend', async function() { - await testFunction(ort, {executionProviders: ['wasm']}); +it('Node.js E2E testing - WebAssembly backend', async function () { + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/node-test-main.mjs b/js/web/test/e2e/node-test-main.mjs index 11c126e9c817b..a55d4463ddf99 100644 --- a/js/web/test/e2e/node-test-main.mjs +++ b/js/web/test/e2e/node-test-main.mjs @@ -7,6 +7,6 @@ import * as ort from 'onnxruntime-web'; import testFunction from './common.mjs'; -it('Node.js E2E testing - WebAssembly backend[esm]', async function() { - await testFunction(ort, {executionProviders: ['wasm']}); +it('Node.js E2E testing - WebAssembly backend[esm]', async function () { + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/node-test-wasm-path-override-filename.js b/js/web/test/e2e/node-test-wasm-path-override-filename.js index bd9baf6e68dd4..772096d08ae81 100644 --- a/js/web/test/e2e/node-test-wasm-path-override-filename.js +++ b/js/web/test/e2e/node-test-wasm-path-override-filename.js @@ -6,14 +6,14 @@ const path = require('path'); const ort = require('onnxruntime-web'); const testFunction = require('./common'); -const {pathToFileURL} = require('url') +const { pathToFileURL } = require('url'); -it('Node.js E2E testing - WebAssembly backend (path override filename)', async function() { +it('Node.js E2E testing - WebAssembly backend (path override filename)', async function () { // override .wasm file path for 'ort-wasm.wasm' ort.env.wasm.wasmPaths = { - 'mjs': pathToFileURL(path.join(__dirname, 'test-wasm-path-override/renamed.mjs')), - 'wasm': pathToFileURL(path.join(__dirname, 'test-wasm-path-override/renamed.wasm')) + mjs: pathToFileURL(path.join(__dirname, 'test-wasm-path-override/renamed.mjs')), + wasm: pathToFileURL(path.join(__dirname, 'test-wasm-path-override/renamed.wasm')), }; - await testFunction(ort, {executionProviders: ['wasm']}); + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/node-test-wasm-path-override-prefix.js b/js/web/test/e2e/node-test-wasm-path-override-prefix.js index 76a7600a75917..fac3e0b8be97c 100644 --- a/js/web/test/e2e/node-test-wasm-path-override-prefix.js +++ b/js/web/test/e2e/node-test-wasm-path-override-prefix.js @@ -6,9 +6,9 @@ const path = require('path'); const ort = require('onnxruntime-web'); const testFunction = require('./common'); -const {pathToFileURL} = require('url') +const { pathToFileURL } = require('url'); -it('Node.js E2E testing - WebAssembly backend (path override prefix)', async function() { +it('Node.js E2E testing - WebAssembly backend (path override prefix)', async function () { // disable SIMD and multi-thread ort.env.wasm.numThreads = 1; ort.env.wasm.simd = false; @@ -16,5 +16,5 @@ it('Node.js E2E testing - WebAssembly backend (path override prefix)', async fun // override .wasm file path prefix ort.env.wasm.wasmPaths = pathToFileURL(path.join(__dirname, 'test-wasm-path-override/')); - await testFunction(ort, {executionProviders: ['wasm']}); + await testFunction(ort, { executionProviders: ['wasm'] }); }); diff --git a/js/web/test/e2e/rollup.config.esm-js.js b/js/web/test/e2e/rollup.config.esm-js.js index 635c52f39d4b1..5ee08aa49a1b8 100644 --- a/js/web/test/e2e/rollup.config.esm-js.js +++ b/js/web/test/e2e/rollup.config.esm-js.js @@ -1,18 +1,17 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -const {nodeResolve} = require('@rollup/plugin-node-resolve'); +const { nodeResolve } = require('@rollup/plugin-node-resolve'); const terser = require('@rollup/plugin-terser'); const copy = require('rollup-plugin-copy'); module.exports = { - input : 'src/esm-js/main.js', - output : { - file : 'dist/rollup_esm_js/ort-test-e2e.bundle.mjs', - format : 'esm', + input: 'src/esm-js/main.js', + output: { + file: 'dist/rollup_esm_js/ort-test-e2e.bundle.mjs', + format: 'esm', }, - plugins : - [ + plugins: [ // Use "@rollup/plugin-node-resolve" to support conditional import. // (e.g. `import {...} from 'onnxruntime-web/wasm';`) nodeResolve(), @@ -21,6 +20,6 @@ module.exports = { terser(), // Use "rollup-plugin-copy" to copy the onnxruntime-web WebAssembly files to the output directory. - copy({targets : [{src : 'node_modules/onnxruntime-web/dist/ort-*.{js,mjs,wasm}', dest : 'dist/rollup_esm_js'}]}) - ] + copy({ targets: [{ src: 'node_modules/onnxruntime-web/dist/ort-*.{js,mjs,wasm}', dest: 'dist/rollup_esm_js' }] }), + ], }; diff --git a/js/web/test/e2e/rollup.config.umd-js.js b/js/web/test/e2e/rollup.config.umd-js.js index 1aad0092145ae..a6ac16f8cb870 100644 --- a/js/web/test/e2e/rollup.config.umd-js.js +++ b/js/web/test/e2e/rollup.config.umd-js.js @@ -2,30 +2,29 @@ // Licensed under the MIT license. const commonjs = require('@rollup/plugin-commonjs'); -const {nodeResolve} = require('@rollup/plugin-node-resolve'); +const { nodeResolve } = require('@rollup/plugin-node-resolve'); const terser = require('@rollup/plugin-terser'); const copy = require('rollup-plugin-copy'); module.exports = { - input : 'src/cjs-js/main.js', - output : { - name : 'testPackageConsuming', - file : 'dist/rollup_umd_js/ort-test-e2e.bundle.js', - format : 'umd', + input: 'src/cjs-js/main.js', + output: { + name: 'testPackageConsuming', + file: 'dist/rollup_umd_js/ort-test-e2e.bundle.js', + format: 'umd', }, - plugins : - [ + plugins: [ // Use "@rollup/plugin-node-resolve" to support conditional import. // (e.g. `import {...} from 'onnxruntime-web/wasm';`) nodeResolve(), // Use "@rollup/plugin-commonjs" to support CommonJS module resolve. - commonjs({ignoreDynamicRequires : true}), + commonjs({ ignoreDynamicRequires: true }), // Use "@rollup/plugin-terser" to minify the output. terser(), // Use "rollup-plugin-copy" to copy the onnxruntime-web WebAssembly files to the output directory. - copy({targets : [{src : 'node_modules/onnxruntime-web/dist/ort-*.{js,mjs,wasm}', dest : 'dist/rollup_umd_js'}]}) - ] + copy({ targets: [{ src: 'node_modules/onnxruntime-web/dist/ort-*.{js,mjs,wasm}', dest: 'dist/rollup_umd_js' }] }), + ], }; diff --git a/js/web/test/e2e/run-data.js b/js/web/test/e2e/run-data.js index 856f29eac6ddf..04079b042bc23 100644 --- a/js/web/test/e2e/run-data.js +++ b/js/web/test/e2e/run-data.js @@ -14,27 +14,27 @@ const NODEJS_TEST_CASES = [ // [test_for_same_origin, test_for_cross_origin, main_js, ort_main_js, [test_args]] const BROWSER_TEST_CASES = [ // IIFE - [true, true, './browser-test-webgl.js', 'ort.min.js'], // webgl - [true, true, './browser-test-webgl.js', 'ort.webgl.min.js'], // webgl - [true, true, './browser-test-wasm.js', 'ort.wasm.min.js'], // wasm, ort.wasm - [true, true, './browser-test-wasm-multi-session-create.js', 'ort.min.js'], // wasm, multi-session create - [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=1']], // wasm, 1 thread - [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=2']], // wasm, 2 threads - [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=2', 'proxy=1']], // wasm, 2 threads, proxy - [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=1', 'proxy=1']], // wasm, 1 thread, proxy + [true, true, './browser-test-webgl.js', 'ort.min.js'], // webgl + [true, true, './browser-test-webgl.js', 'ort.webgl.min.js'], // webgl + [true, true, './browser-test-wasm.js', 'ort.wasm.min.js'], // wasm, ort.wasm + [true, true, './browser-test-wasm-multi-session-create.js', 'ort.min.js'], // wasm, multi-session create + [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=1']], // wasm, 1 thread + [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=2']], // wasm, 2 threads + [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=2', 'proxy=1']], // wasm, 2 threads, proxy + [true, true, './browser-test-wasm.js', 'ort.min.js', ['num_threads=1', 'proxy=1']], // wasm, 1 thread, proxy // ort.min.mjs - [true, true, './browser-test-webgl.js', 'ort.min.mjs'], // webgl - [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=1']], // wasm, 1 thread - [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=2']], // wasm, 2 threads - [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=2', 'proxy=1']], // wasm, 2 threads, proxy - [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=1', 'proxy=1']], // wasm, 1 thread, proxy + [true, true, './browser-test-webgl.js', 'ort.min.mjs'], // webgl + [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=1']], // wasm, 1 thread + [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=2']], // wasm, 2 threads + [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=2', 'proxy=1']], // wasm, 2 threads, proxy + [true, true, './browser-test-wasm.js', 'ort.min.mjs', ['num_threads=1', 'proxy=1']], // wasm, 1 thread, proxy // ort.bundle.min.mjs - [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=1']], // 1 thread - [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=2']], // 2 threads - [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=2', 'proxy=1']], // 2 threads, proxy - [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=1', 'proxy=1']], // 1 thread, proxy + [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=1']], // 1 thread + [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=2']], // 2 threads + [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=2', 'proxy=1']], // 2 threads, proxy + [true, false, './browser-test-wasm.js', 'ort.bundle.min.mjs', ['num_threads=1', 'proxy=1']], // 1 thread, proxy // wasm binary override: [true, false, './browser-test-wasm-binary-override.js', 'ort.min.js'], @@ -65,8 +65,8 @@ const BROWSER_TEST_CASES = [ [false, true, './browser-test-wasm-path-override-prefix.js', 'ort.min.js', ['port=8081']], [false, true, './browser-test-wasm-path-override-prefix.js', 'ort.wasm.min.js', ['port=8081']], - [true, true, './browser-test-wasm-image-tensor-image.js', 'ort.min.js'], // pre-post-process - [true, true, './browser-test-webgpu-external-data.js', 'ort.webgpu.min.js'], // external data + [true, true, './browser-test-wasm-image-tensor-image.js', 'ort.min.js'], // pre-post-process + [true, true, './browser-test-webgpu-external-data.js', 'ort.webgpu.min.js'], // external data ]; // [bundle_path, format] diff --git a/js/web/test/e2e/run.js b/js/web/test/e2e/run.js index 5bf31e8d7ac2a..93f9d4a144bf2 100644 --- a/js/web/test/e2e/run.js +++ b/js/web/test/e2e/run.js @@ -5,11 +5,11 @@ const path = require('path'); const fs = require('fs-extra'); -const {spawn} = require('child_process'); +const { spawn } = require('child_process'); const startServer = require('./simple-http-server'); const minimist = require('minimist'); -const {NODEJS_TEST_CASES, BROWSER_TEST_CASES, BUNDLER_TEST_CASES} = require('./run-data'); +const { NODEJS_TEST_CASES, BROWSER_TEST_CASES, BUNDLER_TEST_CASES } = require('./run-data'); // copy whole folder to out-side of /js/ because we need to test in a folder that no `package.json` file // exists in its parent folder. @@ -28,7 +28,7 @@ fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER); // always use a new folder as user-data-dir let nextUserDataDirId = 0; function getNextUserDataDir() { - const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()) + const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()); nextUserDataDirId++; fs.emptyDirSync(dir); return dir; @@ -39,10 +39,10 @@ const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default'; async function main() { // find packed package - const {globbySync} = await import('globby'); + const { globbySync } = await import('globby'); const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common'); - const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', {cwd: ORT_COMMON_FOLDER}); + const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', { cwd: ORT_COMMON_FOLDER }); const PACKAGES_TO_INSTALL = []; @@ -53,7 +53,7 @@ async function main() { } const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web'); - const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', {cwd: ORT_WEB_FOLDER}); + const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', { cwd: ORT_WEB_FOLDER }); if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) { throw new Error('cannot find exactly single package for onnxruntime-web.'); } @@ -65,7 +65,7 @@ async function main() { await runInShell(`npm install`); // npm install with "--cache" to install packed packages with an empty cache folder - await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map(i => `"${i}"`).join(' ')}`); + await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map((i) => `"${i}"`).join(' ')}`); // prepare .wasm files for path override testing prepareWasmPathOverrideFiles(); @@ -78,11 +78,15 @@ async function main() { prepareEsmLoaderFiles(); await fs.symlink( - path.resolve(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web', 'dist'), path.join(serverWwwRoot, 'dist'), - 'junction'); + path.resolve(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web', 'dist'), + path.join(serverWwwRoot, 'dist'), + 'junction', + ); await fs.symlink( - path.resolve(TEST_E2E_RUN_FOLDER, 'test-wasm-path-override'), path.join(serverWwwRoot, 'test-wasm-path-override'), - 'junction'); + path.resolve(TEST_E2E_RUN_FOLDER, 'test-wasm-path-override'), + path.join(serverWwwRoot, 'test-wasm-path-override'), + 'junction', + ); // start a HTTP server for hosting .wasm files (for cross-origin testing) const server = startServer(serverWwwRoot, 8081); @@ -94,17 +98,16 @@ async function main() { await testAllNodejsCases(); // test cases with self-host (ort hosted in same origin) - await testAllBrowserCases({hostInKarma: true}); + await testAllBrowserCases({ hostInKarma: true }); // test cases without self-host (ort hosted in different origin) - await testAllBrowserCases({hostInKarma: false}); + await testAllBrowserCases({ hostInKarma: false }); // run bundlers await runInShell(`npm run build`); // test package consuming test await testAllBrowserPackagesConsumingCases(); - } finally { // close the server after all tests await server.close(); @@ -112,25 +115,32 @@ async function main() { } function prepareEsmLoaderFiles() { - const allEsmFiles = [...new Set(BROWSER_TEST_CASES.map(i => i[3]).filter(i => i && i.endsWith('.mjs')))]; + const allEsmFiles = [...new Set(BROWSER_TEST_CASES.map((i) => i[3]).filter((i) => i && i.endsWith('.mjs')))]; // self-hosted fs.emptyDirSync(path.join(TEST_E2E_RUN_FOLDER, 'esm-loaders')); fs.emptyDirSync(path.join(TEST_E2E_RUN_FOLDER, 'wwwroot', 'esm-loaders')); - allEsmFiles.forEach(i => { + allEsmFiles.forEach((i) => { fs.writeFileSync( - path.join(TEST_E2E_RUN_FOLDER, 'esm-loaders', i), - `import * as x from '../node_modules/onnxruntime-web/dist/${i}'; globalThis.ort = x;`); + path.join(TEST_E2E_RUN_FOLDER, 'esm-loaders', i), + `import * as x from '../node_modules/onnxruntime-web/dist/${i}'; globalThis.ort = x;`, + ); fs.writeFileSync( - path.join(TEST_E2E_RUN_FOLDER, 'wwwroot', 'esm-loaders', i), - `import * as x from '../dist/${i}'; globalThis.ort = x;`); + path.join(TEST_E2E_RUN_FOLDER, 'wwwroot', 'esm-loaders', i), + `import * as x from '../dist/${i}'; globalThis.ort = x;`, + ); }); } function prepareWasmPathOverrideFiles() { const folder = path.join(TEST_E2E_RUN_FOLDER, 'test-wasm-path-override'); - const sourceFile = - path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web', 'dist', 'ort-wasm-simd-threaded'); + const sourceFile = path.join( + TEST_E2E_RUN_FOLDER, + 'node_modules', + 'onnxruntime-web', + 'dist', + 'ort-wasm-simd-threaded', + ); fs.emptyDirSync(folder); fs.copyFileSync(`${sourceFile}.mjs`, path.join(folder, 'ort-wasm-simd-threaded.mjs')); fs.copyFileSync(`${sourceFile}.wasm`, path.join(folder, 'ort-wasm-simd-threaded.wasm')); @@ -144,23 +154,23 @@ async function testAllNodejsCases() { } } -async function testAllBrowserCases({hostInKarma}) { +async function testAllBrowserCases({ hostInKarma }) { for (const [testForSameOrigin, testForCrossOrigin, main, ortMain, args] of BROWSER_TEST_CASES) { if (hostInKarma && testForSameOrigin) { - await runKarma({hostInKarma, main, ortMain, args}); - await runKarma({hostInKarma, main, ortMain, args, enableSharedArrayBuffer: true}); + await runKarma({ hostInKarma, main, ortMain, args }); + await runKarma({ hostInKarma, main, ortMain, args, enableSharedArrayBuffer: true }); } if (!hostInKarma && testForCrossOrigin) { - await runKarma({hostInKarma, main, ortMain, args}); - await runKarma({hostInKarma, main, ortMain, args, enableSharedArrayBuffer: true}); + await runKarma({ hostInKarma, main, ortMain, args }); + await runKarma({ hostInKarma, main, ortMain, args, enableSharedArrayBuffer: true }); } } } async function testAllBrowserPackagesConsumingCases() { for (const [main, format] of BUNDLER_TEST_CASES) { - await runKarma({hostInKarma: true, main, ortMain: '', format}); - await runKarma({hostInKarma: true, main, ortMain: '', format, enableSharedArrayBuffer: true}); + await runKarma({ hostInKarma: true, main, ortMain: '', format }); + await runKarma({ hostInKarma: true, main, ortMain: '', format, enableSharedArrayBuffer: true }); } } @@ -171,15 +181,17 @@ async function runKarma({ ortMain = 'ort.min.js', format = 'iife', enableSharedArrayBuffer = false, - args = [] + args = [], }) { const selfHostFlag = hostInKarma ? '--self-host' : ''; - const argsStr = args.map(i => `--test-args=${i}`).join(' '); + const argsStr = args.map((i) => `--test-args=${i}`).join(' '); const formatFlag = `--format=${format}`; const enableSharedArrayBufferFlag = enableSharedArrayBuffer ? '--enable-shared-array-buffer' : ''; await runInShell( - `npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ortMain} --test-main=${ - main} --user-data=${getNextUserDataDir()} ${argsStr} ${formatFlag} ${enableSharedArrayBufferFlag}`); + `npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ortMain} --test-main=${ + main + } --user-data=${getNextUserDataDir()} ${argsStr} ${formatFlag} ${enableSharedArrayBufferFlag}`, + ); } async function runInShell(cmd) { @@ -188,8 +200,8 @@ async function runInShell(cmd) { console.log(' > ' + cmd); console.log('==============================================================='); let complete = false; - const childProcess = spawn(cmd, {shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER}); - childProcess.on('close', function(code) { + const childProcess = spawn(cmd, { shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER }); + childProcess.on('close', function (code) { if (code !== 0) { process.exit(code); } else { @@ -202,8 +214,8 @@ async function runInShell(cmd) { } async function delay(ms) { - return new Promise(function(resolve) { - setTimeout(function() { + return new Promise(function (resolve) { + setTimeout(function () { resolve(); }, ms); }); diff --git a/js/web/test/e2e/simple-http-server.js b/js/web/test/e2e/simple-http-server.js index 2faac81969294..bad00ae96f2a5 100644 --- a/js/web/test/e2e/simple-http-server.js +++ b/js/web/test/e2e/simple-http-server.js @@ -15,8 +15,11 @@ const getRequestData = (url, dir) => { let filepath; let mimeType; - if (pathname.startsWith('/test-wasm-path-override/') || pathname.startsWith('/dist/') || - pathname.startsWith('/esm-loaders/')) { + if ( + pathname.startsWith('/test-wasm-path-override/') || + pathname.startsWith('/dist/') || + pathname.startsWith('/esm-loaders/') + ) { filepath = path.resolve(dir, pathname.substring(1)); } else { return null; @@ -33,35 +36,36 @@ const getRequestData = (url, dir) => { return [filepath, mimeType]; }; -module.exports = function(dir, port) { - const server = http.createServer(function(request, response) { - const url = request.url.replace(/\n|\r/g, ''); - console.log(`request ${url}`); +module.exports = function (dir, port) { + const server = http + .createServer(function (request, response) { + const url = request.url.replace(/\n|\r/g, ''); + console.log(`request ${url}`); - const requestData = getRequestData(url, dir); - if (!request || !requestData) { - response.writeHead(404); - response.end('404'); - } else { - const [filePath, contentType] = requestData; - fs.readFile(path.resolve(dir, filePath), function(error, content) { - if (error) { - if (error.code == 'ENOENT') { - response.writeHead(404); - response.end('404'); - } else { - response.writeHead(500); - response.end('500'); - } - } else { - response.setHeader('access-control-allow-origin', '*'); - response.writeHead(200, {'Content-Type': contentType}); - response.end(content, 'utf-8'); - } - }); - } - }) - .listen(port); + const requestData = getRequestData(url, dir); + if (!request || !requestData) { + response.writeHead(404); + response.end('404'); + } else { + const [filePath, contentType] = requestData; + fs.readFile(path.resolve(dir, filePath), function (error, content) { + if (error) { + if (error.code == 'ENOENT') { + response.writeHead(404); + response.end('404'); + } else { + response.writeHead(500); + response.end('500'); + } + } else { + response.setHeader('access-control-allow-origin', '*'); + response.writeHead(200, { 'Content-Type': contentType }); + response.end(content, 'utf-8'); + } + }); + } + }) + .listen(port); console.log(`Server running at http://localhost:${port}/`); return server; }; diff --git a/js/web/test/e2e/src/cjs-js/main.js b/js/web/test/e2e/src/cjs-js/main.js index dac4b92a93c56..c9b8d3e85455d 100644 --- a/js/web/test/e2e/src/cjs-js/main.js +++ b/js/web/test/e2e/src/cjs-js/main.js @@ -4,15 +4,15 @@ 'use strict'; const ort = require('onnxruntime-web/wasm'); -const {setupMultipleThreads, testInferenceAndValidate} = require('./shared'); +const { setupMultipleThreads, testInferenceAndValidate } = require('./shared'); if (typeof SharedArrayBuffer === 'undefined') { - it('Browser package consuming test - single-thread - [js][commonjs]', async function() { - await testInferenceAndValidate(ort, {executionProviders: ['wasm']}); + it('Browser package consuming test - single-thread - [js][commonjs]', async function () { + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); }); } else { - it('Browser package consuming test - multi-thread - [js][commonjs]', async function() { + it('Browser package consuming test - multi-thread - [js][commonjs]', async function () { setupMultipleThreads(ort); - await testInferenceAndValidate(ort, {executionProviders: ['wasm']}); + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); }); } diff --git a/js/web/test/e2e/src/cjs-js/shared.js b/js/web/test/e2e/src/cjs-js/shared.js index ac8d151998712..980587e281ca8 100644 --- a/js/web/test/e2e/src/cjs-js/shared.js +++ b/js/web/test/e2e/src/cjs-js/shared.js @@ -5,7 +5,7 @@ // Model data for "test_abs/model.onnx" const testModelData = - 'CAcSDGJhY2tlbmQtdGVzdDpJCgsKAXgSAXkiA0FicxIIdGVzdF9hYnNaFwoBeBISChAIARIMCgIIAwoCCAQKAggFYhcKAXkSEgoQCAESDAoCCAMKAggECgIIBUIECgAQDQ=='; + 'CAcSDGJhY2tlbmQtdGVzdDpJCgsKAXgSAXkiA0FicxIIdGVzdF9hYnNaFwoBeBISChAIARIMCgIIAwoCCAQKAggFYhcKAXkSEgoQCAESDAoCCAMKAggECgIIBUIECgAQDQ=='; const base64StringToUint8Array = (base64String) => { const charArray = atob(base64String); @@ -31,10 +31,10 @@ const testInferenceAndValidate = async (ort, options) => { const session = await ort.InferenceSession.create(model, options); // test data: [0, -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, ... 58, -59] - const inputData = [...Array(60).keys()].map(i => i % 2 === 0 ? i : -i); - const expectedOutputData = inputData.map(i => Math.abs(i)); + const inputData = [...Array(60).keys()].map((i) => (i % 2 === 0 ? i : -i)); + const expectedOutputData = inputData.map((i) => Math.abs(i)); - const fetches = await session.run({x: new ort.Tensor('float32', inputData, [3, 4, 5])}); + const fetches = await session.run({ x: new ort.Tensor('float32', inputData, [3, 4, 5]) }); const y = fetches.y; @@ -48,5 +48,5 @@ const testInferenceAndValidate = async (ort, options) => { module.exports = { setupMultipleThreads, - testInferenceAndValidate + testInferenceAndValidate, }; diff --git a/js/web/test/e2e/src/esm-js/main.js b/js/web/test/e2e/src/esm-js/main.js index abe9a55e1b37a..7687b8b731878 100644 --- a/js/web/test/e2e/src/esm-js/main.js +++ b/js/web/test/e2e/src/esm-js/main.js @@ -4,15 +4,15 @@ 'use strict'; import * as ort from 'onnxruntime-web/wasm'; -import {setupMultipleThreads, default as testInferenceAndValidate} from './shared.js'; +import { setupMultipleThreads, default as testInferenceAndValidate } from './shared.js'; if (typeof SharedArrayBuffer === 'undefined') { - it('Browser package consuming test - single-thread - [js][esm]', async function() { - await testInferenceAndValidate(ort, {executionProviders: ['wasm']}); + it('Browser package consuming test - single-thread - [js][esm]', async function () { + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); }); } else { - it('Browser package consuming test - multi-thread - [js][esm]', async function() { + it('Browser package consuming test - multi-thread - [js][esm]', async function () { setupMultipleThreads(ort); - await testInferenceAndValidate(ort, {executionProviders: ['wasm']}); + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); }); } diff --git a/js/web/test/e2e/src/esm-js/shared.js b/js/web/test/e2e/src/esm-js/shared.js index 54b714d67e0e3..57d19c99c9a1e 100644 --- a/js/web/test/e2e/src/esm-js/shared.js +++ b/js/web/test/e2e/src/esm-js/shared.js @@ -5,7 +5,7 @@ // Model data for "test_abs/model.onnx" const testModelData = - 'CAcSDGJhY2tlbmQtdGVzdDpJCgsKAXgSAXkiA0FicxIIdGVzdF9hYnNaFwoBeBISChAIARIMCgIIAwoCCAQKAggFYhcKAXkSEgoQCAESDAoCCAMKAggECgIIBUIECgAQDQ=='; + 'CAcSDGJhY2tlbmQtdGVzdDpJCgsKAXgSAXkiA0FicxIIdGVzdF9hYnNaFwoBeBISChAIARIMCgIIAwoCCAQKAggFYhcKAXkSEgoQCAESDAoCCAMKAggECgIIBUIECgAQDQ=='; const base64StringToUint8Array = (base64String) => { const charArray = atob(base64String); @@ -31,10 +31,10 @@ const testInferenceAndValidate = async (ort, options) => { const session = await ort.InferenceSession.create(model, options); // test data: [0, -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, ... 58, -59] - const inputData = [...Array(60).keys()].map(i => i % 2 === 0 ? i : -i); - const expectedOutputData = inputData.map(i => Math.abs(i)); + const inputData = [...Array(60).keys()].map((i) => (i % 2 === 0 ? i : -i)); + const expectedOutputData = inputData.map((i) => Math.abs(i)); - const fetches = await session.run({x: new ort.Tensor('float32', inputData, [3, 4, 5])}); + const fetches = await session.run({ x: new ort.Tensor('float32', inputData, [3, 4, 5]) }); const y = fetches.y; @@ -47,4 +47,4 @@ const testInferenceAndValidate = async (ort, options) => { }; export default testInferenceAndValidate; -export {setupMultipleThreads}; +export { setupMultipleThreads }; diff --git a/js/web/test/e2e/webpack.config.esm-js.js b/js/web/test/e2e/webpack.config.esm-js.js index 713c27cf04286..fe235ccd361d6 100644 --- a/js/web/test/e2e/webpack.config.esm-js.js +++ b/js/web/test/e2e/webpack.config.esm-js.js @@ -5,19 +5,20 @@ const path = require('node:path'); const CopyPlugin = require('copy-webpack-plugin'); module.exports = { - module : {parser : {javascript : {importMeta : false}}}, - experiments : {outputModule : true}, - target : ['web'], - entry : path.resolve(__dirname, 'src/esm-js/main.js'), - output : { - clean : true, - filename : 'ort-test-e2e.bundle.mjs', - path : path.resolve(__dirname, 'dist/webpack_esm_js'), - library : {type : 'module'}, + module: { parser: { javascript: { importMeta: false } } }, + experiments: { outputModule: true }, + target: ['web'], + entry: path.resolve(__dirname, 'src/esm-js/main.js'), + output: { + clean: true, + filename: 'ort-test-e2e.bundle.mjs', + path: path.resolve(__dirname, 'dist/webpack_esm_js'), + library: { type: 'module' }, }, - plugins : - [ + plugins: [ // Use "copy-webpack-plugin" to copy the onnxruntime-web WebAssembly files to the output directory. - new CopyPlugin({patterns : [{from : 'node_modules/onnxruntime-web/dist/ort-*.{js,mjs,wasm}', to : '[name][ext]'}]}), - ] + new CopyPlugin({ + patterns: [{ from: 'node_modules/onnxruntime-web/dist/ort-*.{js,mjs,wasm}', to: '[name][ext]' }], + }), + ], }; diff --git a/js/web/test/e2e/webpack.config.umd-js.js b/js/web/test/e2e/webpack.config.umd-js.js index d21ec30c91d6f..2b909aa40d7c7 100644 --- a/js/web/test/e2e/webpack.config.umd-js.js +++ b/js/web/test/e2e/webpack.config.umd-js.js @@ -5,17 +5,18 @@ const path = require('node:path'); const CopyPlugin = require('copy-webpack-plugin'); module.exports = { - target : ['web'], - entry : path.resolve(__dirname, 'src/cjs-js/main.js'), - output : { - clean : true, - filename : 'ort-test-e2e.bundle.js', - path : path.resolve(__dirname, 'dist/webpack_umd_js'), - library : {type : 'umd'}, + target: ['web'], + entry: path.resolve(__dirname, 'src/cjs-js/main.js'), + output: { + clean: true, + filename: 'ort-test-e2e.bundle.js', + path: path.resolve(__dirname, 'dist/webpack_umd_js'), + library: { type: 'umd' }, }, - plugins : - [ + plugins: [ // Use "copy-webpack-plugin" to copy the onnxruntime-web WebAssembly files to the output directory. - new CopyPlugin({patterns : [{from : 'node_modules/onnxruntime-web/dist/ort-*.{js,mjs,wasm}', to : '[name][ext]'}]}), - ] + new CopyPlugin({ + patterns: [{ from: 'node_modules/onnxruntime-web/dist/ort-*.{js,mjs,wasm}', to: '[name][ext]' }], + }), + ], }; diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 96e374f87aed1..4988da41e802a 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -9,11 +9,11 @@ const ORT_WEB_TEST_CONFIG = require('./testdata-config.json') as Test.Config; import * as platform from 'platform'; -import {Logger} from '../lib/onnxjs/instrument'; +import { Logger } from '../lib/onnxjs/instrument'; -import {Test} from './test-types'; +import { Test } from './test-types'; -if (ORT_WEB_TEST_CONFIG.model.some(testGroup => testGroup.tests.some(test => test.backend === 'cpu'))) { +if (ORT_WEB_TEST_CONFIG.model.some((testGroup) => testGroup.tests.some((test) => test.backend === 'cpu'))) { // require onnxruntime-node require('../../node'); } @@ -26,8 +26,8 @@ for (const logConfig of ORT_WEB_TEST_CONFIG.log) { Logger.set(logConfig.category, logConfig.config); } -import {ModelTestContext, OpTestContext, ProtoOpTestContext, runModelTestSet, runOpTest} from './test-runner'; -import {readJsonFile} from './test-shared'; +import { ModelTestContext, OpTestContext, ProtoOpTestContext, runModelTestSet, runOpTest } from './test-runner'; +import { readJsonFile } from './test-shared'; // Unit test if (ORT_WEB_TEST_CONFIG.unittest) { @@ -37,14 +37,14 @@ if (ORT_WEB_TEST_CONFIG.unittest) { // Set file cache if (ORT_WEB_TEST_CONFIG.fileCacheUrls) { before('prepare file cache', async () => { - const allJsonCache = await Promise.all(ORT_WEB_TEST_CONFIG.fileCacheUrls!.map(readJsonFile)) as Test.FileCache[]; + const allJsonCache = (await Promise.all(ORT_WEB_TEST_CONFIG.fileCacheUrls!.map(readJsonFile))) as Test.FileCache[]; for (const cache of allJsonCache) { ModelTestContext.setCache(cache); } }); } -function shouldSkipTest(test: Test.ModelTest|Test.OperatorTest) { +function shouldSkipTest(test: Test.ModelTest | Test.OperatorTest) { if (!test.cases || test.cases.length === 0) { return true; } @@ -95,11 +95,12 @@ for (const group of ORT_WEB_TEST_CONFIG.op) { const backend = test.backend!; const useProtoOpTest = backend !== 'webgl'; describeTest(`[${backend}]${test.operator} - ${test.name}`, () => { - let context: ProtoOpTestContext|OpTestContext; + let context: ProtoOpTestContext | OpTestContext; before('Initialize Context', async () => { - context = useProtoOpTest ? new ProtoOpTestContext(test, ORT_WEB_TEST_CONFIG.options.sessionOptions) : - new OpTestContext(test); + context = useProtoOpTest + ? new ProtoOpTestContext(test, ORT_WEB_TEST_CONFIG.options.sessionOptions) + : new OpTestContext(test); await context.init(); if (ORT_WEB_TEST_CONFIG.profile) { if (context instanceof ProtoOpTestContext) { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index bc782a18c55f2..84f3d8d9fca2b 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -1,25 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Float16Array as Float16ArrayPolyfill} from '@petamoriken/float16'; -import {expect} from 'chai'; +import { Float16Array as Float16ArrayPolyfill } from '@petamoriken/float16'; +import { expect } from 'chai'; import * as ort from 'onnxruntime-common'; -import {extname} from 'path'; -import {inspect} from 'util'; - -import {Attribute} from '../lib/onnxjs/attribute'; -import {InferenceHandler, resolveBackend, SessionHandler} from '../lib/onnxjs/backend'; -import {createWebGLContext} from '../lib/onnxjs/backends/webgl/webgl-context-factory'; -import {Logger, Profiler} from '../lib/onnxjs/instrument'; -import {Operator} from '../lib/onnxjs/operators'; -import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; -import {Tensor} from '../lib/onnxjs/tensor'; -import {ProtoUtil} from '../lib/onnxjs/util'; -import {createView} from '../lib/wasm/jsep/tensor-view'; -import {getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; - -import {base64toBuffer, createMockGraph, readFile} from './test-shared'; -import {Test} from './test-types'; +import { extname } from 'path'; +import { inspect } from 'util'; + +import { Attribute } from '../lib/onnxjs/attribute'; +import { InferenceHandler, resolveBackend, SessionHandler } from '../lib/onnxjs/backend'; +import { createWebGLContext } from '../lib/onnxjs/backends/webgl/webgl-context-factory'; +import { Logger, Profiler } from '../lib/onnxjs/instrument'; +import { Operator } from '../lib/onnxjs/operators'; +import { onnx } from '../lib/onnxjs/ort-schema/protobuf/onnx'; +import { Tensor } from '../lib/onnxjs/tensor'; +import { ProtoUtil } from '../lib/onnxjs/util'; +import { createView } from '../lib/wasm/jsep/tensor-view'; +import { getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum } from '../lib/wasm/wasm-common'; + +import { base64toBuffer, createMockGraph, readFile } from './test-shared'; +import { Test } from './test-types'; // the threshold that used to compare 2 float numbers. See above for TensorResultValidator.floatEqual(). const CPU_THRESHOLD_ABSOLUTE_ERROR = 1.0e-4; @@ -38,31 +38,41 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001; /** * returns a number to represent the current timestamp in a resolution as high as possible. */ -const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now; +const now = typeof performance !== 'undefined' && performance.now ? () => performance.now() : Date.now; function fromInternalTensor(tensor: Tensor): ort.Tensor { return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims); } -async function loadTensorProto(uriOrData: string|Uint8Array, allowInt64 = false): Promise { - const buf = (typeof uriOrData === 'string') ? await readFile(uriOrData) : uriOrData; +async function loadTensorProto(uriOrData: string | Uint8Array, allowInt64 = false): Promise { + const buf = typeof uriOrData === 'string' ? await readFile(uriOrData) : uriOrData; const tensorProto = onnx.TensorProto.decode(buf); let tensor: ort.Tensor; // by default, we don't allow (u)int64. this is for backward compatibility. - if (allowInt64 && tensorProto && tensorProto.dataType && - ((tensorProto.dataType === onnx.TensorProto.DataType.INT64 || - tensorProto.dataType === onnx.TensorProto.DataType.UINT64))) { + if ( + allowInt64 && + tensorProto && + tensorProto.dataType && + (tensorProto.dataType === onnx.TensorProto.DataType.INT64 || + tensorProto.dataType === onnx.TensorProto.DataType.UINT64) + ) { const signed = tensorProto.dataType === onnx.TensorProto.DataType.INT64; const dataConstructor = signed ? BigInt64Array : BigUint64Array; const length = tensorProto.rawData.byteLength / 8; const data = new dataConstructor(length); - if (tensorProto.rawData && typeof tensorProto.rawData.byteLength === 'number' && - tensorProto.rawData.byteLength > 0) { - const dataSource = - new DataView(tensorProto.rawData.buffer, tensorProto.rawData.byteOffset, tensorProto.rawData.byteLength); + if ( + tensorProto.rawData && + typeof tensorProto.rawData.byteLength === 'number' && + tensorProto.rawData.byteLength > 0 + ) { + const dataSource = new DataView( + tensorProto.rawData.buffer, + tensorProto.rawData.byteOffset, + tensorProto.rawData.byteLength, + ); for (let i = 0; i < length; i++) { data[i] = signed ? dataSource.getBigInt64(i * 8, true) : dataSource.getBigUint64(i * 8, true); } @@ -82,16 +92,19 @@ async function loadTensorProto(uriOrData: string|Uint8Array, allowInt64 = false) return namedTensor; } -async function loadMlProto(_uriOrData: string|Uint8Array): Promise { +async function loadMlProto(_uriOrData: string | Uint8Array): Promise { return Promise.reject('not supported'); } async function loadTensors( - modelMetaData: {inputNames: readonly string[]; outputNames: readonly string[]}, testCase: Test.ModelTestCase, - backendName: string, fileCache?: FileCacheBuffer) { + modelMetaData: { inputNames: readonly string[]; outputNames: readonly string[] }, + testCase: Test.ModelTestCase, + backendName: string, + fileCache?: FileCacheBuffer, +) { const inputs: Test.NamedTensor[] = []; const outputs: Test.NamedTensor[] = []; - let dataFileType: 'none'|'pb'|'npy' = 'none'; + let dataFileType: 'none' | 'pb' | 'npy' = 'none'; const allowInt64 = ['wasm', 'webgpu', 'webnn'].includes(backendName); @@ -106,8 +119,10 @@ async function loadTensors( } const uriOrData = fileCache && fileCache[dataFile] ? fileCache[dataFile] : dataFile; - const t = ext.toLowerCase() === '.pb' ? await loadTensorProto(uriOrData, allowInt64) : // onnx.TensorProto - await loadMlProto(uriOrData); + const t = + ext.toLowerCase() === '.pb' + ? await loadTensorProto(uriOrData, allowInt64) // onnx.TensorProto + : await loadMlProto(uriOrData); const dataFileBasename = dataFile.split(/[/\\]/).pop()!; @@ -134,24 +149,31 @@ async function loadTensors( } async function initializeSession( - modelFilePath: string, backendHint: ort.InferenceSession.ExecutionProviderConfig, ioBindingMode: Test.IOBindingMode, - profile: boolean, externalData: ort.InferenceSession.SessionOptions['externalData'], - sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { - const preloadModelData: Uint8Array|undefined = - fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; + modelFilePath: string, + backendHint: ort.InferenceSession.ExecutionProviderConfig, + ioBindingMode: Test.IOBindingMode, + profile: boolean, + externalData: ort.InferenceSession.SessionOptions['externalData'], + sessionOptions: ort.InferenceSession.SessionOptions, + fileCache?: FileCacheBuffer, +): Promise { + const preloadModelData: Uint8Array | undefined = + fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; Logger.verbose( - 'TestRunner', - `Start to load model from file: ${modelFilePath}${ - preloadModelData ? ` [preloaded(${preloadModelData.byteLength})]` : ''}`); + 'TestRunner', + `Start to load model from file: ${modelFilePath}${ + preloadModelData ? ` [preloaded(${preloadModelData.byteLength})]` : '' + }`, + ); - const profilerConfig = profile ? {maxNumberEvents: 65536} : undefined; + const profilerConfig = profile ? { maxNumberEvents: 65536 } : undefined; const sessionConfig = { ...sessionOptions, executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile, preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, - externalData + externalData, }; let session: ort.InferenceSession; @@ -165,9 +187,9 @@ async function initializeSession( } } catch (e) { Logger.error( - 'TestRunner', - `Failed to load model from file: ${modelFilePath}. ` + - `Error: ${e.message} @ ${e.fileName}:${e.lineNumber}`); + 'TestRunner', + `Failed to load model from file: ${modelFilePath}. ` + `Error: ${e.message} @ ${e.fileName}:${e.lineNumber}`, + ); throw e; } @@ -188,11 +210,11 @@ type FileCacheBuffer = { */ export class ModelTestContext { private constructor( - readonly session: ort.InferenceSession, - readonly backend: string, - readonly perfData: ModelTestContext.ModelTestPerfData, - readonly ioBinding: Test.IOBindingMode, - private readonly profile: boolean, + readonly session: ort.InferenceSession, + readonly backend: string, + readonly perfData: ModelTestContext.ModelTestPerfData, + readonly ioBinding: Test.IOBindingMode, + private readonly profile: boolean, ) {} /** @@ -206,7 +228,7 @@ export class ModelTestContext { Logger.verbose('TestRunner.Perf', ` * FirstRun : ${data.firstRun.toFixed(2)}`); const runs = data.runs; if (runs.length > 0) { - Logger.verbose('TestRunner.Perf', ` * Runs : ${runs.map(r => r.toFixed(2)).join(', ')}`); + Logger.verbose('TestRunner.Perf', ` * Runs : ${runs.map((r) => r.toFixed(2)).join(', ')}`); if (runs.length > 1) { const sorted = runs.sort((a, b) => a - b); @@ -232,8 +254,11 @@ export class ModelTestContext { /** * create a ModelTestContext object that used in every test cases in the given ModelTest. */ - static async create(modelTest: Test.ModelTest, profile: boolean, testOptions?: Test.Options): - Promise { + static async create( + modelTest: Test.ModelTest, + profile: boolean, + testOptions?: Test.Options, + ): Promise { if (this.initializing) { throw new Error('cannot create a ModelTestContext object when the previous creation is not done'); } @@ -243,10 +268,16 @@ export class ModelTestContext { const initStart = now(); const executionProviderConfig = - modelTest.backend === 'webnn' ? (testOptions?.webnnOptions || 'webnn') : modelTest.backend!; + modelTest.backend === 'webnn' ? testOptions?.webnnOptions || 'webnn' : modelTest.backend!; const session = await initializeSession( - modelTest.modelUrl, executionProviderConfig, modelTest.ioBinding, profile, modelTest.externalData, - testOptions?.sessionOptions || {}, this.cache); + modelTest.modelUrl, + executionProviderConfig, + modelTest.ioBinding, + profile, + modelTest.externalData, + testOptions?.sessionOptions || {}, + this.cache, + ); const initEnd = now(); @@ -255,11 +286,11 @@ export class ModelTestContext { } return new ModelTestContext( - session, - modelTest.backend!, - {init: initEnd - initStart, firstRun: -1, runs: [], count: 0}, - modelTest.ioBinding, - profile, + session, + modelTest.backend!, + { init: initEnd - initStart, firstRun: -1, runs: [], count: 0 }, + modelTest.ioBinding, + profile, ); } finally { this.initializing = false; @@ -293,9 +324,9 @@ export declare namespace ModelTestContext { export class TensorResultValidator { private readonly absoluteThreshold: number; private readonly relativeThreshold: number; - private readonly maxFloatValue: number = 3.4028234663852886e+38; + private readonly maxFloatValue: number = 3.4028234663852886e38; - private static isHalfFloat: boolean|undefined; + private static isHalfFloat: boolean | undefined; constructor(backend: string) { if (backend === 'cpu') { @@ -340,10 +371,11 @@ export class TensorResultValidator { const match = this.areEqual(actual[i], expected[i]); if (!match) { Logger.error( - 'TestRunner', - `Tensor mismatch: \nACTUAL: type=${actual[i].type}; dims=[${actual[i].dims}]; data=[${ - actual[i].data}]\nEXPECT: type=${expected[i].type}; dims=[${expected[i].dims}]; data=[${ - expected[i].data}]`); + 'TestRunner', + `Tensor mismatch: \nACTUAL: type=${actual[i].type}; dims=[${actual[i].dims}]; data=[${ + actual[i].data + }]\nEXPECT: type=${expected[i].type}; dims=[${expected[i].dims}]; data=[${expected[i].data}]`, + ); } expect(match, 'tensor data should match').to.be.true; } @@ -358,7 +390,10 @@ export class TensorResultValidator { expect(actual, 'keys of output tensors').to.contain.keys(expectedOneOutput.name); } - this.checkApiTensorResult(expected.map(i => actual[i.name]!), expected); + this.checkApiTensorResult( + expected.map((i) => actual[i.name]!), + expected, + ); } // This function check whether 2 tensors should be considered as 'match' or not @@ -397,15 +432,17 @@ export class TensorResultValidator { const actualDataBuffer = actualData.buffer; const actualDataByteOffset = actualData.byteOffset; const actualDataLength = actualData.length; - const actualDataFloat32Array = - new Float32Array(new Float16ArrayPolyfill(actualDataBuffer, actualDataByteOffset, actualDataLength)); + const actualDataFloat32Array = new Float32Array( + new Float16ArrayPolyfill(actualDataBuffer, actualDataByteOffset, actualDataLength), + ); const expectedData = expected.data as Uint16Array; const expectedDataBuffer = expectedData.buffer; const expectedDataByteOffset = expectedData.byteOffset; const expectedDataLength = expectedData.length; - const expectedDataFloat32Array = - new Float32Array(new Float16ArrayPolyfill(expectedDataBuffer, expectedDataByteOffset, expectedDataLength)); + const expectedDataFloat32Array = new Float32Array( + new Float16ArrayPolyfill(expectedDataBuffer, expectedDataByteOffset, expectedDataLength), + ); return this.floatEqual(actualDataFloat32Array, expectedDataFloat32Array); } @@ -413,8 +450,9 @@ export class TensorResultValidator { case 'float32': case 'float64': return this.floatEqual( - actual.data as number[] | Float32Array | Float64Array, - expected.data as number[] | Float32Array | Float64Array); + actual.data as number[] | Float32Array | Float64Array, + expected.data as number[] | Float32Array | Float64Array, + ); case 'uint8': case 'int8': @@ -425,8 +463,9 @@ export class TensorResultValidator { case 'int64': case 'bool': return TensorResultValidator.integerEqual( - actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, - expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array); + actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + ); default: throw new Error('type not implemented or not supported'); @@ -440,7 +479,10 @@ export class TensorResultValidator { return false; } } - floatEqual(actual: number[]|Float32Array|Float64Array, expected: number[]|Float32Array|Float64Array): boolean { + floatEqual( + actual: number[] | Float32Array | Float64Array, + expected: number[] | Float32Array | Float64Array, + ): boolean { if (actual.length !== expected.length) { return false; } @@ -450,24 +492,24 @@ export class TensorResultValidator { let b = expected[i]; if (a === b) { - continue; // exact the same value, treat as equal + continue; // exact the same value, treat as equal } // check for NaN // if (Number.isNaN(a) && Number.isNaN(b)) { - continue; // 2 numbers are NaN, treat as equal + continue; // 2 numbers are NaN, treat as equal } if (Number.isNaN(a) || Number.isNaN(b)) { Logger.error('Validator', `a or b isNan -- index:${i}: actual=${actual[i]},expected=${expected[i]}`); - return false; // one is NaN and the other is not + return false; // one is NaN and the other is not } // check for Infinity // if (!Number.isFinite(a) || !Number.isFinite(b)) { Logger.error('Validator', `a or b is Infinity -- index:${i}: actual=${actual[i]},expected=${expected[i]}`); - return false; // at least one is Infinity and the other is not or their sign is different + return false; // at least one is Infinity and the other is not or their sign is different } // normalize value of b @@ -482,10 +524,10 @@ export class TensorResultValidator { // endif // if (Math.abs(actual[i] - expected[i]) < this.absoluteThreshold) { - continue; // absolute error check pass + continue; // absolute error check pass } if (a !== 0 && b !== 0 && a / b < this.relativeThreshold && b / a < this.relativeThreshold) { - continue; // relative error check pass + continue; // relative error check pass } // if code goes here, it means both (abs/rel) check failed. @@ -496,8 +538,9 @@ export class TensorResultValidator { return true; } static integerEqual( - actual: number[]|Uint8Array|Int8Array|Uint16Array|Int16Array|Uint32Array|Int32Array, - expected: number[]|Uint8Array|Int8Array|Uint16Array|Int16Array|Uint32Array|Int32Array): boolean { + actual: number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + expected: number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + ): boolean { if (actual.length !== expected.length) { return false; } @@ -521,17 +564,21 @@ function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { // eslint-disable-next-line no-bitwise usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, size: Math.ceil(cpuTensor.data.byteLength / 16) * 16, - mappedAtCreation: true + mappedAtCreation: true, }); const arrayBuffer = gpuBuffer.getMappedRange(); - new Uint8Array(arrayBuffer) - .set(new Uint8Array(cpuTensor.data.buffer, cpuTensor.data.byteOffset, cpuTensor.data.byteLength)); + new Uint8Array(arrayBuffer).set( + new Uint8Array(cpuTensor.data.buffer, cpuTensor.data.byteOffset, cpuTensor.data.byteLength), + ); gpuBuffer.unmap(); // TODO: how to "await" for the copy to finish, so that we can get more accurate performance data? - return ort.Tensor.fromGpuBuffer( - gpuBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => gpuBuffer.destroy()}); + return ort.Tensor.fromGpuBuffer(gpuBuffer, { + dataType: cpuTensor.type, + dims: cpuTensor.dims, + dispose: () => gpuBuffer.destroy(), + }); } function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { @@ -546,7 +593,7 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[] const gpuBuffer = device.createBuffer({ // eslint-disable-next-line no-bitwise usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, - size: Math.ceil(size / 16) * 16 + size: Math.ceil(size / 16) * 16, }); return ort.Tensor.fromGpuBuffer(gpuBuffer, { @@ -557,7 +604,7 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[] const stagingBuffer = device.createBuffer({ // eslint-disable-next-line no-bitwise usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, - size: gpuBuffer.size + size: gpuBuffer.size, }); const encoder = device.createCommandEncoder(); encoder.copyBufferToBuffer(gpuBuffer, 0, stagingBuffer, 0, gpuBuffer.size); @@ -568,13 +615,14 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[] stagingBuffer.destroy(); return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.GpuBufferDataTypes]; - } + }, }); } export async function sessionRun(options: { - session: ort.InferenceSession; feeds: Record; - outputsMetaInfo: Record>; + session: ort.InferenceSession; + feeds: Record; + outputsMetaInfo: Record>; ioBinding: Test.IOBindingMode; }): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> { const session = options.session; @@ -603,8 +651,8 @@ export async function sessionRun(options: { if (shouldUploadOutput) { for (const name in options.outputsMetaInfo) { if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { - const {type, dims} = options.outputsMetaInfo[name]; - if (dims.some(d => d === 0)) { + const { type, dims } = options.outputsMetaInfo[name]; + if (dims.some((d) => d === 0)) { fetches[name] = new ort.Tensor(type, [], dims); } else { fetches[name] = createGpuTensorForOutput(type, dims); @@ -615,9 +663,9 @@ export async function sessionRun(options: { const start = now(); Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); - const outputs = await ( - shouldUploadOutput ? session.run(feeds, fetches) : - session.run(feeds, Object.getOwnPropertyNames(options.outputsMetaInfo))); + const outputs = await (shouldUploadOutput + ? session.run(feeds, fetches) + : session.run(feeds, Object.getOwnPropertyNames(options.outputsMetaInfo))); const end = now(); Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); @@ -646,17 +694,24 @@ export async function sessionRun(options: { * run a single model test case. the inputs/outputs tensors should already been prepared. */ export async function runModelTestSet( - context: ModelTestContext, testCase: Test.ModelTestCase, testName: string): Promise { + context: ModelTestContext, + testCase: Test.ModelTestCase, + testName: string, +): Promise { Logger.verbose('TestRunner', `Start to run test data from folder: ${testName}/${testCase.name}`); Logger.verbose('TestRunner', `Start to run test data from folder: ${testCase.name}`); const validator = new TensorResultValidator(context.backend); try { const feeds: Record = {}; const outputsMetaInfo: Record = {}; - testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor); - testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor); - const [start, end, outputs] = - await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); + testCase.inputs!.forEach((tensor) => (feeds[tensor.name] = tensor)); + testCase.outputs!.forEach((tensor) => (outputsMetaInfo[tensor.name] = tensor)); + const [start, end, outputs] = await sessionRun({ + session: context.session, + feeds, + outputsMetaInfo, + ioBinding: context.ioBinding, + }); if (context.perfData.count === 0) { context.perfData.firstRun = end - start; } else { @@ -667,7 +722,7 @@ export async function runModelTestSet( Logger.verbose('TestRunner', `Finished running model from file: ${testCase.name}`); Logger.verbose('TestRunner', ' Stats:'); Logger.verbose('TestRunner', ` Input(s): ${testCase.inputs!.length}`); - testCase.inputs!.forEach(i => { + testCase.inputs!.forEach((i) => { Logger.verbose('TestRunner', ` '${i.name}': ${i.type}[${i.dims.join(',')}]`); }); Logger.verbose('TestRunner', ` Output(s): ${Object.keys(outputs).length}`); @@ -689,10 +744,13 @@ export async function runModelTestSet( } function initializeOperator( - sessionHandler: SessionHandler, opType: string, attributeValues: readonly Test.AttributeValue[], - opsetImports: readonly Test.OperatorTestOpsetImport[]): Operator { + sessionHandler: SessionHandler, + opType: string, + attributeValues: readonly Test.AttributeValue[], + opsetImports: readonly Test.OperatorTestOpsetImport[], +): Operator { const attributes = new Attribute(undefined); - attributeValues.forEach(value => attributes.set(value.name, value.type, value.data)); + attributeValues.forEach((value) => attributes.set(value.name, value.type, value.data)); const graph = createMockGraph(opType, attributes); return sessionHandler.resolve(graph.getNodes()[0], opsetImports, graph); } @@ -711,9 +769,9 @@ export class OpTestContext { this.backendHint = opTest.backend ?? 'cpu'; } createOperator(): Operator { - return initializeOperator( - this.sessionHandler, this.opTest.operator, this.opTest.attributes || [], - [this.opTest.opset ?? {domain: '', version: 7}]); + return initializeOperator(this.sessionHandler, this.opTest.operator, this.opTest.attributes || [], [ + this.opTest.opset ?? { domain: '', version: 7 }, + ]); } async dispose(): Promise { @@ -723,7 +781,7 @@ export class OpTestContext { async init(): Promise { const backend = await resolveBackend(this.backendHint); - this.sessionHandler = backend.createSessionHandler({profiler: OpTestContext.profiler}); + this.sessionHandler = backend.createSessionHandler({ profiler: OpTestContext.profiler }); this.inferenceHandler = this.sessionHandler.createInferenceHandler(); } } @@ -732,15 +790,18 @@ export class OpTestContext { * a ProtoOpTestContext uses a protobuf model for operator test. used for ORT based backend. */ export class ProtoOpTestContext { - private readonly loadedData: Uint8Array; // model data, inputs, outputs + private readonly loadedData: Uint8Array; // model data, inputs, outputs session: ort.InferenceSession; readonly backendHint: string; readonly ioBindingMode: Test.IOBindingMode; - constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) { + constructor( + test: Test.OperatorTest, + private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}, + ) { const opsetImport = onnx.OperatorSetIdProto.create(test.opset); const operator = test.operator; - const attribute = (test.attributes || []).map(attr => { - const protoAttr = onnx.AttributeProto.create({name: attr.name}); + const attribute = (test.attributes || []).map((attr) => { + const protoAttr = onnx.AttributeProto.create({ name: attr.name }); switch (attr.type) { case 'float': protoAttr.type = onnx.AttributeProto.AttributeType.FLOAT; @@ -764,7 +825,7 @@ export class ProtoOpTestContext { break; case 'strings': protoAttr.type = onnx.AttributeProto.AttributeType.STRINGS; - protoAttr.strings = (attr.data as string[]).map(s => new TextEncoder().encode(s)); + protoAttr.strings = (attr.data as string[]).map((s) => new TextEncoder().encode(s)); break; default: throw new Error(`Unsupported attribute type: ${attr.type}`); @@ -777,27 +838,37 @@ export class ProtoOpTestContext { } const inputCount = test.cases[0].inputs!.length; const outputCount = test.cases[0].outputs!.length; - if (test.cases.some( - testCase => testCase.inputs!.length !== inputCount || testCase.outputs!.length !== outputCount)) { + if ( + test.cases.some((testCase) => testCase.inputs!.length !== inputCount || testCase.outputs!.length !== outputCount) + ) { throw new Error( - `Test cases for test: ${test.name} [${test.operator}] must have the same number of inputs and outputs`); + `Test cases for test: ${test.name} [${test.operator}] must have the same number of inputs and outputs`, + ); } - const inputsOmitted = test.cases[0].inputs.map(input => !input.data); - const outputsOmitted = test.cases[0].outputs.map(output => !output.data); + const inputsOmitted = test.cases[0].inputs.map((input) => !input.data); + const outputsOmitted = test.cases[0].outputs.map((output) => !output.data); for (let caseIndex = 1; caseIndex < test.cases.length; caseIndex++) { const testCase = test.cases[caseIndex]; for (let i = 0; i < inputCount; i++) { if (inputsOmitted[i] !== !testCase.inputs![i].data) { - throw new Error(`Test cases for test: ${test.name} [${ - test.operator}] must have consistent inputs data availability. Data of input[${i}] in testCase #0 and #${ - caseIndex} should be both available or both omitted.`); + throw new Error( + `Test cases for test: ${test.name} [${ + test.operator + }] must have consistent inputs data availability. Data of input[${i}] in testCase #0 and #${ + caseIndex + } should be both available or both omitted.`, + ); } } for (let i = 0; i < outputCount; i++) { if (outputsOmitted[i] !== !testCase.outputs![i].data) { - throw new Error(`Test cases for test: ${test.name} [${ - test.operator}] must have consistent outputs data availability. Data of output[${ - i}] in testCase #0 and #${caseIndex} should be both available or both omitted.`); + throw new Error( + `Test cases for test: ${test.name} [${ + test.operator + }] must have consistent outputs data availability. Data of output[${ + i + }] in testCase #0 and #${caseIndex} should be both available or both omitted.`, + ); } } } @@ -807,97 +878,119 @@ export class ProtoOpTestContext { model.opsetImport.push(opsetImport); model.graph = onnx.GraphProto.create(); - model.graph.node = [onnx.NodeProto.create({ - input: test.cases[0].inputs!.map((t, i) => t.data ? `input_${i}` : ''), - output: test.cases[0].outputs!.map((t, i) => t.data ? `output_${i}` : ''), - opType: operator, - domain: test.opset?.domain, - name: operator, - attribute - })]; + model.graph.node = [ + onnx.NodeProto.create({ + input: test.cases[0].inputs!.map((t, i) => (t.data ? `input_${i}` : '')), + output: test.cases[0].outputs!.map((t, i) => (t.data ? `output_${i}` : '')), + opType: operator, + domain: test.opset?.domain, + name: operator, + attribute, + }), + ]; // normalize input shape definitions - let normalizedInputShapeDefinitions: ReadonlyArray; + let normalizedInputShapeDefinitions: ReadonlyArray; if (!test.inputShapeDefinitions || test.inputShapeDefinitions === 'none') { // if inputShapeDefinitions is not specified, use undefined for all inputs normalizedInputShapeDefinitions = new Array(inputCount).fill(undefined); } else if (test.inputShapeDefinitions === 'rankOnly') { // check if all test cases have data - if (test.cases.some(testCase => testCase.inputs!.some(input => !input.data || !input.dims))) { - throw new Error(`Test cases for test: ${test.name} [${ - test.operator}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`); + if (test.cases.some((testCase) => testCase.inputs!.some((input) => !input.data || !input.dims))) { + throw new Error( + `Test cases for test: ${test.name} [${ + test.operator + }] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`, + ); } // if inputShapeDefinitions is 'rankOnly', use semantic names for all inputs. This means only rank is specified. - normalizedInputShapeDefinitions = - test.cases[0].inputs!.map((input: Test.TensorValue, i) => input.dims.map((_, j) => `_input_${i}_d${j}`)); + normalizedInputShapeDefinitions = test.cases[0].inputs!.map((input: Test.TensorValue, i) => + input.dims.map((_, j) => `_input_${i}_d${j}`), + ); // check if all test cases have the same rank for each inputs - if (test.cases.some( - testCase => testCase.inputs!.some( - (input: Test.TensorValue, i) => - input.dims.length !== (test.cases[0].inputs![i] as Test.TensorValue).dims.length))) { - throw new Error(`Test cases for test: ${test.name} [${ - test.operator}] must have the same rank for each inputs in different test cases`); + if ( + test.cases.some((testCase) => + testCase.inputs!.some( + (input: Test.TensorValue, i) => + input.dims.length !== (test.cases[0].inputs![i] as Test.TensorValue).dims.length, + ), + ) + ) { + throw new Error( + `Test cases for test: ${test.name} [${ + test.operator + }] must have the same rank for each inputs in different test cases`, + ); } } else if (test.inputShapeDefinitions === 'static') { // check if all test cases have data - if (test.cases.some(testCase => testCase.inputs!.some(input => !input.data || !input.dims))) { - throw new Error(`Test cases for test: ${test.name} [${ - test.operator}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`); + if (test.cases.some((testCase) => testCase.inputs!.some((input) => !input.data || !input.dims))) { + throw new Error( + `Test cases for test: ${test.name} [${ + test.operator + }] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`, + ); } // if inputShapeDefinitions is 'static', use the shape of the first test case for all inputs. normalizedInputShapeDefinitions = test.cases[0].inputs!.map((input: Test.TensorValue) => input.dims); // check if all test cases have the same shape for each inputs - if (test.cases.some( - testCase => testCase.inputs!.some( - (input: Test.TensorValue, i) => TensorResultValidator.integerEqual( - input.dims, (test.cases[0].inputs![i] as Test.TensorValue).dims)))) { - throw new Error(`Test cases for test: ${test.name} [${ - test.operator}] must have the same shape for each inputs in different test cases`); + if ( + test.cases.some((testCase) => + testCase.inputs!.some((input: Test.TensorValue, i) => + TensorResultValidator.integerEqual(input.dims, (test.cases[0].inputs![i] as Test.TensorValue).dims), + ), + ) + ) { + throw new Error( + `Test cases for test: ${test.name} [${ + test.operator + }] must have the same shape for each inputs in different test cases`, + ); } } else { // if inputShapeDefinitions is specified as an array, use it as is. // check if inputShapeDefinitions has the same number of inputs as test cases if (test.inputShapeDefinitions && test.inputShapeDefinitions.length !== inputCount) { throw new Error( - `Input shape definitions for test: ${test.name} [${test.operator}] must have the same number of inputs`); + `Input shape definitions for test: ${test.name} [${test.operator}] must have the same number of inputs`, + ); } normalizedInputShapeDefinitions = test.inputShapeDefinitions; } - model.graph.input = - test.cases[0] - .inputs! - .map((input, i) => { - const shapeDefinition = normalizedInputShapeDefinitions[i]; - const shape = shapeDefinition ? onnx.TensorShapeProto.create({ - dim: shapeDefinition.map( - dim => onnx.TensorShapeProto.Dimension.create( - typeof dim === 'string' ? {dimParam: dim} : {dimValue: dim})) - }) : - undefined; - return onnx.ValueInfoProto.create({ - name: `input_${i}`, - type: onnx.TypeProto.create({ - tensorType: onnx.TypeProto.Tensor.create({elemType: tensorDataTypeStringToEnum(input.type), shape}), - }), - }); + model.graph.input = test.cases[0] + .inputs!.map((input, i) => { + const shapeDefinition = normalizedInputShapeDefinitions[i]; + const shape = shapeDefinition + ? onnx.TensorShapeProto.create({ + dim: shapeDefinition.map((dim) => + onnx.TensorShapeProto.Dimension.create(typeof dim === 'string' ? { dimParam: dim } : { dimValue: dim }), + ), }) - .filter((_, i) => test.cases[0].inputs![i].data); - - model.graph.output = - test.cases[0] - .outputs! - .map((output, i) => onnx.ValueInfoProto.create({ - name: `output_${i}`, - type: onnx.TypeProto.create({ - tensorType: onnx.TypeProto.Tensor.create({elemType: tensorDataTypeStringToEnum(output.type)}), - }), - })) - .filter((_, i) => test.cases[0].outputs![i].data); + : undefined; + return onnx.ValueInfoProto.create({ + name: `input_${i}`, + type: onnx.TypeProto.create({ + tensorType: onnx.TypeProto.Tensor.create({ elemType: tensorDataTypeStringToEnum(input.type), shape }), + }), + }); + }) + .filter((_, i) => test.cases[0].inputs![i].data); + + model.graph.output = test.cases[0] + .outputs!.map((output, i) => + onnx.ValueInfoProto.create({ + name: `output_${i}`, + type: onnx.TypeProto.create({ + tensorType: onnx.TypeProto.Tensor.create({ elemType: tensorDataTypeStringToEnum(output.type) }), + }), + }), + ) + .filter((_, i) => test.cases[0].outputs![i].data); model.graph.name = test.name; @@ -907,8 +1000,9 @@ export class ProtoOpTestContext { // in debug mode, open a new tab in browser for the generated onnx model. if (ort.env.debug) { - const modelFile = - new File([this.loadedData], `op_test_generated_model_${test.name}.onnx`, {type: 'application/octet-stream'}); + const modelFile = new File([this.loadedData], `op_test_generated_model_${test.name}.onnx`, { + type: 'application/octet-stream', + }); const modelTempUrl = URL.createObjectURL(modelFile); const a = document.createElement('a'); a.href = modelTempUrl; @@ -922,7 +1016,7 @@ export class ProtoOpTestContext { this.session = await ort.InferenceSession.create(this.loadedData, { executionProviders: [this.backendHint], preferredOutputLocation: this.ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, - ...this.sessionOptions + ...this.sessionOptions, }); } @@ -932,13 +1026,16 @@ export class ProtoOpTestContext { } async function runProtoOpTestcase( - session: ort.InferenceSession, testCase: Test.OperatorTestCase, ioBindingMode: Test.IOBindingMode, - validator: TensorResultValidator): Promise { + session: ort.InferenceSession, + testCase: Test.OperatorTestCase, + ioBindingMode: Test.IOBindingMode, + validator: TensorResultValidator, +): Promise { const feeds: Record = {}; - const fetches: Record> = {}; + const fetches: Record> = {}; testCase.inputs.forEach((input, i) => { if (input.data) { - let data: number[]|BigUint64Array|BigInt64Array|Uint16Array = input.data; + let data: number[] | BigUint64Array | BigInt64Array | Uint16Array = input.data; if (input.type === 'uint64') { data = BigUint64Array.from(input.data.map(BigInt)); } else if (input.type === 'int64') { @@ -955,7 +1052,7 @@ async function runProtoOpTestcase( const expectedOutputNames: string[] = []; testCase.outputs.forEach((output, i) => { if (output.data) { - let data: number[]|BigUint64Array|BigInt64Array|Uint16Array = output.data; + let data: number[] | BigUint64Array | BigInt64Array | Uint16Array = output.data; if (output.type === 'uint64') { data = BigUint64Array.from(output.data.map(BigInt)); } else if (output.type === 'int64') { @@ -966,17 +1063,17 @@ async function runProtoOpTestcase( } outputs.push(new ort.Tensor(output.type, data, output.dims)); expectedOutputNames.push(`output_${i}`); - fetches[`output_${i}`] = {dims: output.dims, type: output.type}; + fetches[`output_${i}`] = { dims: output.dims, type: output.type }; } }); - const [, , results] = await sessionRun({session, feeds, outputsMetaInfo: fetches, ioBinding: ioBindingMode}); + const [, , results] = await sessionRun({ session, feeds, outputsMetaInfo: fetches, ioBinding: ioBindingMode }); const actualOutputNames = Object.getOwnPropertyNames(results); expect(actualOutputNames.length).to.equal(expectedOutputNames.length); expect(actualOutputNames).to.have.members(expectedOutputNames); - const actualOutputs = actualOutputNames.map(name => results[name]); + const actualOutputs = actualOutputNames.map((name) => results[name]); validator.checkApiTensorResult(actualOutputs, outputs); } @@ -989,13 +1086,17 @@ function createTensor(dims: number[], type: Tensor.DataType, data: number[]): Te } async function runOpTestcase( - inferenceHandler: InferenceHandler, operator: Operator, testcase: Test.OperatorTestCase, - validator: TensorResultValidator): Promise { + inferenceHandler: InferenceHandler, + operator: Operator, + testcase: Test.OperatorTestCase, + validator: TensorResultValidator, +): Promise { testcase.inputs.forEach((input: Test.TensorValue, i) => { Logger.verbose('TestOpRunner', ` Input '${i}': ${input.type}[${input.dims.join(',')}]`); }); - const inputTensors = testcase.inputs.map( - (input: Test.TensorValue) => createTensor(input.dims, input.type as Tensor.DataType, input.data)); + const inputTensors = testcase.inputs.map((input: Test.TensorValue) => + createTensor(input.dims, input.type as Tensor.DataType, input.data), + ); const results = operator.impl(inferenceHandler, inputTensors, operator.context); @@ -1003,15 +1104,15 @@ async function runOpTestcase( for (const result of results) { try { await result.getData(); - } catch { - } + } catch {} } results.forEach((output, i) => { Logger.verbose('TestOpRunner', ` Result'${i}': ${output.type}[${output.dims.join(',')}]`); }); - const expectedTensors = testcase.outputs.map( - (output: Test.TensorValue) => createTensor(output.dims, output.type as Tensor.DataType, output.data)); + const expectedTensors = testcase.outputs.map((output: Test.TensorValue) => + createTensor(output.dims, output.type as Tensor.DataType, output.data), + ); validator.checkTensorResult(results, expectedTensors); } @@ -1019,12 +1120,22 @@ async function runOpTestcase( * run a single operator test case. */ export async function runOpTest( - testcase: Test.OperatorTestCase, context: ProtoOpTestContext|OpTestContext): Promise { + testcase: Test.OperatorTestCase, + context: ProtoOpTestContext | OpTestContext, +): Promise { if (context instanceof ProtoOpTestContext) { await runProtoOpTestcase( - context.session, testcase, context.ioBindingMode, new TensorResultValidator(context.backendHint)); + context.session, + testcase, + context.ioBindingMode, + new TensorResultValidator(context.backendHint), + ); } else { await runOpTestcase( - context.inferenceHandler, context.createOperator(), testcase, new TensorResultValidator(context.backendHint)); + context.inferenceHandler, + context.createOperator(), + testcase, + new TensorResultValidator(context.backendHint), + ); } } diff --git a/js/web/test/test-shared.ts b/js/web/test/test-shared.ts index 55beb66e37e6e..605f2eae2e7fe 100644 --- a/js/web/test/test-shared.ts +++ b/js/web/test/test-shared.ts @@ -4,8 +4,8 @@ import * as base64 from 'base64-js'; import * as fs from 'node:fs/promises'; -import {Attribute} from '../lib/onnxjs/attribute'; -import {Graph} from '../lib/onnxjs/graph'; +import { Attribute } from '../lib/onnxjs/attribute'; +import { Graph } from '../lib/onnxjs/graph'; export function base64toBuffer(data: string): Uint8Array { return base64.toByteArray(data); @@ -24,7 +24,7 @@ async function retry(fn: () => Promise, maxRetries = 3, delay = 100): Prom if (retries-- === 0) { throw err; } - await new Promise(resolve => setTimeout(resolve, delay)); + await new Promise((resolve) => setTimeout(resolve, delay)); } // eslint-disable-next-line no-constant-condition } while (true); @@ -54,13 +54,13 @@ export async function readJsonFile(file: string): Promise { * create a single-node graph for unit test purpose */ export function createMockGraph(opType: string, attributes: Attribute): Graph { - const node: Graph.Node = {name: '', opType, inputs: [], outputs: [], attributes}; + const node: Graph.Node = { name: '', opType, inputs: [], outputs: [], attributes }; return { getInputIndices: () => [], getInputNames: () => [], getOutputIndices: () => [], getOutputNames: () => [], getNodes: () => [node], - getValues: () => [] + getValues: () => [], }; } diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index 14b9fd7c005ab..be1e56485ec5a 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; +import { Env, InferenceSession, Tensor } from 'onnxruntime-common'; -import {Attribute} from '../lib/onnxjs/attribute'; -import {Logger} from '../lib/onnxjs/instrument'; +import { Attribute } from '../lib/onnxjs/attribute'; +import { Logger } from '../lib/onnxjs/instrument'; export declare namespace Test { export interface NamedTensor extends Tensor { @@ -53,20 +53,20 @@ export declare namespace Test { * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation` * will not be set. */ - export type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; + export type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location'; export interface ModelTestCase { name: string; dataFiles: readonly string[]; - inputs?: NamedTensor[]; // value should be populated at runtime - outputs?: NamedTensor[]; // value should be populated at runtime + inputs?: NamedTensor[]; // value should be populated at runtime + outputs?: NamedTensor[]; // value should be populated at runtime } export interface ModelTest { name: string; modelUrl: string; externalData?: InferenceSession.SessionOptions['externalData']; - backend?: string; // value should be populated at build time + backend?: string; // value should be populated at build time ioBinding: IOBindingMode; platformCondition?: PlatformCondition; cases: readonly ModelTestCase[]; @@ -79,8 +79,8 @@ export declare namespace Test { export interface OperatorTestCase { name: string; - inputs: ReadonlyArray; - outputs: ReadonlyArray; + inputs: ReadonlyArray; + outputs: ReadonlyArray; } export interface OperatorTestOpsetImport { @@ -88,14 +88,14 @@ export declare namespace Test { version: number; } - export type InputShapeDefinition = ReadonlyArray; + export type InputShapeDefinition = ReadonlyArray; export interface OperatorTest { name: string; operator: string; - inputShapeDefinitions?: 'none'|'rankOnly'|'static'|ReadonlyArray; + inputShapeDefinitions?: 'none' | 'rankOnly' | 'static' | ReadonlyArray; opset?: OperatorTestOpsetImport; - backend?: string; // value should be populated at build time + backend?: string; // value should be populated at build time ioBinding: IOBindingMode; platformCondition?: PlatformCondition; attributes?: readonly AttributeValue[]; @@ -114,7 +114,7 @@ export declare namespace Test { name: string; platformCondition: PlatformCondition; } - export type Test = TestName|TestDescription; + export type Test = TestName | TestDescription; } /** @@ -122,10 +122,10 @@ export declare namespace Test { * A testlist should only be applied when running suite test cases (suite0) */ export interface TestList { - [backend: string]: {[group: string]: readonly TestList.Test[]}; + [backend: string]: { [group: string]: readonly TestList.Test[] }; } - interface EnvOptions extends Partial> { + interface EnvOptions extends Partial> { wasm: Partial; webgl: Partial; webgpu: Partial; @@ -166,7 +166,7 @@ export declare namespace Test { fileCacheUrls?: readonly string[]; - log: ReadonlyArray<{category: string; config: Logger.Config}>; + log: ReadonlyArray<{ category: string; config: Logger.Config }>; profile: boolean; options: Options; } diff --git a/js/web/test/training/e2e/browser-test-wasm.js b/js/web/test/training/e2e/browser-test-wasm.js index fa87389f7ac46..05750ed149303 100644 --- a/js/web/test/training/e2e/browser-test-wasm.js +++ b/js/web/test/training/e2e/browser-test-wasm.js @@ -3,19 +3,19 @@ 'use strict'; -describe('Browser E2E testing for training package', function() { - it('Check that training package encompasses inference', async function() { +describe('Browser E2E testing for training package', function () { + it('Check that training package encompasses inference', async function () { ort.env.wasm.numThreads = 1; - await testInferenceFunction(ort, {executionProviders: ['wasm']}); + await testInferenceFunction(ort, { executionProviders: ['wasm'] }); }); - it('Check training functionality, all options', async function() { + it('Check training functionality, all options', async function () { ort.env.wasm.numThreads = 1; - await testTrainingFunctionAll(ort, {executionProviders: ['wasm']}); + await testTrainingFunctionAll(ort, { executionProviders: ['wasm'] }); }); - it('Check training functionality, minimum options', async function() { + it('Check training functionality, minimum options', async function () { ort.env.wasm.numThreads = 1; - await testTrainingFunctionMin(ort, {executionProviders: ['wasm']}); + await testTrainingFunctionMin(ort, { executionProviders: ['wasm'] }); }); }); diff --git a/js/web/test/training/e2e/common.js b/js/web/test/training/e2e/common.js index b6040b63d56b4..0574ae85aabd1 100644 --- a/js/web/test/training/e2e/common.js +++ b/js/web/test/training/e2e/common.js @@ -13,13 +13,13 @@ const trainingSessionAllOptions = { checkpointState: TRAININGDATA_CKPT, trainModel: TRAININGDATA_TRAIN_MODEL, evalModel: TRAININGDATA_EVAL_MODEL, - optimizerModel: TRAININGDATA_OPTIMIZER_MODEL -} + optimizerModel: TRAININGDATA_OPTIMIZER_MODEL, +}; const trainingSessionMinOptions = { checkpointState: TRAININGDATA_CKPT, trainModel: TRAININGDATA_TRAIN_MODEL, -} +}; // ASSERT METHODS @@ -51,7 +51,7 @@ function assertTwoListsUnequal(list1, list2) { // HELPER METHODS FOR TESTS -function generateGaussianRandom(mean=0, scale=1) { +function generateGaussianRandom(mean = 0, scale = 1) { const u = 1 - Math.random(); const v = Math.random(); const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); @@ -106,12 +106,12 @@ function checkEvalModel(trainingSession) { */ function checkNoEvalModel(trainingSession) { try { - assertStrictEquals(trainingSession.evalInputNames, "should have thrown an error upon accessing"); + assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing'); } catch (error) { assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); } try { - assertStrictEquals(trainingSession.evalOutputNames, "should have thrown an error upon accessing"); + assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing'); } catch (error) { assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); } @@ -124,15 +124,15 @@ function checkNoEvalModel(trainingSession) { * @param {*} feeds * @returns */ -var runTrainStepAndCheck = async function(trainingSession, feeds) { - const results = await trainingSession.runTrainStep(feeds); +var runTrainStepAndCheck = async function (trainingSession, feeds) { + const results = await trainingSession.runTrainStep(feeds); assertStrictEquals(Object.keys(results).length, 1); assertStrictEquals(results['onnx::loss::21273'].data.length, 1); assertStrictEquals(results['onnx::loss::21273'].type, 'float32'); return results; }; -var loadParametersBufferAndCheck = async function(trainingSession, paramsLength, constant, paramsBefore) { +var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) { // make a float32 array that is filled with the constant const newParams = new Float32Array(paramsLength); for (let i = 0; i < paramsLength; i++) { @@ -155,18 +155,20 @@ var loadParametersBufferAndCheck = async function(trainingSession, paramsLength, } return paramsAfterLoad; -} +}; // TESTS -var testInferenceFunction = async function(ort, options) { +var testInferenceFunction = async function (ort, options) { const session = await ort.InferenceSession.create('data/model.onnx', options || {}); const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); - const fetches = - await session.run({a: new ort.Tensor('float32', dataA, [3, 4]), b: new ort.Tensor('float32', dataB, [4, 3])}); + const fetches = await session.run({ + a: new ort.Tensor('float32', dataA, [3, 4]), + b: new ort.Tensor('float32', dataB, [4, 3]), + }); const c = fetches.c; @@ -183,12 +185,12 @@ var testInferenceFunction = async function(ort, options) { assert(c.data[8] === 3300); }; -var testTrainingFunctionMin = async function(ort, options) { +var testTrainingFunctionMin = async function (ort, options) { const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options); checkNoEvalModel(trainingSession); const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); const labels = new ort.Tensor('int32', [2, 1], [2]); - const feeds = {"input-0": input0, "labels": labels}; + const feeds = { 'input-0': input0, labels: labels }; // check getParametersSize const paramsSize = await trainingSession.getParametersSize(); @@ -204,15 +206,15 @@ var testTrainingFunctionMin = async function(ort, options) { await runTrainStepAndCheck(trainingSession, feeds); await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams); -} +}; -var testTrainingFunctionAll = async function(ort, options) { +var testTrainingFunctionAll = async function (ort, options) { const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options); checkEvalModel(trainingSession); const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); const labels = new ort.Tensor('int32', [2, 1], [2]); - let feeds = {"input-0": input0, "labels": labels}; + let feeds = { 'input-0': input0, labels: labels }; // check getParametersSize const paramsSize = await trainingSession.getParametersSize(); @@ -228,7 +230,7 @@ var testTrainingFunctionAll = async function(ort, options) { const results = await runTrainStepAndCheck(trainingSession, feeds); await trainingSession.runOptimizerStep(feeds); - feeds = {"input-0": input0, "labels": labels}; + feeds = { 'input-0': input0, labels: labels }; // check getContiguousParameters after optimizerStep -- that the parameters have been updated const optimizedParams = await trainingSession.getContiguousParameters(); assertTwoListsUnequal(originalParams.data, optimizedParams.data); @@ -239,7 +241,7 @@ var testTrainingFunctionAll = async function(ort, options) { assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data); await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams); -} +}; if (typeof module === 'object') { module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest]; diff --git a/js/web/test/training/e2e/karma.conf.js b/js/web/test/training/e2e/karma.conf.js index 7900fbb27bbe1..74662b67676f7 100644 --- a/js/web/test/training/e2e/karma.conf.js +++ b/js/web/test/training/e2e/karma.conf.js @@ -15,23 +15,23 @@ if (typeof USER_DATA !== 'string') { throw new Error('flag --user-data= is required'); } -module.exports = function(config) { +module.exports = function (config) { const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; config.set({ frameworks: ['mocha'], files: [ - {pattern: distPrefix + ORT_MAIN}, - {pattern: './common.js'}, - {pattern: TEST_MAIN}, - {pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true}, - {pattern: './data/*', included: false}, + { pattern: distPrefix + ORT_MAIN }, + { pattern: './common.js' }, + { pattern: TEST_MAIN }, + { pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true }, + { pattern: './data/*', included: false }, ], plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], proxies: { '/model.onnx': '/base/model.onnx', '/data/': '/base/data/', }, - client: {captureConsole: true, mocha: {expose: ['body'], timeout: 60000}}, + client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } }, reporters: ['mocha'], captureTimeout: 120000, reportSlowerThan: 100, @@ -42,13 +42,13 @@ module.exports = function(config) { hostname: 'localhost', browsers: [], customLaunchers: { - Chrome_default: {base: 'ChromeHeadless', chromeDataDir: USER_DATA}, + Chrome_default: { base: 'ChromeHeadless', chromeDataDir: USER_DATA }, Chrome_no_threads: { base: 'ChromeHeadless', chromeDataDir: USER_DATA, // TODO: no-thread flags }, - Edge_default: {base: 'Edge', edgeDataDir: USER_DATA} - } + Edge_default: { base: 'Edge', edgeDataDir: USER_DATA }, + }, }); }; diff --git a/js/web/test/training/e2e/run.js b/js/web/test/training/e2e/run.js index cc92f7ca58bd5..d12bcc7aa66ed 100644 --- a/js/web/test/training/e2e/run.js +++ b/js/web/test/training/e2e/run.js @@ -5,7 +5,7 @@ const path = require('path'); const fs = require('fs-extra'); -const {spawn} = require('child_process'); +const { spawn } = require('child_process'); const startServer = require('./simple-http-server'); const minimist = require('minimist'); @@ -31,7 +31,7 @@ const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data'); // always use a new folder as user-data-dir let nextUserDataDirId = 0; function getNextUserDataDir() { - const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()) + const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()); nextUserDataDirId++; fs.emptyDirSync(dir); return dir; @@ -42,10 +42,10 @@ const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default'; async function main() { // find packed package - const {globbySync} = await import('globby'); + const { globbySync } = await import('globby'); const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common'); - const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', {cwd: ORT_COMMON_FOLDER}); + const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', { cwd: ORT_COMMON_FOLDER }); const PACKAGES_TO_INSTALL = []; @@ -56,7 +56,7 @@ async function main() { } const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web'); - const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', {cwd: ORT_WEB_FOLDER}); + const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', { cwd: ORT_WEB_FOLDER }); if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) { throw new Error('cannot find exactly single package for onnxruntime-web.'); } @@ -68,7 +68,7 @@ async function main() { await runInShell(`npm install`); // npm install with "--cache" to install packed packages with an empty cache folder - await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map(i => `"${i}"`).join(' ')}`); + await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map((i) => `"${i}"`).join(' ')}`); // prepare training data prepareTrainingDataByCopying(); @@ -77,7 +77,7 @@ async function main() { console.log('Running self-hosted tests'); console.log('==============================================================='); // test cases with self-host (ort hosted in same origin) - await testAllBrowserCases({hostInKarma: true}); + await testAllBrowserCases({ hostInKarma: true }); console.log('==============================================================='); console.log('Running not self-hosted tests'); @@ -85,24 +85,27 @@ async function main() { // test cases without self-host (ort hosted in cross origin) const server = startServer(path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web'), 8081); try { - await testAllBrowserCases({hostInKarma: false}); + await testAllBrowserCases({ hostInKarma: false }); } finally { // close the server after all tests await server.close(); } } -async function testAllBrowserCases({hostInKarma}) { - await runKarma({hostInKarma, main: './browser-test-wasm.js'}); +async function testAllBrowserCases({ hostInKarma }) { + await runKarma({ hostInKarma, main: './browser-test-wasm.js' }); } -async function runKarma({hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js'}) { +async function runKarma({ hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js' }) { console.log('==============================================================='); console.log(`Running karma with the following binary: ${ortMain}`); console.log('==============================================================='); const selfHostFlag = hostInKarma ? '--self-host' : ''; - await runInShell(`npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ - ortMain} --test-main=${main} --user-data=${getNextUserDataDir()}`); + await runInShell( + `npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ + ortMain + } --test-main=${main} --user-data=${getNextUserDataDir()}`, + ); } async function runInShell(cmd) { @@ -111,8 +114,8 @@ async function runInShell(cmd) { console.log(' > ' + cmd); console.log('==============================================================='); let complete = false; - const childProcess = spawn(cmd, {shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER}); - childProcess.on('close', function(code) { + const childProcess = spawn(cmd, { shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER }); + childProcess.on('close', function (code) { if (code !== 0) { process.exit(code); } else { @@ -125,8 +128,8 @@ async function runInShell(cmd) { } async function delay(ms) { - return new Promise(function(resolve) { - setTimeout(function() { + return new Promise(function (resolve) { + setTimeout(function () { resolve(); }, ms); }); diff --git a/js/web/test/training/e2e/simple-http-server.js b/js/web/test/training/e2e/simple-http-server.js index d1f8bdd5c2367..ef9cced681cc8 100644 --- a/js/web/test/training/e2e/simple-http-server.js +++ b/js/web/test/training/e2e/simple-http-server.js @@ -32,35 +32,36 @@ const getRequestData = (url, dir) => { return [filepath, mimeType]; }; -module.exports = function(dir, port) { - const server = http.createServer(function(request, response) { - const url = request.url.replace(/\n|\r/g, ''); - console.log(`request ${url}`); +module.exports = function (dir, port) { + const server = http + .createServer(function (request, response) { + const url = request.url.replace(/\n|\r/g, ''); + console.log(`request ${url}`); - const requestData = getRequestData(url, dir); - if (!request || !requestData) { - response.writeHead(404); - response.end('404'); - } else { - const [filePath, contentType] = requestData; - fs.readFile(path.resolve(dir, filePath), function(error, content) { - if (error) { - if (error.code == 'ENOENT') { - response.writeHead(404); - response.end('404'); - } else { - response.writeHead(500); - response.end('500'); - } - } else { - response.setHeader('access-control-allow-origin', '*'); - response.writeHead(200, {'Content-Type': contentType}); - response.end(content, 'utf-8'); - } - }); - } - }) - .listen(port); + const requestData = getRequestData(url, dir); + if (!request || !requestData) { + response.writeHead(404); + response.end('404'); + } else { + const [filePath, contentType] = requestData; + fs.readFile(path.resolve(dir, filePath), function (error, content) { + if (error) { + if (error.code == 'ENOENT') { + response.writeHead(404); + response.end('404'); + } else { + response.writeHead(500); + response.end('500'); + } + } else { + response.setHeader('access-control-allow-origin', '*'); + response.writeHead(200, { 'Content-Type': contentType }); + response.end(content, 'utf-8'); + } + }); + } + }) + .listen(port); console.log(`Server running at http://localhost:${port}/`); return server; }; diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 014fc57f21558..60dd32dfcab5a 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -1,20 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Attribute} from '../../../../lib/onnxjs/attribute'; -import {Backend, InferenceHandler, resolveBackend, SessionHandler} from '../../../../lib/onnxjs/backend'; -import {Profiler} from '../../../../lib/onnxjs/instrument'; -import {Tensor} from '../../../../lib/onnxjs/tensor'; -import {PoolConvUtil} from '../../../../lib/onnxjs/util'; -import {TensorResultValidator} from '../../../test-runner'; -import {createMockGraph} from '../../../test-shared'; +import { Attribute } from '../../../../lib/onnxjs/attribute'; +import { Backend, InferenceHandler, resolveBackend, SessionHandler } from '../../../../lib/onnxjs/backend'; +import { Profiler } from '../../../../lib/onnxjs/instrument'; +import { Tensor } from '../../../../lib/onnxjs/tensor'; +import { PoolConvUtil } from '../../../../lib/onnxjs/util'; +import { TensorResultValidator } from '../../../test-runner'; +import { createMockGraph } from '../../../test-shared'; -import {conv2d} from './test-conv-utils'; +import { conv2d } from './test-conv-utils'; function createRandomArray(size: number): Float32Array { const randomTable = [0, 3, 6, 9, 2, 5, 8, 1, 4, 7]; return new Float32Array( - Array.from({length: size}, (_v, k) => randomTable[k % 10] * 0.1 + randomTable[Math.trunc(k / 10) % 10] * 0.01)); + Array.from({ length: size }, (_v, k) => randomTable[k % 10] * 0.1 + randomTable[Math.trunc(k / 10) % 10] * 0.01), + ); } interface TestData { inputShape: number[]; @@ -35,7 +36,7 @@ function getTestData(): TestData[] { autoPad: 'SAME_UPPER', dilations: [1, 1], strides: [1, 1], - group: 1 + group: 1, }, { inputShape: [1, 3, 224, 224], @@ -44,7 +45,7 @@ function getTestData(): TestData[] { pads: [0, 0, 0, 0], dilations: [1, 1], strides: [2, 2], - group: 1 + group: 1, }, { inputShape: [1, 64, 55, 55], @@ -53,7 +54,7 @@ function getTestData(): TestData[] { pads: [0, 0, 0, 0], dilations: [1, 1], strides: [1, 1], - group: 1 + group: 1, }, // { // inputShape: [1, 16, 55, 55], @@ -278,7 +279,7 @@ function getTestData(): TestData[] { pads: [1, 1, 1, 1], dilations: [1, 1], strides: [1, 1], - group: 1 + group: 1, }, { inputShape: [1, 2, 3, 3], @@ -287,7 +288,7 @@ function getTestData(): TestData[] { pads: [0, 0, 0, 0], dilations: [1, 1], strides: [1, 1], - group: 1 + group: 1, }, { inputShape: [1, 3, 224, 224], @@ -296,7 +297,7 @@ function getTestData(): TestData[] { pads: [3, 3, 3, 3], dilations: [1, 1], strides: [2, 2], - group: 1 + group: 1, }, // { // inputShape: [1, 64, 56, 56], @@ -765,7 +766,7 @@ function getTestData(): TestData[] { pads: [1, 1, 1, 1], dilations: [1, 1], strides: [1, 1], - group: 1 + group: 1, }, { inputShape: [1, 512, 7, 7], @@ -775,7 +776,7 @@ function getTestData(): TestData[] { pads: [0, 0, 0, 0], dilations: [1, 1], strides: [1, 1], - group: 1 + group: 1, }, // { // inputShape: [1, 2048, 7, 7], @@ -811,13 +812,19 @@ function getTestData(): TestData[] { } const validator = new TensorResultValidator('webgl'); -let webglBackend: Backend|undefined; -let webglSessionhandler: SessionHandler|undefined; -let webglInferenceHandler: InferenceHandler|undefined; +let webglBackend: Backend | undefined; +let webglSessionhandler: SessionHandler | undefined; +let webglInferenceHandler: InferenceHandler | undefined; function webglConv( - inputTensor: Tensor, kernelTensor: Tensor, biasTensor: Tensor|null, autoPad: string|undefined, dilations: number[], - pads: number[]|undefined, strides: number[]): Tensor { + inputTensor: Tensor, + kernelTensor: Tensor, + biasTensor: Tensor | null, + autoPad: string | undefined, + dilations: number[], + pads: number[] | undefined, + strides: number[], +): Tensor { const attributes = new Attribute(undefined); attributes.set('dilations', 'ints', dilations); attributes.set('auto_pad', 'string', autoPad ? autoPad : ''); @@ -827,16 +834,22 @@ function webglConv( } attributes.set('strides', 'ints', strides); const graph = createMockGraph('Conv', attributes); - const op = webglSessionhandler!.resolve(graph.getNodes()[0], [{domain: '', version: 7}], graph); + const op = webglSessionhandler!.resolve(graph.getNodes()[0], [{ domain: '', version: 7 }], graph); const inputs = [inputTensor, kernelTensor]; if (biasTensor) { inputs.push(biasTensor); } - return (op.impl(webglInferenceHandler!, inputs, op.context))[0]; + return op.impl(webglInferenceHandler!, inputs, op.context)[0]; } function cpuConv( - inputTensor: Tensor, kernelTensor: Tensor, biasTensor: Tensor|null, autoPad: string|undefined, dilations: number[], - pads: number[]|undefined, strides: number[]): Tensor { + inputTensor: Tensor, + kernelTensor: Tensor, + biasTensor: Tensor | null, + autoPad: string | undefined, + dilations: number[], + pads: number[] | undefined, + strides: number[], +): Tensor { const attributes = new Attribute(undefined); attributes.set('dilations', 'ints', dilations); attributes.set('auto_pad', 'string', autoPad ? autoPad : ''); @@ -852,7 +865,14 @@ function cpuConv( const adjustedPads = pads ? pads.slice(0) : [0, 0, 0, 0]; const outputDims = PoolConvUtil.computeConvOutputShape( - x.dims, w.dims, strides, dilations, kernelTensor.dims.slice(2), adjustedPads, autoPad); + x.dims, + w.dims, + strides, + dilations, + kernelTensor.dims.slice(2), + adjustedPads, + autoPad, + ); const y = new Tensor(outputDims, x.type); conv2d(y, x, w, b, dilations, 1, adjustedPads, strides); return y; @@ -861,7 +881,7 @@ describe('New Conv tests', () => { before(async () => { const profiler = Profiler.create(); webglBackend = await resolveBackend('webgl'); - webglSessionhandler = webglBackend.createSessionHandler({profiler}); + webglSessionhandler = webglBackend.createSessionHandler({ profiler }); webglInferenceHandler = webglSessionhandler.createInferenceHandler(); }); const testDataSet = getTestData(); @@ -872,9 +892,9 @@ describe('New Conv tests', () => { const kernelData = createRandomArray(testData.kernelShape.reduce((a, b) => a * b)); const biasData = testData.biasShape.length === 1 ? createRandomArray(testData.biasShape[0]) : null; const rgbas = [false]; - rgbas.forEach(rgba => { + rgbas.forEach((rgba) => { describe(`RGBA: ${rgba}`, () => { - before(function() { + before(function () { const patchSize = testData.kernelShape.slice(1).reduce((a, b) => a * b); if (rgba && patchSize % 4 !== 0) { // eslint-disable-next-line no-invalid-this @@ -885,14 +905,27 @@ describe('New Conv tests', () => { // create new Tensors otherwise the session/inference level caching would cause issues const inputTensor = new Tensor(testData.inputShape, 'float32', undefined, undefined, inputData); const kernelTensor = new Tensor(testData.kernelShape, 'float32', undefined, undefined, kernelData); - const biasTensor = - biasData ? new Tensor(testData.biasShape, 'float32', undefined, undefined, biasData) : null; + const biasTensor = biasData + ? new Tensor(testData.biasShape, 'float32', undefined, undefined, biasData) + : null; const actual = webglConv( - inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads, - testData.strides); + inputTensor, + kernelTensor, + biasTensor, + testData.autoPad, + testData.dilations, + testData.pads, + testData.strides, + ); const expected = cpuConv( - inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads, - testData.strides); + inputTensor, + kernelTensor, + biasTensor, + testData.autoPad, + testData.dilations, + testData.pads, + testData.strides, + ); try { validator.checkTensorResult([actual], [expected]); } catch { diff --git a/js/web/test/unittests/backends/webgl/test-conv-utils.ts b/js/web/test/unittests/backends/webgl/test-conv-utils.ts index 32cace1ea9040..778d498efe1c0 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-utils.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-utils.ts @@ -1,15 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Tensor} from '../../../../lib/onnxjs/tensor'; +import { Tensor } from '../../../../lib/onnxjs/tensor'; /* eslint-disable no-bitwise */ // eslint-disable-next-line no-underscore-dangle function matMul2d_( - A: Float32Array|Float64Array, B: Float32Array|Float64Array, C: Float32Array|Float64Array, alpha: number, - beta: number, M: number, N: number, K: number) { - let offsetA = 0, offsetB = 0, offsetC = 0; + A: Float32Array | Float64Array, + B: Float32Array | Float64Array, + C: Float32Array | Float64Array, + alpha: number, + beta: number, + M: number, + N: number, + K: number, +) { + let offsetA = 0, + offsetB = 0, + offsetC = 0; for (let mm = 0; mm < M; mm++) { for (let nn = 0; nn < N; nn++) { let sum = 0; @@ -30,9 +39,18 @@ function matMul2d_( } function matMul2d_tA( - A: Float32Array|Float64Array, B: Float32Array|Float64Array, C: Float32Array|Float64Array, alpha: number, - beta: number, M: number, N: number, K: number) { - let offsetA = 0, offsetB = 0, offsetC = 0; + A: Float32Array | Float64Array, + B: Float32Array | Float64Array, + C: Float32Array | Float64Array, + alpha: number, + beta: number, + M: number, + N: number, + K: number, +) { + let offsetA = 0, + offsetB = 0, + offsetC = 0; for (let mm = 0; mm < M; mm++) { for (let nn = 0; nn < N; nn++) { let sum = 0; @@ -53,9 +71,18 @@ function matMul2d_tA( } function matMul2d_tB( - A: Float32Array|Float64Array, B: Float32Array|Float64Array, C: Float32Array|Float64Array, alpha: number, - beta: number, M: number, N: number, K: number) { - let offsetA = 0, offsetB = 0, offsetC = 0; + A: Float32Array | Float64Array, + B: Float32Array | Float64Array, + C: Float32Array | Float64Array, + alpha: number, + beta: number, + M: number, + N: number, + K: number, +) { + let offsetA = 0, + offsetB = 0, + offsetC = 0; for (let mm = 0; mm < M; mm++) { for (let nn = 0; nn < N; nn++) { let sum = 0; @@ -76,9 +103,18 @@ function matMul2d_tB( } function matMul2d_tAtB( - A: Float32Array|Float64Array, B: Float32Array|Float64Array, C: Float32Array|Float64Array, alpha: number, - beta: number, M: number, N: number, K: number) { - let offsetA = 0, offsetB = 0, offsetC = 0; + A: Float32Array | Float64Array, + B: Float32Array | Float64Array, + C: Float32Array | Float64Array, + alpha: number, + beta: number, + M: number, + N: number, + K: number, +) { + let offsetA = 0, + offsetB = 0, + offsetC = 0; for (let mm = 0; mm < M; mm++) { for (let nn = 0; nn < N; nn++) { let sum = 0; @@ -105,8 +141,17 @@ function matMul2d_tAtB( * @param C data of tensor C, whose shape is [M,N] */ export function matMul2d( - A: Float32Array|Float64Array, B: Float32Array|Float64Array, C: Float32Array|Float64Array, transA: boolean, - transB: boolean, alpha: number, beta: number, M: number, N: number, K: number): void { + A: Float32Array | Float64Array, + B: Float32Array | Float64Array, + C: Float32Array | Float64Array, + transA: boolean, + transB: boolean, + alpha: number, + beta: number, + M: number, + N: number, + K: number, +): void { if (transA && transB) { matMul2d_tAtB(A, B, C, alpha, beta, M, N, K); } else if (transA) { @@ -119,9 +164,22 @@ export function matMul2d( } function im2col( - data_im: Float32Array|Float64Array, data_col: Float32Array|Float64Array, channels: number, height: number, - width: number, kernel_h: number, kernel_w: number, dilation_h: number, dilation_w: number, pad_t: number, - pad_l: number, pad_b: number, pad_r: number, stride_h: number, stride_w: number) { + data_im: Float32Array | Float64Array, + data_col: Float32Array | Float64Array, + channels: number, + height: number, + width: number, + kernel_h: number, + kernel_w: number, + dilation_h: number, + dilation_w: number, + pad_t: number, + pad_l: number, + pad_b: number, + pad_r: number, + stride_h: number, + stride_w: number, +) { const output_h = ~~((height + pad_b + pad_t - (dilation_h * (kernel_h - 1) + 1)) / stride_h) + 1; const output_w = ~~((width + pad_l + pad_r - (dilation_w * (kernel_w - 1) + 1)) / stride_w) + 1; @@ -133,16 +191,19 @@ function im2col( const rest = k % (kernel_h * kernel_w); const kh = ~~(rest / kernel_w); const kw = rest % kernel_w; - const dst_offset = nip * (kernel_h * kernel_w * output_h * output_w) + kh * (kernel_w * output_h * output_w) + - kw * (output_h * output_w); + const dst_offset = + nip * (kernel_h * kernel_w * output_h * output_w) + + kh * (kernel_w * output_h * output_w) + + kw * (output_h * output_w); const src_offset = nip * (height * width); for (let y = 0; y < output_h; y++) { const iy = y * stride_h + kh; const ix = kw; if (stride_w === 1) { data_col.set( - data_im.subarray(src_offset + iy * width + ix, src_offset + iy * width + ix + output_w), - dst_offset + y * output_w); + data_im.subarray(src_offset + iy * width + ix, src_offset + iy * width + ix + output_w), + dst_offset + y * output_w, + ); } else { for (let x = 0; x < output_w; x++) { data_col[dst_offset + (y * output_w + x)] = data_im[src_offset + (iy * width + ix + x * stride_w)]; @@ -180,8 +241,15 @@ function im2col( } export function conv2d( - Y: Tensor, X: Tensor, W: Tensor, B: Tensor|undefined, dilations: readonly number[], group: number, - pads: readonly number[], strides: readonly number[]): void { + Y: Tensor, + X: Tensor, + W: Tensor, + B: Tensor | undefined, + dilations: readonly number[], + group: number, + pads: readonly number[], + strides: readonly number[], +): void { const input_num = X.dims[0]; const input_channels = X.dims[1]; const input_height = X.dims[2]; @@ -203,10 +271,10 @@ export function conv2d( const input_image_size = input_height * input_width; const output_image_size = output_height * output_width; const kernel_size = kernel_shape[0] * kernel_shape[1]; - const X_offset = input_channels / group * input_image_size; + const X_offset = (input_channels / group) * input_image_size; const Y_offset = output_size / output_num / group; const W_offset = filter_size / group; - const kernel_dim = input_channels / group * kernel_size; + const kernel_dim = (input_channels / group) * kernel_size; const col_buffer_size = kernel_dim * output_image_size; const col_buffer_data = new Float32Array(col_buffer_size); @@ -216,14 +284,35 @@ export function conv2d( let Y_image_offset = 0; for (let group_id = 0; group_id < group; ++group_id) { im2col( - X.floatData.subarray(X_image_offset + group_id * X_offset), col_buffer_data, input_channels / group, - input_height, input_width, kernel_shape[0], kernel_shape[1], dilations[0], dilations[1], pads[0], pads[1], - pads[2], pads[3], strides[0], strides[1]); + X.floatData.subarray(X_image_offset + group_id * X_offset), + col_buffer_data, + input_channels / group, + input_height, + input_width, + kernel_shape[0], + kernel_shape[1], + dilations[0], + dilations[1], + pads[0], + pads[1], + pads[2], + pads[3], + strides[0], + strides[1], + ); matMul2d( - W.floatData.subarray(group_id * W_offset), col_buffer_data, - Y.floatData.subarray(Y_image_offset + group_id * Y_offset), false, false, 1, 0, filter_num / group, - output_image_size, kernel_dim); + W.floatData.subarray(group_id * W_offset), + col_buffer_data, + Y.floatData.subarray(Y_image_offset + group_id * Y_offset), + false, + false, + 1, + 0, + filter_num / group, + output_image_size, + kernel_dim, + ); } X_image_offset += X_offset * group; diff --git a/js/web/test/unittests/backends/webgl/test-glsl-function-inliner.ts b/js/web/test/unittests/backends/webgl/test-glsl-function-inliner.ts index 518cb52d01da5..bb5f7645af97c 100644 --- a/js/web/test/unittests/backends/webgl/test-glsl-function-inliner.ts +++ b/js/web/test/unittests/backends/webgl/test-glsl-function-inliner.ts @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {expect} from 'chai'; +import { expect } from 'chai'; -import {replaceInlines} from '../../../../lib/onnxjs/backends/webgl/glsl-function-inliner'; -import {Logger} from '../../../../lib/onnxjs/instrument'; +import { replaceInlines } from '../../../../lib/onnxjs/backends/webgl/glsl-function-inliner'; +import { Logger } from '../../../../lib/onnxjs/instrument'; function removeWhiteSpace(s: string): string { return s.replace(/\s+/gm, ' '); diff --git a/js/web/test/unittests/backends/webgl/test-matmul-packed.ts b/js/web/test/unittests/backends/webgl/test-matmul-packed.ts index e5714c8f8cdc1..c67413caf3365 100644 --- a/js/web/test/unittests/backends/webgl/test-matmul-packed.ts +++ b/js/web/test/unittests/backends/webgl/test-matmul-packed.ts @@ -1,16 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {expect} from 'chai'; -import {env} from 'onnxruntime-common'; +import { expect } from 'chai'; +import { env } from 'onnxruntime-common'; -import {Backend, InferenceHandler, resolveBackend, SessionHandler} from '../../../../lib/onnxjs/backend'; -import {WebGLInferenceHandler} from '../../../../lib/onnxjs/backends/webgl/inference-handler'; -import {createPackedMatmulProgramInfoLoader} from '../../../../lib/onnxjs/backends/webgl/ops/matmul-pack'; -import {Profiler} from '../../../../lib/onnxjs/instrument'; -import {Tensor} from '../../../../lib/onnxjs/tensor'; +import { Backend, InferenceHandler, resolveBackend, SessionHandler } from '../../../../lib/onnxjs/backend'; +import { WebGLInferenceHandler } from '../../../../lib/onnxjs/backends/webgl/inference-handler'; +import { createPackedMatmulProgramInfoLoader } from '../../../../lib/onnxjs/backends/webgl/ops/matmul-pack'; +import { Profiler } from '../../../../lib/onnxjs/instrument'; +import { Tensor } from '../../../../lib/onnxjs/tensor'; -import {createAscendingArray} from './test-utils'; +import { createAscendingArray } from './test-utils'; interface TestData { elementCountA: number; @@ -136,15 +136,15 @@ function getTestData(): TestData[] { ]; } -let backend: Backend|undefined; -let sessionhandler: SessionHandler|undefined; -let inferenceHandler: InferenceHandler|undefined; +let backend: Backend | undefined; +let sessionhandler: SessionHandler | undefined; +let inferenceHandler: InferenceHandler | undefined; describe('#UnitTest# - packed matmul - Tensor matmul', () => { before('Initialize Context', async () => { const profiler = Profiler.create(); backend = await resolveBackend('webgl'); - sessionhandler = backend.createSessionHandler({profiler}); + sessionhandler = backend.createSessionHandler({ profiler }); inferenceHandler = sessionhandler.createInferenceHandler(); }); @@ -171,14 +171,15 @@ describe('#UnitTest# - packed matmul - Tensor matmul', () => { const inputDataB = testData.rawInputB ?? createAscendingArray(elementCountB); const inputTensorA = new Tensor(inputTensorShapeA, 'float32', undefined, undefined, inputDataA); const inputTensorB = new Tensor(inputTensorShapeB, 'float32', undefined, undefined, inputDataB); - const biasTensor = testData.biasValue ? - new Tensor([1], 'float32', undefined, undefined, new Float32Array([testData.biasValue])) : - undefined; + const biasTensor = testData.biasValue + ? new Tensor([1], 'float32', undefined, undefined, new Float32Array([testData.biasValue])) + : undefined; const inputs = biasTensor ? [inputTensorA, inputTensorB, biasTensor] : [inputTensorA, inputTensorB]; const output = webglInferenceHandler.run( - createPackedMatmulProgramInfoLoader(webglInferenceHandler, inputs, {activation: '', activationCacheKey: ''}), - inputs); + createPackedMatmulProgramInfoLoader(webglInferenceHandler, inputs, { activation: '', activationCacheKey: '' }), + inputs, + ); const result = output.data; webglInferenceHandler.session.textureManager.glContext.checkError(); @@ -200,8 +201,10 @@ describe('#UnitTest# - packed matmul - Tensor matmul', () => { } const batchMultiplier = Math.max(batchMultiplierA, batchMultiplierB); expect(result).to.have.lengthOf( - batchMultiplier * testData.inputShapeA[testData.inputShapeA.length - 2] * - testData.inputShapeB[testData.inputShapeB.length - 1]); + batchMultiplier * + testData.inputShapeA[testData.inputShapeA.length - 2] * + testData.inputShapeB[testData.inputShapeB.length - 1], + ); expect(result).to.deep.equal(expectedOutput); }); } diff --git a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts index 61c21d4b689fb..28821663ffd50 100644 --- a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts +++ b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts @@ -1,18 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {expect} from 'chai'; - -import {Backend, InferenceHandler, resolveBackend, SessionHandler} from '../../../../lib/onnxjs/backend'; -import {WebGLInferenceHandler} from '../../../../lib/onnxjs/backends/webgl/inference-handler'; -import {createPackProgramInfoLoader} from '../../../../lib/onnxjs/backends/webgl/ops/pack'; -import {createUnpackProgramInfoLoader} from '../../../../lib/onnxjs/backends/webgl/ops/unpack'; -import {createTextureLayoutFromShape} from '../../../../lib/onnxjs/backends/webgl/texture-layout'; -import {Profiler} from '../../../../lib/onnxjs/instrument'; -import {Tensor} from '../../../../lib/onnxjs/tensor'; -import {ShapeUtil} from '../../../../lib/onnxjs/util'; - -import {createArrayFromTexture, createAscendingArray, createTextureFromArray, generateExpected, getExpectedElementCount} from './test-utils'; +import { expect } from 'chai'; + +import { Backend, InferenceHandler, resolveBackend, SessionHandler } from '../../../../lib/onnxjs/backend'; +import { WebGLInferenceHandler } from '../../../../lib/onnxjs/backends/webgl/inference-handler'; +import { createPackProgramInfoLoader } from '../../../../lib/onnxjs/backends/webgl/ops/pack'; +import { createUnpackProgramInfoLoader } from '../../../../lib/onnxjs/backends/webgl/ops/unpack'; +import { createTextureLayoutFromShape } from '../../../../lib/onnxjs/backends/webgl/texture-layout'; +import { Profiler } from '../../../../lib/onnxjs/instrument'; +import { Tensor } from '../../../../lib/onnxjs/tensor'; +import { ShapeUtil } from '../../../../lib/onnxjs/util'; + +import { + createArrayFromTexture, + createAscendingArray, + createTextureFromArray, + generateExpected, + getExpectedElementCount, +} from './test-utils'; interface TestData { elementCount: number; @@ -27,51 +33,87 @@ function getTestData(isPacked = true): TestData[] { if (isPacked) { return [ // test scalar - {elementCount: 1, inputShape: [], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1]}, + { elementCount: 1, inputShape: [], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1] }, // test 1D tensor - {elementCount: 1, inputShape: [1], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1]}, - {elementCount: 16, inputShape: [16], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 8]}, - {elementCount: 9, inputShape: [9], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 5]}, + { elementCount: 1, inputShape: [1], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1] }, + { elementCount: 16, inputShape: [16], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 8] }, + { elementCount: 9, inputShape: [9], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 5] }, // test 2D tensor - {elementCount: 1, inputShape: [1, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1]}, - {elementCount: 16, inputShape: [4, 4], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 2]}, - {elementCount: 16, inputShape: [2, 8], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 4]}, - {elementCount: 16, inputShape: [8, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 1]}, - {elementCount: 15, inputShape: [3, 5], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 3]}, - {elementCount: 18, inputShape: [3, 6], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 3]}, - {elementCount: 10, inputShape: [2, 5], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 3]}, - {elementCount: 6, inputShape: [1, 6], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 3]}, - {elementCount: 6, inputShape: [6, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [3, 1]}, - {elementCount: 5, inputShape: [5, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [3, 1]}, - {elementCount: 5, inputShape: [1, 5], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 3]}, + { elementCount: 1, inputShape: [1, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1] }, + { elementCount: 16, inputShape: [4, 4], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 2] }, + { elementCount: 16, inputShape: [2, 8], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 4] }, + { elementCount: 16, inputShape: [8, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 1] }, + { elementCount: 15, inputShape: [3, 5], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 3] }, + { elementCount: 18, inputShape: [3, 6], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 3] }, + { elementCount: 10, inputShape: [2, 5], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 3] }, + { elementCount: 6, inputShape: [1, 6], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 3] }, + { elementCount: 6, inputShape: [6, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [3, 1] }, + { elementCount: 5, inputShape: [5, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [3, 1] }, + { elementCount: 5, inputShape: [1, 5], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 3] }, // test 3D tensor - {elementCount: 1, inputShape: [1, 1, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1]}, - {elementCount: 16, inputShape: [2, 2, 4], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 2]}, - {elementCount: 24, inputShape: [2, 3, 4], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 2]}, - {elementCount: 30, inputShape: [5, 3, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [10, 1]}, - {elementCount: 9, inputShape: [1, 3, 3], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 2]}, - {elementCount: 8, inputShape: [1, 4, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 1]}, - {elementCount: 8, inputShape: [4, 2, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 1]}, - {elementCount: 8, inputShape: [4, 1, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 1]}, + { elementCount: 1, inputShape: [1, 1, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1] }, + { elementCount: 16, inputShape: [2, 2, 4], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 2] }, + { elementCount: 24, inputShape: [2, 3, 4], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 2] }, + { elementCount: 30, inputShape: [5, 3, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [10, 1] }, + { elementCount: 9, inputShape: [1, 3, 3], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 2] }, + { elementCount: 8, inputShape: [1, 4, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 1] }, + { elementCount: 8, inputShape: [4, 2, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 1] }, + { elementCount: 8, inputShape: [4, 1, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 1] }, // test 4D tensor - {elementCount: 1, inputShape: [1, 1, 1, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1]}, - {elementCount: 15, inputShape: [1, 1, 3, 5], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 3]}, - {elementCount: 16, inputShape: [1, 2, 2, 4], outputShape: [], inputTextureShape: [], outputTextureShape: [2, 2]}, - {elementCount: 32, inputShape: [2, 2, 2, 4], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 2]}, - {elementCount: 36, inputShape: [2, 2, 3, 3], outputShape: [], inputTextureShape: [], outputTextureShape: [8, 2]}, - {elementCount: 80, inputShape: [2, 5, 2, 4], outputShape: [], inputTextureShape: [], outputTextureShape: [10, 2]}, - {elementCount: 12, inputShape: [2, 1, 3, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 1]}, - {elementCount: 8, inputShape: [4, 1, 1, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 1]}, + { elementCount: 1, inputShape: [1, 1, 1, 1], outputShape: [], inputTextureShape: [], outputTextureShape: [1, 1] }, + { + elementCount: 15, + inputShape: [1, 1, 3, 5], + outputShape: [], + inputTextureShape: [], + outputTextureShape: [2, 3], + }, + { + elementCount: 16, + inputShape: [1, 2, 2, 4], + outputShape: [], + inputTextureShape: [], + outputTextureShape: [2, 2], + }, + { + elementCount: 32, + inputShape: [2, 2, 2, 4], + outputShape: [], + inputTextureShape: [], + outputTextureShape: [4, 2], + }, + { + elementCount: 36, + inputShape: [2, 2, 3, 3], + outputShape: [], + inputTextureShape: [], + outputTextureShape: [8, 2], + }, + { + elementCount: 80, + inputShape: [2, 5, 2, 4], + outputShape: [], + inputTextureShape: [], + outputTextureShape: [10, 2], + }, + { + elementCount: 12, + inputShape: [2, 1, 3, 2], + outputShape: [], + inputTextureShape: [], + outputTextureShape: [4, 1], + }, + { elementCount: 8, inputShape: [4, 1, 1, 2], outputShape: [], inputTextureShape: [], outputTextureShape: [4, 1] }, { elementCount: 3840, inputShape: [1, 1, 48, 80], outputShape: [], inputTextureShape: [], - outputTextureShape: [24, 40] + outputTextureShape: [24, 40], }, // test 6D tensor { @@ -79,14 +121,14 @@ function getTestData(isPacked = true): TestData[] { inputShape: [1, 1, 2, 2, 2, 4], outputShape: [], inputTextureShape: [], - outputTextureShape: [4, 2] + outputTextureShape: [4, 2], }, { elementCount: 3840, inputShape: [1, 1, 2, 24, 2, 40], outputShape: [], inputTextureShape: [], - outputTextureShape: [48, 20] + outputTextureShape: [48, 20], }, ]; } else { @@ -150,9 +192,8 @@ function getTestData(isPacked = true): TestData[] { inputTextureShape: [2, 4], outputTextureShape: [6, 4], rawData: new Float32Array([ - 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 0, 0, 11, 12, 0, 0, - 13, 14, 17, 18, 15, 16, 19, 20, 21, 22, 0, 0, 23, 24, 0, 0 - ]) + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 0, 0, 11, 12, 0, 0, 13, 14, 17, 18, 15, 16, 19, 20, 21, 22, 0, 0, 23, 24, 0, 0, + ]), }, // test 4d tensor { @@ -192,15 +233,15 @@ function getTestData(isPacked = true): TestData[] { } } -let backend: Backend|undefined; -let sessionhandler: SessionHandler|undefined; -let inferenceHandler: InferenceHandler|undefined; +let backend: Backend | undefined; +let sessionhandler: SessionHandler | undefined; +let inferenceHandler: InferenceHandler | undefined; describe('#UnitTest# - pack - Tensor pack', () => { before('Initialize Context', async () => { const profiler = Profiler.create(); backend = await resolveBackend('webgl'); - sessionhandler = backend!.createSessionHandler({profiler}); + sessionhandler = backend!.createSessionHandler({ profiler }); inferenceHandler = sessionhandler.createInferenceHandler(); }); const testDataSet = getTestData(); @@ -231,14 +272,20 @@ describe('#UnitTest# - pack - Tensor pack', () => { console.log('Testing unreverted HW input texture'); // use inputTensorShape to create a texture layout that is unpacked(channel === 1)&& hw unreverted. - const inputUnpackedLayout = - createTextureLayoutFromShape(webglInferenceHandler.session.layoutStrategy, inputTensorShape); + const inputUnpackedLayout = createTextureLayoutFromShape( + webglInferenceHandler.session.layoutStrategy, + inputTensorShape, + ); // create texture data from the layout. The texture data is cached inside inference handler such that // when pack kernel is invoked, it will read this texture data from cache instead of creating it from // scratch webglInferenceHandler.createTextureDataFromLayoutBindTensor( - inputUnpackedLayout, inputTensor.type, inputTensor.numberData, inputTensor); + inputUnpackedLayout, + inputTensor.type, + inputTensor.numberData, + inputTensor, + ); } // compile shader code @@ -247,8 +294,12 @@ describe('#UnitTest# - pack - Tensor pack', () => { // run kernal and get output const resultTextureData = webglInferenceHandler.executeProgram(programInfo, [inputTensor]); const gl = webglInferenceHandler.session.textureManager.glContext.gl; - const resultDataBuffer = - createArrayFromTexture(gl, resultTextureData.texture, outputTextureShape[1], outputTextureShape[0]); + const resultDataBuffer = createArrayFromTexture( + gl, + resultTextureData.texture, + outputTextureShape[1], + outputTextureShape[0], + ); expect(resultDataBuffer).to.not.equal(null); @@ -265,7 +316,7 @@ describe('#UnitTest# - unpack - Tensor unpack', () => { before('Initialize Context', async () => { const profiler = Profiler.create(); backend = await resolveBackend('webgl'); - sessionhandler = backend!.createSessionHandler({profiler}); + sessionhandler = backend!.createSessionHandler({ profiler }); inferenceHandler = sessionhandler.createInferenceHandler(); }); const testDataSet = getTestData(false); @@ -290,8 +341,11 @@ describe('#UnitTest# - unpack - Tensor unpack', () => { const gl = webglInferenceHandler.session.textureManager.glContext.gl; webglInferenceHandler.session.textureManager.glContext.checkError(); const webglTexture = createTextureFromArray( - webglInferenceHandler.session.textureManager.glContext, testData.rawData ? testData.rawData : inputData, - inputTextureShape[0], inputTextureShape[1]); + webglInferenceHandler.session.textureManager.glContext, + testData.rawData ? testData.rawData : inputData, + inputTextureShape[0], + inputTextureShape[1], + ); webglInferenceHandler.session.textureManager.glContext.checkError(); const packedShape = inputTextureShape; const textureData = { @@ -303,7 +357,7 @@ describe('#UnitTest# - unpack - Tensor unpack', () => { strides: ShapeUtil.computeStrides(packedShape), unpackedShape: outputTensorShape, tensor: inputTensor, - texture: webglTexture! + texture: webglTexture!, }; webglInferenceHandler.setTextureData(inputTensor.dataId, textureData, true); @@ -336,7 +390,7 @@ describe('#UnitTest# - pack-unpack round trip', () => { before('Initialize Context', async () => { const profiler = Profiler.create(); backend = await resolveBackend('webgl'); - sessionhandler = backend!.createSessionHandler({profiler}); + sessionhandler = backend!.createSessionHandler({ profiler }); inferenceHandler = sessionhandler.createInferenceHandler(); }); const testDataSet = getTestData(); @@ -360,13 +414,14 @@ describe('#UnitTest# - pack-unpack round trip', () => { // create unpack kernel // compile unpack shader code - const unpackProgramInfo = - createPackProgramInfoLoader(inferenceHandler! as WebGLInferenceHandler, packResultData.tensor); + const unpackProgramInfo = createPackProgramInfoLoader( + inferenceHandler! as WebGLInferenceHandler, + packResultData.tensor, + ); // run unpack kernal and get output const unpackResultData = webglInferenceHandler.executeProgram(unpackProgramInfo, [inputTensor]); - const resultData = unpackResultData.tensor.data; expect(resultData).to.not.equal(null); expect(resultData).to.have.lengthOf(testData.elementCount); diff --git a/js/web/test/unittests/backends/webgl/test-reshape-packed.ts b/js/web/test/unittests/backends/webgl/test-reshape-packed.ts index e848e6686f8a9..b90372db1250a 100644 --- a/js/web/test/unittests/backends/webgl/test-reshape-packed.ts +++ b/js/web/test/unittests/backends/webgl/test-reshape-packed.ts @@ -1,15 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {expect} from 'chai'; -import {env} from 'onnxruntime-common'; +import { expect } from 'chai'; +import { env } from 'onnxruntime-common'; -import {Backend, InferenceHandler, resolveBackend, SessionHandler} from '../../../../lib/onnxjs/backend'; -import {WebGLInferenceHandler} from '../../../../lib/onnxjs/backends/webgl/inference-handler'; -import {Profiler} from '../../../../lib/onnxjs/instrument'; -import {Tensor} from '../../../../lib/onnxjs/tensor'; +import { Backend, InferenceHandler, resolveBackend, SessionHandler } from '../../../../lib/onnxjs/backend'; +import { WebGLInferenceHandler } from '../../../../lib/onnxjs/backends/webgl/inference-handler'; +import { Profiler } from '../../../../lib/onnxjs/instrument'; +import { Tensor } from '../../../../lib/onnxjs/tensor'; -import {createAscendingArray} from './test-utils'; +import { createAscendingArray } from './test-utils'; interface TestData { elementCount: number; @@ -102,15 +102,15 @@ function getTestData(): TestData[] { ]; } -let backend: Backend|undefined; -let sessionhandler: SessionHandler|undefined; -let inferenceHandler: InferenceHandler|undefined; +let backend: Backend | undefined; +let sessionhandler: SessionHandler | undefined; +let inferenceHandler: InferenceHandler | undefined; describe('#UnitTest# - reshape - packed', () => { before('Initialize Context', async () => { const profiler = Profiler.create(); backend = await resolveBackend('webgl'); - sessionhandler = backend.createSessionHandler({profiler}); + sessionhandler = backend.createSessionHandler({ profiler }); inferenceHandler = sessionhandler.createInferenceHandler(); }); diff --git a/js/web/test/unittests/backends/webgl/test-utils.ts b/js/web/test/unittests/backends/webgl/test-utils.ts index 092d63cd2ade4..0f26055ef8d5e 100644 --- a/js/web/test/unittests/backends/webgl/test-utils.ts +++ b/js/web/test/unittests/backends/webgl/test-utils.ts @@ -1,17 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {WebGLContext} from '../../../../lib/onnxjs/backends/webgl/webgl-context'; +import { WebGLContext } from '../../../../lib/onnxjs/backends/webgl/webgl-context'; export function createAscendingArray(size: number): Float32Array { - return new Float32Array(Array.from({length: size}, (_v, i) => (i + 1))); + return new Float32Array(Array.from({ length: size }, (_v, i) => i + 1)); } // Returns an array by injecting 3 zeros after every element in the input array to be used for creating unpacked // texture. export function generateArrayForUnpackedTexture(input: Float32Array): Float32Array { const output = new Float32Array(input.length * 4); - for (let i = 0; i < (input.length * 4); i += 4) { + for (let i = 0; i < input.length * 4; i += 4) { output[i] = input[i / 4]; } return output; @@ -19,7 +19,11 @@ export function generateArrayForUnpackedTexture(input: Float32Array): Float32Arr // create a webgl texture and fill it with the array content export function createTextureFromArray( - glContext: WebGLContext, dataArray: Float32Array, width: number, height: number): WebGLTexture { + glContext: WebGLContext, + dataArray: Float32Array, + width: number, + height: number, +): WebGLTexture { const gl = glContext.gl; // create the texture @@ -46,12 +50,14 @@ export function createTextureFromArray( // create a cpu array and download GPU texture data to this array export function createArrayFromTexture( - gl: WebGLRenderingContext, texture: WebGLTexture, width: number, height: number): Float32Array { + gl: WebGLRenderingContext, + texture: WebGLTexture, + width: number, + height: number, +): Float32Array { const resultDataBuffer = new Float32Array(width * height * 4); gl.bindTexture(gl.TEXTURE_2D, texture); - gl.framebufferTexture2D( - gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, - 0); // 0, we aren't using MIPMAPs + gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); // 0, we aren't using MIPMAPs gl.readPixels(0, 0, width, height, gl.RGBA, gl.FLOAT, resultDataBuffer); return resultDataBuffer; } @@ -130,7 +136,7 @@ export function generateExpected(inputArray: Float32Array, inputShape: number[]) result[ii++] = 0; } - if ((j + 1) < inputHeight) { + if (j + 1 < inputHeight) { result[ii++] = inputArray[(j + 1) * inputWidth + i + b * (inputHeight * inputWidth)]; } else { result[ii++] = 0; diff --git a/js/web/test/unittests/opset.ts b/js/web/test/unittests/opset.ts index 6a163dfb47817..a4bd0a079cdda 100644 --- a/js/web/test/unittests/opset.ts +++ b/js/web/test/unittests/opset.ts @@ -1,16 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {expect} from 'chai'; +import { expect } from 'chai'; -import {Attribute} from '../../lib/onnxjs/attribute'; -import {WEBGL_OP_RESOLVE_RULES} from '../../lib/onnxjs/backends/webgl/op-resolve-rules'; -import {Graph} from '../../lib/onnxjs/graph'; -import {OpSet, resolveOperator} from '../../lib/onnxjs/opset'; -import {Tensor} from '../../lib/onnxjs/tensor'; +import { Attribute } from '../../lib/onnxjs/attribute'; +import { WEBGL_OP_RESOLVE_RULES } from '../../lib/onnxjs/backends/webgl/op-resolve-rules'; +import { Graph } from '../../lib/onnxjs/graph'; +import { OpSet, resolveOperator } from '../../lib/onnxjs/opset'; +import { Tensor } from '../../lib/onnxjs/tensor'; function createTestGraphNode(name: string, opType: string): Graph.Node { - return {name, opType, inputs: [], outputs: [], attributes: new Attribute(null)}; + return { name, opType, inputs: [], outputs: [], attributes: new Attribute(null) }; } function dummyOpImpl(): Tensor[] { @@ -18,9 +18,10 @@ function dummyOpImpl(): Tensor[] { } function checkConsistency(rules: readonly OpSet.ResolveRule[]) { - const VERSION_MIN = 1, VERSION_MAX = 10; + const VERSION_MIN = 1, + VERSION_MAX = 10; const typeRules = new Map(); - rules.forEach(rule => { + rules.forEach((rule) => { let ruleSet = typeRules.get(rule[0]); if (!ruleSet) { ruleSet = []; @@ -34,7 +35,7 @@ function checkConsistency(rules: readonly OpSet.ResolveRule[]) { let match = false; for (const r of rules) { try { - resolveOperator(createTestGraphNode('', type), [{domain: '', version: i}], [r]); + resolveOperator(createTestGraphNode('', type), [{ domain: '', version: i }], [r]); } catch { continue; } @@ -47,7 +48,7 @@ function checkConsistency(rules: readonly OpSet.ResolveRule[]) { describe('#UnitTest# - resolveOperator', () => { const nodeAbs = createTestGraphNode('Abs_1', 'Abs'); - const opset7 = [{domain: '', version: 7}]; + const opset7 = [{ domain: '', version: 7 }]; it('ExpectFail - no rule available', () => { expect(() => { resolveOperator(nodeAbs, opset7, []); @@ -55,7 +56,10 @@ describe('#UnitTest# - resolveOperator', () => { }); it('ExpectFail - no matching rule', () => { expect(() => { - resolveOperator(nodeAbs, opset7, [['And', '', '7', dummyOpImpl], ['Sub', '', '7', dummyOpImpl]]); + resolveOperator(nodeAbs, opset7, [ + ['And', '', '7', dummyOpImpl], + ['Sub', '', '7', dummyOpImpl], + ]); }).to.throw(TypeError); }); it('ExpectFail - version not match (exact match)', () => { @@ -93,8 +97,9 @@ describe('#UnitTest# - resolveOperator', () => { }); describe('#UnitTest# - resolve rules', () => { - const webglCheckOnlyRules = - WEBGL_OP_RESOLVE_RULES.map(rule => [rule[0], rule[1], rule[2], dummyOpImpl] as OpSet.ResolveRule); + const webglCheckOnlyRules = WEBGL_OP_RESOLVE_RULES.map( + (rule) => [rule[0], rule[1], rule[2], dummyOpImpl] as OpSet.ResolveRule, + ); it('Consistency check - onnx.ai - webgl', () => { checkConsistency(webglCheckOnlyRules); });