From 291a5352b27ded5714e5748b381f2efb88f28fb9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:56:22 -0700 Subject: [PATCH] [js/web] remove training release (#22103) ### Description Remove training from onnxruntime-web Following up of #22082 --- js/web/lib/backend-wasm-inference.ts | 5 - js/web/lib/backend-wasm-training.ts | 29 - js/web/lib/backend-wasm.ts | 2 + js/web/lib/index.ts | 4 +- js/web/lib/wasm/session-handler-training.ts | 198 ------ js/web/lib/wasm/wasm-core-impl.ts | 9 +- js/web/lib/wasm/wasm-training-core-impl.ts | 631 ------------------ js/web/lib/wasm/wasm-types.ts | 76 +-- js/web/lib/wasm/wasm-utils-import.ts | 16 +- js/web/package.json | 7 - js/web/script/build.ts | 13 +- js/web/script/pull-prebuilt-wasm-artifacts.ts | 2 - js/web/test/training/e2e/browser-test-wasm.js | 21 - js/web/test/training/e2e/common.js | 248 ------- js/web/test/training/e2e/data/model.onnx | 16 - js/web/test/training/e2e/karma.conf.js | 54 -- js/web/test/training/e2e/package.json | 14 - js/web/test/training/e2e/run.js | 143 ---- .../test/training/e2e/simple-http-server.js | 67 -- js/web/types.d.ts | 4 - 20 files changed, 15 insertions(+), 1544 deletions(-) delete mode 100644 js/web/lib/backend-wasm-inference.ts delete mode 100644 js/web/lib/backend-wasm-training.ts delete mode 100644 js/web/lib/wasm/session-handler-training.ts delete mode 100644 js/web/lib/wasm/wasm-training-core-impl.ts delete mode 100644 js/web/test/training/e2e/browser-test-wasm.js delete mode 100644 js/web/test/training/e2e/common.js delete mode 100644 js/web/test/training/e2e/data/model.onnx delete mode 100644 js/web/test/training/e2e/karma.conf.js delete mode 100644 js/web/test/training/e2e/package.json delete mode 100644 js/web/test/training/e2e/run.js delete mode 100644 js/web/test/training/e2e/simple-http-server.js diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts deleted file mode 100644 index 7dfe7ee05a1d3..0000000000000 --- a/js/web/lib/backend-wasm-inference.ts +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts deleted file mode 100644 index 7332b3f97eba0..0000000000000 --- a/js/web/lib/backend-wasm-training.ts +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, TrainingSessionHandler } from 'onnxruntime-common'; - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -import { OnnxruntimeWebAssemblyTrainingSessionHandler } from './wasm/session-handler-training'; - -class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { - async createTrainingSessionHandler( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ): Promise { - const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); - await handler.createTrainingSession( - checkpointStateUriOrBuffer, - trainModelUriOrBuffer, - evalModelUriOrBuffer, - optimizerModelUriOrBuffer, - options, - ); - return Promise.resolve(handler); - } -} - -export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 7bef538b26063..766937dc4c4cf 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -99,3 +99,5 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { return Promise.resolve(handler); } } + +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 321394466b365..776c0d026bc97 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -20,9 +20,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = BUILD_DEFS.DISABLE_TRAINING - ? require('./backend-wasm-inference').wasmBackend - : require('./backend-wasm-training').wasmBackend; + const wasmBackend = require('./backend-wasm').wasmBackend; if (!BUILD_DEFS.DISABLE_JSEP) { registerBackend('webgpu', wasmBackend, 5); registerBackend('webnn', wasmBackend, 5); diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts deleted file mode 100644 index 8bbfb9cf06668..0000000000000 --- a/js/web/lib/wasm/session-handler-training.ts +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference'; -import { copyFromExternalBuffer } from './wasm-core-impl'; -import { - createCheckpointHandle, - createTrainingSessionHandle, - getContiguousParameters, - getModelInputOutputNames, - getParametersSize, - lazyResetGrad, - loadParametersBuffer, - releaseTrainingSessionAndCheckpoint, - runEvalStep, - runOptimizerStep, - runTrainStep, -} from './wasm-training-core-impl'; - -export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - private sessionId: number; - private checkpointId: number; - - inputNames: string[]; - outputNames: string[]; - - evalInputNames: string[] = []; - evalOutputNames: string[] = []; - - async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise { - let buffer: Uint8Array; - if (typeof uriOrBuffer === 'string') { - const response = await fetch(uriOrBuffer); - const arrayBuffer = await response.arrayBuffer(); - buffer = new Uint8Array(arrayBuffer); - } else { - buffer = uriOrBuffer; - } - return copyFromExternalBuffer(buffer); - } - - async createTrainingSession( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ) { - const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer); - // 0 is supposed to be the nullptr - let evalModelData: SerializableInternalBuffer = [0, 0]; - let optimizerModelData: SerializableInternalBuffer = [0, 0]; - - if (evalModelUriOrBuffer !== '') { - evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); - } - if (optimizerModelUriOrBuffer !== '') { - optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); - } - - this.checkpointId = createCheckpointHandle(checkpointData); - this.sessionId = createTrainingSessionHandle( - this.checkpointId, - trainModelData, - evalModelData, - optimizerModelData, - options, - ); - [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); - if (evalModelUriOrBuffer !== '') { - [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); - } - } - - /** - * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the - * corresponding name as a number referring to the index in the list of names provided. - * - * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType - * @param names either inputNames or outputNames - * @returns a tuple of a list of values and a list of indices. - */ - convertMapIntoValuesArrayAndIndicesArray( - feeds: { [name: string]: T }, - names: string[], - mapFunc: (val: T, index: number) => U, - ): [T[], number[], U[]] { - const values: T[] = []; - const indices: number[] = []; - Object.entries(feeds).forEach((kvp) => { - const name = kvp[0]; - const tensor = kvp[1]; - const index = names.indexOf(name); - if (index === -1) { - throw new Error(`invalid input '${name}`); - } - values.push(tensor); - indices.push(index); - }); - - const uList = values.map(mapFunc); - return [values, indices, uList]; - } - - /** - * Helper method that converts the TensorMetadata that the wasm-core functions return to the - * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the - * corresponding result. - * - * @param results used to populate the resultMap if there is no value for that outputName already - * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results - * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. - * @returns a map of output names and OnnxValues. - */ - convertTensorMetadataToReturnType( - results: TensorMetadata[], - outputArray: Array, - outputIndices: number[], - ): SessionHandler.ReturnType { - const resultMap: SessionHandler.ReturnType = {}; - for (let i = 0; i < results.length; i++) { - resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); - } - return resultMap; - } - - async lazyResetGrad(): Promise { - await lazyResetGrad(this.sessionId); - } - - async runTrainStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.inputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.outputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async runOptimizerStep(options: InferenceSession.RunOptions): Promise { - await runOptimizerStep(this.sessionId, options); - } - - async runEvalStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.evalInputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.evalOutputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async getParametersSize(trainableOnly: boolean): Promise { - return getParametersSize(this.sessionId, trainableOnly); - } - - async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { - await loadParametersBuffer(this.sessionId, array, trainableOnly); - } - async getContiguousParameters(trainableOnly: boolean): Promise { - const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); - return decodeTensorMetadata(tensorResult); - } - - async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId); - } -} diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 6c4e28df62f23..ed001cfa90f59 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -41,8 +41,8 @@ import { loadFile } from './wasm-utils-load-file'; * Refer to web/lib/index.ts for the backend registration. * * 2. WebAssembly artifact initialization. - * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or - * `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings: + * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` is + * called). In this step, onnxruntime-web does the followings: * - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled. * - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated * JavaScript code to initialize the WebAssembly runtime. @@ -57,9 +57,8 @@ import { loadFile } from './wasm-utils-load-file'; * - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step. * * 4. Session initialization. - * This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3 - * steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the - * followings: + * This happens when `ort.InferenceSession.create()` is called. Unlike the first 3 steps (they only called once), + * this step will be done for each session. In this step, onnxruntime-web does the followings: * If the parameter is a URL: * - download the model data from the URL. * - copy the model data to the WASM heap. (proxy: 'copy-from') diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts deleted file mode 100644 index 22cd6ec30732c..0000000000000 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, Tensor } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { setRunOptions } from './run-options'; -import { setSessionOptions } from './session-options'; -import { - dataLocationStringToEnum, - tensorDataTypeEnumToString, - tensorDataTypeStringToEnum, - tensorTypeToTypedArrayConstructor, -} from './wasm-common'; -import { prepareInputOutputTensor } from './wasm-core-impl'; -import { getInstance } from './wasm-factory'; -import { checkLastError } from './wasm-utils'; - -const NO_TRAIN_FUNCS_MSG = - "Built without training API's enabled. Use the onnxruntime-web/training import for training " + - 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + - 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; - -/** - * Runs the checkLastError function which will throw an error, if the provided error code matches the specified - * pattern for an error code. - * @param errCode number to evaluated for if it's an error - * @param message message to pass into checkLastError - * @param checkNeqZero when true, treats not equal to zero as an error. - * When false, treats equal to zero as an error. - */ -const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { - if (checkNeqZero && errCode !== 0) { - checkLastError(message); - } else if (!checkNeqZero && errCode === 0) { - checkLastError(message); - } -}; - -export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => { - const wasm = getInstance(); - - const [checkpointDataOffset, checkpointDataLength] = checkpointData; - let checkpointHandle = 0; - - try { - if (wasm._OrtTrainingLoadCheckpoint) { - checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); - return checkpointHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { - wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); - } - throw e; - } finally { - // free buffer from wasm heap - wasm._OrtFree(checkpointData[0]); - } -}; - -const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - try { - const dataOffset = wasm.stackAlloc(8); - if (wasm._OrtTrainingGetModelInputOutputCount) { - const errorCode = wasm._OrtTrainingGetModelInputOutputCount( - trainingSessionId, - dataOffset, - dataOffset + 4, - isEvalModel, - ); - ifErrCodeCheckLastError(errorCode, "Can't get session input/output count."); - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -const getModelInputOutputNamesLoop = ( - trainingSessionId: number, - count: number, - isInput: boolean, - isEvalModel: boolean, -): string[] => { - const names = []; - const wasm = getInstance(); - - for (let i = 0; i < count; i++) { - if (wasm._OrtTrainingGetModelInputOutputName) { - const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); - ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); - - names.push(wasm.UTF8ToString(name)); - wasm._free(name); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } - return names; -}; - -export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { - let inputNames: string[] = []; - let outputNames: string[] = []; - - const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel); - - inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel); - outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel); - - return [inputNames, outputNames]; -}; - -export const createTrainingSessionHandle = ( - checkpointHandle: number, - trainModelData: SerializableInternalBuffer, - evalModelData: SerializableInternalBuffer, - optimizerModelData: SerializableInternalBuffer, - options: InferenceSession.SessionOptions, -): number => { - const wasm = getInstance(); - - let trainingSessionHandle = 0; - let sessionOptionsHandle = 0; - let allocs: number[] = []; - - try { - [sessionOptionsHandle, allocs] = setSessionOptions(options); - if (wasm._OrtTrainingCreateSession) { - trainingSessionHandle = wasm._OrtTrainingCreateSession( - sessionOptionsHandle, - checkpointHandle, - trainModelData[0], - trainModelData[1], - evalModelData[0], - evalModelData[1], - optimizerModelData[0], - optimizerModelData[1], - ); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); - return trainingSessionHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { - wasm._OrtTrainingReleaseSession(trainingSessionHandle); - } - throw e; - } finally { - wasm._free(trainModelData[0]); - wasm._free(evalModelData[0]); - wasm._free(optimizerModelData[0]); - - if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); - } - allocs.forEach((alloc) => wasm._free(alloc)); - } -}; - -/** - * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the - * WASM tensors. - * - * @param trainingSessionId - * @param indices for each tensor, the index of the input or output name that the tensor corresponds with - * @param tensors list of TensorMetaData - * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting - * handles of the allocated tensors on the heap - * @param inputOutputAllocs modified in-place by this method - * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor - */ -const createAndAllocateTensors = ( - trainingSessionId: number, - indices: number[], - tensors: Array, - tensorHandles: number[], - inputOutputAllocs: number[], - indexAdd: number, -) => { - const count = indices.length; - - // creates the tensors - for (let i = 0; i < count; i++) { - prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); - } - - // moves to heap - const wasm = getInstance(); - const valuesOffset = wasm.stackAlloc(count * 4); - let valuesIndex = valuesOffset / 4; - for (let i = 0; i < count; i++) { - wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; - } - - return valuesOffset; -}; - -/** - * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information - * associated with the tensor handle. - * - * @param outputValuesOffset - * @param outputCount - * @returns list of TensorMetadata retrieved from the output handles. - */ -const moveOutputToTensorMetadataArr = ( - outputValuesOffset: number, - outputCount: number, - outputTensorHandles: number[], - outputTensors: Array, -) => { - const wasm = getInstance(); - const output: TensorMetadata[] = []; - - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - if (tensor === outputTensorHandles[i]) { - // output tensor is pre-allocated. no need to copy data. - output.push(outputTensors[i]!); - continue; - } - - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); - - let type: Tensor.Type | undefined, - dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, - tensorDataOffset, - tensorDataOffset + 4, - tensorDataOffset + 8, - tensorDataOffset + 12, - ); - ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); - - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength), - ); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - wasm._OrtReleaseTensor(tensor); - } - } - - return output; -}; - -export const lazyResetGrad = async (trainingSessionId: number): Promise => { - const wasm = getInstance(); - - if (wasm._OrtTrainingLazyResetGrad) { - const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId); - ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } -}; - -export const runTrainStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingRunTrainStep) { - const errorCode = wasm._OrtTrainingRunTrainStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runOptimizerStep = async ( - trainingSessionId: number, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - try { - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - if (wasm._OrtTrainingOptimizerStep) { - const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle); - ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runEvalStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingEvalStep) { - const errorCode = wasm._OrtTrainingEvalStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - try { - const sizeOffset = wasm.stackAlloc(4); - if (wasm._OrtTrainingGetParametersSize) { - const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); - ifErrCodeCheckLastError(errorCode, "Can't get parameters size"); - - return wasm.HEAP32[sizeOffset / 4]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -export const getContiguousParameters = async ( - trainingSessionId: number, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - const parametersSize = getParametersSize(trainingSessionId, trainableOnly); - let tensor = 0; - - // allocates a buffer of the correct size on the WASM heap - const paramsByteLength = 4 * parametersSize; - const paramsOffset = wasm._malloc(paramsByteLength); - - // handles the dimensions-related createTensor parameters - const dims = [parametersSize]; - - const dimsOffset = wasm.stackAlloc(4); - const dimsIndex = dimsOffset / 4; - wasm.HEAP32[dimsIndex] = parametersSize; - - try { - // wraps allocated array in a tensor - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - paramsOffset, - paramsByteLength, - dimsOffset, - dims.length, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError( - tensor, - `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, - false, - ); - - if (wasm._OrtTrainingCopyParametersToBuffer) { - const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); - const data = new typedArrayConstructor(parametersSize); - const output: TensorMetadata[] = []; - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength), - ); - output.push([tensorTypeAsString, dims, data, locationAsString]); - if (output.length !== 1) { - throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of - one, got ${output.length}`); - } else { - return output[0]; - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm._free(paramsOffset); - wasm._free(dimsOffset); - wasm.stackRestore(stack); - } -}; - -export const loadParametersBuffer = async ( - trainingSessionId: number, - buffer: Uint8Array, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - // allocates & copies JavaScript buffer to WASM heap - const bufferByteLength = buffer.length; - const bufferCount = bufferByteLength / 4; - const bufferOffset = wasm._malloc(bufferByteLength); - wasm.HEAPU8.set(buffer, bufferOffset); - - // allocates and handles moving dimensions information to WASM memory - const dimsOffset = wasm.stackAlloc(4); - wasm.HEAP32[dimsOffset / 4] = bufferCount; - const dimsLength = 1; - let tensor = 0; - - try { - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - bufferOffset, - bufferByteLength, - dimsOffset, - dimsLength, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); - - if (wasm._OrtTrainingCopyParametersFromBuffer) { - const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm.stackRestore(stack); - wasm._free(bufferOffset); - wasm._free(dimsOffset); - } -}; - -export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => { - const wasm = getInstance(); - - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } -}; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70b6cceab0eef..828cd3cfd94fa 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -213,84 +213,10 @@ export interface OrtInferenceAPIs { _OrtEndProfiling(sessionHandle: number): number; } -export interface OrtTrainingAPIs { - _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - - _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; - - _OrtTrainingCreateSession( - sessionOptionsHandle: number, - checkpointHandle: number, - trainOffset: number, - trainLength: number, - evalOffset: number, - evalLength: number, - optimizerOffset: number, - optimizerLength: number, - ): number; - - _OrtTrainingLazyResetGrad(trainingHandle: number): number; - - _OrtTrainingRunTrainStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number; - - _OrtTrainingEvalStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; - - _OrtTrainingCopyParametersToBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingCopyParametersFromBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingGetModelInputOutputCount( - trainingHandle: number, - inputCount: number, - outputCount: number, - isEvalModel: boolean, - ): number; - _OrtTrainingGetModelInputOutputName( - trainingHandle: number, - index: number, - isInput: boolean, - isEvalModel: boolean, - ): number; - - _OrtTrainingReleaseSession(trainingHandle: number): void; -} - /** * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ -export interface OrtWasmModule - extends EmscriptenModule, - OrtInferenceAPIs, - Partial, - Partial { +export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial { // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index 008b9b41b1592..bd9e0ce083ef0 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -135,11 +135,9 @@ const embeddedWasmModule: EmscriptenModuleFactory | undefined = BUILD_DEFS.IS_ESM && BUILD_DEFS.DISABLE_DYNAMIC_IMPORT ? // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires require( - !BUILD_DEFS.DISABLE_TRAINING - ? '../../dist/ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' - : '../../dist/ort-wasm-simd-threaded.mjs', + !BUILD_DEFS.DISABLE_JSEP + ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' + : '../../dist/ort-wasm-simd-threaded.mjs', ).default : undefined; @@ -163,11 +161,9 @@ export const importWasmModule = async ( if (BUILD_DEFS.DISABLE_DYNAMIC_IMPORT) { return [undefined, embeddedWasmModule!]; } else { - const wasmModuleFilename = !BUILD_DEFS.DISABLE_TRAINING - ? 'ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? 'ort-wasm-simd-threaded.jsep.mjs' - : 'ort-wasm-simd-threaded.mjs'; + const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP + ? 'ort-wasm-simd-threaded.jsep.mjs' + : 'ort-wasm-simd-threaded.mjs'; const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride); // need to preload if all of the following conditions are met: // 1. not in Node.js. diff --git a/js/web/package.json b/js/web/package.json index 94dd047915b05..d770499adada4 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -23,7 +23,6 @@ "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", - "test:training:e2e": "node ./test/training/e2e/run", "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", @@ -101,12 +100,6 @@ "import": "./dist/ort.webgpu.bundle.min.mjs", "require": "./dist/ort.webgpu.min.js", "types": "./types.d.ts" - }, - "./training": { - "node": null, - "import": "./dist/ort.training.wasm.min.mjs", - "require": "./dist/ort.training.wasm.min.js", - "types": "./types.d.ts" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 6d1b3bdb65068..408f9e00a5cbd 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -56,7 +56,6 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_JSEP': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', - 'BUILD_DEFS.DISABLE_TRAINING': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'false', 'BUILD_DEFS.IS_ESM': 'false', @@ -253,7 +252,7 @@ async function buildBundle(options: esbuild.BuildOptions) { * * The distribution code is split into multiple files: * - [output-name][.min].[m]js - * - ort[-training]-wasm-simd-threaded[.jsep].mjs + * - ort-wasm-simd-threaded[.jsep].mjs */ async function buildOrt({ isProduction = false, @@ -630,16 +629,6 @@ async function main() { 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', }, }); - // ort.training.wasm[.min].[m]js - await addAllWebBuildTasks({ - outputName: 'ort.training.wasm', - define: { - ...DEFAULT_DEFINE, - 'BUILD_DEFS.DISABLE_TRAINING': 'false', - 'BUILD_DEFS.DISABLE_JSEP': 'true', - 'BUILD_DEFS.DISABLE_WEBGL': 'true', - }, - }); } if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') { diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index b1b2fa26b2351..5b8b0d27c88db 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -149,11 +149,9 @@ downloadJson( void jszip.loadAsync(buffer).then((zip) => { extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.mjs', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.mjs', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.mjs', folderName); }); }); }, diff --git a/js/web/test/training/e2e/browser-test-wasm.js b/js/web/test/training/e2e/browser-test-wasm.js deleted file mode 100644 index 05750ed149303..0000000000000 --- a/js/web/test/training/e2e/browser-test-wasm.js +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -describe('Browser E2E testing for training package', function () { - it('Check that training package encompasses inference', async function () { - ort.env.wasm.numThreads = 1; - await testInferenceFunction(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, all options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionAll(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, minimum options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionMin(ort, { executionProviders: ['wasm'] }); - }); -}); diff --git a/js/web/test/training/e2e/common.js b/js/web/test/training/e2e/common.js deleted file mode 100644 index 0574ae85aabd1..0000000000000 --- a/js/web/test/training/e2e/common.js +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const DATA_FOLDER = 'data/'; -const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx'; -const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx'; -const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx'; -const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt'; - -const trainingSessionAllOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, - evalModel: TRAININGDATA_EVAL_MODEL, - optimizerModel: TRAININGDATA_OPTIMIZER_MODEL, -}; - -const trainingSessionMinOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, -}; - -// ASSERT METHODS - -function assert(cond) { - if (!cond) throw new Error(); -} - -function assertStrictEquals(actual, expected) { - if (actual !== expected) { - let strRep = actual; - if (typeof actual === 'object') { - strRep = JSON.stringify(actual); - } - throw new Error(`expected: ${expected}; got: ${strRep}`); - } -} - -function assertTwoListsUnequal(list1, list2) { - if (list1.length !== list2.length) { - return; - } - for (let i = 0; i < list1.length; i++) { - if (list1[i] !== list2[i]) { - return; - } - } - throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`); -} - -// HELPER METHODS FOR TESTS - -function generateGaussianRandom(mean = 0, scale = 1) { - const u = 1 - Math.random(); - const v = Math.random(); - const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); - return z * scale + mean; -} - -function generateGaussianFloatArray(length) { - const array = new Float32Array(length); - - for (let i = 0; i < length; i++) { - array[i] = generateGaussianRandom(); - } - - return array; -} - -/** - * creates the TrainingSession and verifies that the input and output names of the training model loaded into the - * training session are correct. - * @param {} ort - * @param {*} createOptions - * @param {*} options - * @returns - */ -async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) { - const trainingSession = await ort.TrainingSession.create(createOptions, options); - - assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.trainingInputNames[1], 'labels'); - assertStrictEquals(trainingSession.trainingInputNames.length, 2); - assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.trainingOutputNames.length, 1); - return trainingSession; -} - -/** - * verifies that the eval input and output names associated with the eval model loaded into the given training session - * are correct. - */ -function checkEvalModel(trainingSession) { - assertStrictEquals(trainingSession.evalInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.evalInputNames[1], 'labels'); - assertStrictEquals(trainingSession.evalInputNames.length, 2); - assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.evalOutputNames.length, 1); -} - -/** - * Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if - * accessed - * @param {} trainingSession - */ -function checkNoEvalModel(trainingSession) { - try { - assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } - try { - assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } -} - -/** - * runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length - * of 1 for the loss. - * @param {} trainingSession - * @param {*} feeds - * @returns - */ -var runTrainStepAndCheck = async function (trainingSession, feeds) { - const results = await trainingSession.runTrainStep(feeds); - assertStrictEquals(Object.keys(results).length, 1); - assertStrictEquals(results['onnx::loss::21273'].data.length, 1); - assertStrictEquals(results['onnx::loss::21273'].type, 'float32'); - return results; -}; - -var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) { - // make a float32 array that is filled with the constant - const newParams = new Float32Array(paramsLength); - for (let i = 0; i < paramsLength; i++) { - newParams[i] = constant; - } - - const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength); - - await trainingSession.loadParametersBuffer(newParamsUint8); - const paramsAfterLoad = await trainingSession.getContiguousParameters(); - - // check that the parameters have changed - assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data); - assertStrictEquals(paramsAfterLoad.dims[0], paramsLength); - - // check that the parameters have changed to what they should be - for (let i = 0; i < paramsLength; i++) { - // round to the same number of digits (4 decimal places) - assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4)); - } - - return paramsAfterLoad; -}; - -// TESTS - -var testInferenceFunction = async function (ort, options) { - const session = await ort.InferenceSession.create('data/model.onnx', options || {}); - - const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); - - const fetches = await session.run({ - a: new ort.Tensor('float32', dataA, [3, 4]), - b: new ort.Tensor('float32', dataB, [4, 3]), - }); - - const c = fetches.c; - - assert(c instanceof ort.Tensor); - assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3); - assert(c.data[0] === 700); - assert(c.data[1] === 800); - assert(c.data[2] === 900); - assert(c.data[3] === 1580); - assert(c.data[4] === 1840); - assert(c.data[5] === 2100); - assert(c.data[6] === 2460); - assert(c.data[7] === 2880); - assert(c.data[8] === 3300); -}; - -var testTrainingFunctionMin = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options); - checkNoEvalModel(trainingSession); - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - const feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - await runTrainStepAndCheck(trainingSession, feeds); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams); -}; - -var testTrainingFunctionAll = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options); - checkEvalModel(trainingSession); - - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - let feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - const results = await runTrainStepAndCheck(trainingSession, feeds); - - await trainingSession.runOptimizerStep(feeds); - feeds = { 'input-0': input0, labels: labels }; - // check getContiguousParameters after optimizerStep -- that the parameters have been updated - const optimizedParams = await trainingSession.getContiguousParameters(); - assertTwoListsUnequal(originalParams.data, optimizedParams.data); - - const results2 = await runTrainStepAndCheck(trainingSession, feeds); - - // check that loss decreased after optimizer step and training again - assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams); -}; - -if (typeof module === 'object') { - module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest]; -} diff --git a/js/web/test/training/e2e/data/model.onnx b/js/web/test/training/e2e/data/model.onnx deleted file mode 100644 index 088124bd48624..0000000000000 --- a/js/web/test/training/e2e/data/model.onnx +++ /dev/null @@ -1,16 +0,0 @@ - backend-test:b - -a -bc"MatMultest_matmul_2dZ -a -  - -Z -b -  - -b -c -  - -B \ No newline at end of file diff --git a/js/web/test/training/e2e/karma.conf.js b/js/web/test/training/e2e/karma.conf.js deleted file mode 100644 index 74662b67676f7..0000000000000 --- a/js/web/test/training/e2e/karma.conf.js +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const args = require('minimist')(process.argv.slice(2)); -const SELF_HOST = !!args['self-host']; -const ORT_MAIN = args['ort-main']; -const TEST_MAIN = args['test-main']; -if (typeof TEST_MAIN !== 'string') { - throw new Error('flag --test-main= is required'); -} -const USER_DATA = args['user-data']; -if (typeof USER_DATA !== 'string') { - throw new Error('flag --user-data= is required'); -} - -module.exports = function (config) { - const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; - config.set({ - frameworks: ['mocha'], - files: [ - { pattern: distPrefix + ORT_MAIN }, - { pattern: './common.js' }, - { pattern: TEST_MAIN }, - { pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true }, - { pattern: './data/*', included: false }, - ], - plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], - proxies: { - '/model.onnx': '/base/model.onnx', - '/data/': '/base/data/', - }, - client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } }, - reporters: ['mocha'], - captureTimeout: 120000, - reportSlowerThan: 100, - browserDisconnectTimeout: 600000, - browserNoActivityTimeout: 300000, - browserDisconnectTolerance: 0, - browserSocketTimeout: 60000, - hostname: 'localhost', - browsers: [], - customLaunchers: { - Chrome_default: { base: 'ChromeHeadless', chromeDataDir: USER_DATA }, - Chrome_no_threads: { - base: 'ChromeHeadless', - chromeDataDir: USER_DATA, - // TODO: no-thread flags - }, - Edge_default: { base: 'Edge', edgeDataDir: USER_DATA }, - }, - }); -}; diff --git a/js/web/test/training/e2e/package.json b/js/web/test/training/e2e/package.json deleted file mode 100644 index 5f11a27de6dfc..0000000000000 --- a/js/web/test/training/e2e/package.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "devDependencies": { - "@chiragrupani/karma-chromium-edge-launcher": "^2.2.2", - "fs-extra": "^11.1.0", - "globby": "^13.1.3", - "karma": "^6.4.1", - "karma-chrome-launcher": "^3.1.1", - "karma-mocha": "^2.0.1", - "karma-mocha-reporter": "^2.2.5", - "light-server": "^2.9.1", - "minimist": "^1.2.7", - "mocha": "^10.2.0" - } -} diff --git a/js/web/test/training/e2e/run.js b/js/web/test/training/e2e/run.js deleted file mode 100644 index d12bcc7aa66ed..0000000000000 --- a/js/web/test/training/e2e/run.js +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const path = require('path'); -const fs = require('fs-extra'); -const { spawn } = require('child_process'); -const startServer = require('./simple-http-server'); -const minimist = require('minimist'); - -// copy whole folder to out-side of /js/ because we need to test in a folder that no `package.json` file -// exists in its parent folder. -// here we use /build/js/e2e-training/ for the test - -const TEST_E2E_SRC_FOLDER = __dirname; -const JS_ROOT_FOLDER = path.resolve(__dirname, '../../../..'); -const TEST_E2E_RUN_FOLDER = path.resolve(JS_ROOT_FOLDER, '../build/js/e2e-training'); -const NPM_CACHE_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../npm_cache'); -const CHROME_USER_DATA_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../user_data'); -fs.emptyDirSync(TEST_E2E_RUN_FOLDER); -fs.emptyDirSync(NPM_CACHE_FOLDER); -fs.emptyDirSync(CHROME_USER_DATA_FOLDER); -fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER); - -// training data to copy -const ORT_ROOT_FOLDER = path.resolve(JS_ROOT_FOLDER, '..'); -const TRAINING_DATA_FOLDER = path.resolve(ORT_ROOT_FOLDER, 'onnxruntime/test/testdata/training_api'); -const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data'); - -// always use a new folder as user-data-dir -let nextUserDataDirId = 0; -function getNextUserDataDir() { - const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()); - nextUserDataDirId++; - fs.emptyDirSync(dir); - return dir; -} - -// commandline arguments -const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default'; - -async function main() { - // find packed package - const { globbySync } = await import('globby'); - - const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common'); - const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', { cwd: ORT_COMMON_FOLDER }); - - const PACKAGES_TO_INSTALL = []; - - if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length === 1) { - PACKAGES_TO_INSTALL.push(path.resolve(ORT_COMMON_FOLDER, ORT_COMMON_PACKED_FILEPATH_CANDIDATES[0])); - } else if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length > 1) { - throw new Error('multiple packages found for onnxruntime-common.'); - } - - const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web'); - const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', { cwd: ORT_WEB_FOLDER }); - if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) { - throw new Error('cannot find exactly single package for onnxruntime-web.'); - } - PACKAGES_TO_INSTALL.push(path.resolve(ORT_WEB_FOLDER, ORT_WEB_PACKED_FILEPATH_CANDIDATES[0])); - - // we start here: - - // install dev dependencies - await runInShell(`npm install`); - - // npm install with "--cache" to install packed packages with an empty cache folder - await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map((i) => `"${i}"`).join(' ')}`); - - // prepare training data - prepareTrainingDataByCopying(); - - console.log('==============================================================='); - console.log('Running self-hosted tests'); - console.log('==============================================================='); - // test cases with self-host (ort hosted in same origin) - await testAllBrowserCases({ hostInKarma: true }); - - console.log('==============================================================='); - console.log('Running not self-hosted tests'); - console.log('==============================================================='); - // test cases without self-host (ort hosted in cross origin) - const server = startServer(path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web'), 8081); - try { - await testAllBrowserCases({ hostInKarma: false }); - } finally { - // close the server after all tests - await server.close(); - } -} - -async function testAllBrowserCases({ hostInKarma }) { - await runKarma({ hostInKarma, main: './browser-test-wasm.js' }); -} - -async function runKarma({ hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js' }) { - console.log('==============================================================='); - console.log(`Running karma with the following binary: ${ortMain}`); - console.log('==============================================================='); - const selfHostFlag = hostInKarma ? '--self-host' : ''; - await runInShell( - `npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ - ortMain - } --test-main=${main} --user-data=${getNextUserDataDir()}`, - ); -} - -async function runInShell(cmd) { - console.log('==============================================================='); - console.log(' Running command in shell:'); - console.log(' > ' + cmd); - console.log('==============================================================='); - let complete = false; - const childProcess = spawn(cmd, { shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER }); - childProcess.on('close', function (code) { - if (code !== 0) { - process.exit(code); - } else { - complete = true; - } - }); - while (!complete) { - await delay(100); - } -} - -async function delay(ms) { - return new Promise(function (resolve) { - setTimeout(function () { - resolve(); - }, ms); - }); -} - -function prepareTrainingDataByCopying() { - fs.copySync(TRAINING_DATA_FOLDER, TRAININGDATA_DEST); - console.log(`Copied ${TRAINING_DATA_FOLDER} to ${TRAININGDATA_DEST}`); -} - -main(); diff --git a/js/web/test/training/e2e/simple-http-server.js b/js/web/test/training/e2e/simple-http-server.js deleted file mode 100644 index ef9cced681cc8..0000000000000 --- a/js/web/test/training/e2e/simple-http-server.js +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -// this is a simple HTTP server that enables CORS. -// following code is based on https://developer.mozilla.org/en-US/docs/Learn/Server-side/Node_server_without_framework - -const http = require('http'); -const fs = require('fs'); -const path = require('path'); - -const getRequestData = (url, dir) => { - const pathname = new URL(url, 'http://localhost').pathname; - - let filepath; - let mimeType; - if (pathname.startsWith('/test-wasm-path-override/') || pathname.startsWith('/dist/')) { - filepath = path.resolve(dir, pathname.substring(1)); - } else { - return null; - } - - if (filepath.endsWith('.wasm')) { - mimeType = 'application/wasm'; - } else if (filepath.endsWith('.js') || filepath.endsWith('.mjs')) { - mimeType = 'text/javascript'; - } else { - return null; - } - - return [filepath, mimeType]; -}; - -module.exports = function (dir, port) { - const server = http - .createServer(function (request, response) { - const url = request.url.replace(/\n|\r/g, ''); - console.log(`request ${url}`); - - const requestData = getRequestData(url, dir); - if (!request || !requestData) { - response.writeHead(404); - response.end('404'); - } else { - const [filePath, contentType] = requestData; - fs.readFile(path.resolve(dir, filePath), function (error, content) { - if (error) { - if (error.code == 'ENOENT') { - response.writeHead(404); - response.end('404'); - } else { - response.writeHead(500); - response.end('500'); - } - } else { - response.setHeader('access-control-allow-origin', '*'); - response.writeHead(200, { 'Content-Type': contentType }); - response.end(content, 'utf-8'); - } - }); - } - }) - .listen(port); - console.log(`Server running at http://localhost:${port}/`); - return server; -}; diff --git a/js/web/types.d.ts b/js/web/types.d.ts index 735b6a89a2a86..b82248c0c83b8 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -20,7 +20,3 @@ declare module 'onnxruntime-web/webgl' { declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } - -declare module 'onnxruntime-web/training' { - export * from 'onnxruntime-web'; -}