diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index e2d962677f30c..faf597931165a 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -4,9 +4,9 @@ import {resolveBackend} from './backend-impl.js'; import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; +import {Tensor} from './tensor.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; -import { OnnxValue } from './onnx-value.js'; -import { Tensor } from './tensor.js'; type SessionOptions = InferenceSession.SessionOptions; type FeedsType = InferenceSession.FeedsType; @@ -49,17 +49,8 @@ export class TrainingSession implements TrainingSessionInterface { } } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; - runTrainStep( - feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; + runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { const fetches: {[name: string]: OnnxValue|null} = {}; let options: RunOptions = {}; @@ -159,6 +150,14 @@ export class TrainingSession implements TrainingSessionInterface { return returnValue; } + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + async release(): Promise { return this.handler.dispose(); } diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 5ea7de809a495..7176823c9bf13 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -5,7 +5,7 @@ import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {Session} from './onnxjs/session'; -import {OnnxjsSessionHandler} from './onnxjs/session-handler'; +import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference'; class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index 98e40807aa29c..09dac3a85311c 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -4,7 +4,7 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; -import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 5740263583031..78edcc90f55f9 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -5,7 +5,7 @@ import {cpus} from 'node:os'; import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; -import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler'; +import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference'; /** * This function initializes all flags for WebAssembly. diff --git a/js/web/lib/onnxjs/session-handler.ts b/js/web/lib/onnxjs/session-handler-inference.ts similarity index 100% rename from js/web/lib/onnxjs/session-handler.ts rename to js/web/lib/onnxjs/session-handler-inference.ts diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler-inference.ts similarity index 100% rename from js/web/lib/wasm/session-handler.ts rename to js/web/lib/wasm/session-handler-inference.ts diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-training.ts similarity index 91% rename from js/web/lib/wasm/session-handler-for-training.ts rename to js/web/lib/wasm/session-handler-training.ts index 9aeca4a28dce4..e754e0bf64282 100644 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, SessionHandler, TrainingSessionHandler, Tensor} from 'onnxruntime-common'; +import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; import {SerializableModeldata} from './proxy-messages'; +import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, runTrainStep, - releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; -import { encodeTensorMetadata, decodeTensorMetadata } from './session-handler'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { @@ -99,7 +98,7 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); const resultMap: SessionHandler.ReturnType = {}; - for (let i = 0; i < results. length; i++) { + for (let i = 0; i < results.length; i++) { resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); } return resultMap; @@ -109,5 +108,4 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return releaseTrainingSessionAndCheckpoint( this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); } - } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 7593bca81ffce..d8bb0fae905f0 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -3,13 +3,13 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; -import { prepareInputOutputTensor } from './wasm-core-impl'; import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; +import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import { tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor } from './wasm-common'; +import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -import { setRunOptions } from './run-options'; const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. ' + 'Make sure to use the onnxruntime-training package for training functionality.'; @@ -143,10 +143,113 @@ export const createTrainingSessionHandle = } }; +/** + * Prepares input and output tensors by creating the tensors in the WASM side then moving them to the heap + * @param trainingSessionId + * @param indices for each tensor, the index of the input or output name that the tensor corresponds with + * @param tensors list of TensorMetaData + * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting + * handles of the allocated tensors on the heap + * @param inputOutputAllocs modified in-place by this method + * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor + */ +const createAndAllocateTensors = + (trainingSessionId: number, indices: number[], tensors: Array, tensorHandles: number[], + inputOutputAllocs: number[], indexAdd: number) => { + const wasm = getInstance(); + + const count = indices.length; + const valuesOffset = wasm.stackAlloc(count * 4); + + // creates the tensors + for (let i = 0; i < count; i++) { + prepareInputOutputTensor( + tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); + } + + // moves to heap + let valuesIndex = valuesOffset / 4; + for (let i = 0; i < count; i++) { + wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; + } + + return valuesOffset; + }; + +/** + * Move output tensors from the heap to an array + * @param outputValuesOffset + * @param outputCount + * @returns + */ +const moveOutputToTensorMetadataArr = + (outputValuesOffset: number, outputCount: number) => { + const wasm = getInstance(); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); + + const keepOutputTensor = false; + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data, 'cpu']); + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + if (!keepOutputTensor) { + wasm._OrtReleaseTensor(tensor); + } + } + } + + return output; + }; + export const runTrainStep = async( - trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], - outputTensors: Array, options: InferenceSession.RunOptions): Promise => { - const wasm = getInstance(); + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -159,108 +262,31 @@ export const runTrainStep = async( const inputOutputAllocs: number[] = []; const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); try { + // prepare parameters by moving them to heap [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - // TODO: - // move all input and output processing -> wasm heap to one helper method???? - // can abstract out the similarities between input and output - // create input tensors - for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, trainingSessionId, inputIndices[i]); - } - - // create output tensors - for (let i = 0; i < outputCount; i++) { - prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, trainingSessionId, inputCount + outputIndices[i]); - } - - let inputValuesIndex = inputValuesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; - } - for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; - } - - let errorCode: number; + // handle inputs -- you don't want anything added to the index + const inputValuesOffset = createAndAllocateTensors( + trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + // handle outputs + // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor + const outputValuesOffset = createAndAllocateTensors( + trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); if (wasm._OrtTrainingRunTrainStep) { - errorCode = await wasm._OrtTrainingRunTrainStep(trainingSessionId, inputValuesOffset, inputCount, - outputValuesOffset, outputCount, runOptionsHandle); - } - else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } + const errorCode = wasm._OrtTrainingRunTrainStep( + trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); - if (errorCode !== 0) { - checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } - - const output: TensorMetadata[] = []; - - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); - - let keepOutputTensor = false; - let type: Tensor.Type|undefined, dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); - } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength) - .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - if (!keepOutputTensor) { - wasm._OrtReleaseTensor(tensor); - } + if (errorCode !== 0) { + checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); } + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); } - return output; + return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount); } finally { wasm.stackRestore(beforeRunStack);