diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 47e67879e66ce..ee6d26b22b1f6 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -2,11 +2,18 @@ // Licensed under the MIT License. import {resolveBackend} from './backend-impl.js'; -import {TrainingSessionHandler} from './backend.js'; +import {SessionHandler, TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; +import {Tensor} from './tensor.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; type SessionOptions = InferenceSession.SessionOptions; +type FeedsType = InferenceSession.FeedsType; +type FetchesType = InferenceSession.FetchesType; +type ReturnType = InferenceSession.ReturnType; +type RunOptions = InferenceSession.RunOptions; + const noBackendErrMsg: string = 'Training backend could not be resolved. ' + 'Make sure you\'re using the correct configuration & WebAssembly files.'; @@ -42,21 +49,138 @@ export class TrainingSession implements TrainingSessionInterface { } } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + /** + * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from + * the given parameters to SessionHandler.FetchesType and RunOptions. + * + * @param feeds the required input + * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object + * @param arg2 optional RunOptions object. + * @returns + */ + typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): + [SessionHandler.FetchesType, RunOptions] { + const fetches: {[name: string]: OnnxValue|null} = {}; + let options: RunOptions = {}; + // check inputs + if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { + throw new TypeError( + '\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.'); + } + + let isFetchesEmpty = true; + // determine which override is being used + if (typeof arg1 === 'object') { + if (arg1 === null) { + throw new TypeError('Unexpected argument[1]: cannot be null.'); + } + if (arg1 instanceof Tensor) { + throw new TypeError('\'fetches\' cannot be a Tensor'); + } + + if (Array.isArray(arg1)) { + if (arg1.length === 0) { + throw new TypeError('\'fetches\' cannot be an empty array.'); + } + isFetchesEmpty = false; + // output names + for (const name of arg1) { + if (typeof name !== 'string') { + throw new TypeError('\'fetches\' must be a string array or an object.'); + } + if (this.outputNames.indexOf(name) === -1) { + throw new RangeError(`'fetches' contains invalid output name: ${name}.`); + } + fetches[name] = null; + } + + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + // decide whether arg1 is fetches or options + // if any output name is present and its value is valid OnnxValue, we consider it fetches + let isFetches = false; + const arg1Keys = Object.getOwnPropertyNames(arg1); + for (const name of this.outputNames) { + if (arg1Keys.indexOf(name) !== -1) { + const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; + if (v === null || v instanceof Tensor) { + isFetches = true; + isFetchesEmpty = false; + fetches[name] = v; + } + } + } + + if (isFetches) { + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + options = arg1 as RunOptions; + } + } + } else if (typeof arg1 !== 'undefined') { + throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.'); + } + + // check if all inputs are in feed + for (const name of this.inputNames) { + if (typeof feeds[name] === 'undefined') { + throw new Error(`input '${name}' is missing in 'feeds'.`); + } + } + + // if no fetches is specified, we use the full output names list + if (isFetchesEmpty) { + for (const name of this.outputNames) { + fetches[name] = null; + } + } + + return [fetches, options]; } - async getContiguousParameters(_trainableOnly: boolean): Promise { + /** + * Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler + * and changes it into a map of Tensors. + * + * @param results + * @returns + */ + convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType { + const returnValue: {[name: string]: OnnxValue} = {}; + for (const key in results) { + if (Object.hasOwnProperty.call(results, key)) { + const result = results[key]; + if (result instanceof Tensor) { + returnValue[key] = result; + } else { + returnValue[key] = new Tensor(result.type, result.data, result.dims); + } + } + } + return returnValue; + } + + runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; + runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; + async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2); + const results = await this.handler.runTrainStep(feeds, fetches, options); + return this.convertHandlerReturnTypeToMapOfTensors(results); + } + + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } - runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined): - Promise; - runTrainStep( - feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions|undefined): Promise; - async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown): - Promise { + async getContiguousParameters(_trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 5ea7de809a495..7176823c9bf13 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -5,7 +5,7 @@ import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {Session} from './onnxjs/session'; -import {OnnxjsSessionHandler} from './onnxjs/session-handler'; +import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference'; class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index 98e40807aa29c..09dac3a85311c 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -4,7 +4,7 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; -import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 5740263583031..78edcc90f55f9 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -5,7 +5,7 @@ import {cpus} from 'node:os'; import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; -import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler'; +import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference'; /** * This function initializes all flags for WebAssembly. diff --git a/js/web/lib/onnxjs/session-handler.ts b/js/web/lib/onnxjs/session-handler-inference.ts similarity index 100% rename from js/web/lib/onnxjs/session-handler.ts rename to js/web/lib/onnxjs/session-handler-inference.ts diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts deleted file mode 100644 index 83d133b9a5157..0000000000000 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common'; - -import {SerializableModeldata} from './proxy-messages'; -import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; - -export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - private sessionId: number; - private checkpointId: number; - - inputNames: string[]; - outputNames: string[]; - - inputEncodedNames: number[]; - outputEncodedNames: number[]; - - async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { - let buffer: Uint8Array; - if (typeof uriOrBuffer === 'string') { - const response = await fetch(uriOrBuffer); - const arrayBuffer = await response.arrayBuffer(); - buffer = new Uint8Array(arrayBuffer); - } else { - buffer = uriOrBuffer; - } - return createSessionAllocate(buffer); - } - - async createTrainingSession( - checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, - evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, - options: InferenceSession.SessionOptions) { - if (!isOrtEnvInitialized()) { - await initRuntime(env); - } - const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); - // 0 is supposed to be the nullptr - let evalModelData: SerializableModeldata = [0, 0]; - let optimizerModelData: SerializableModeldata = [0, 0]; - - if (evalModelUriOrBuffer !== '') { - evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); - } - if (optimizerModelUriOrBuffer !== '') { - optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); - } - - this.checkpointId = createCheckpointHandle(checkpointData); - [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = - createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); - } - - async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint( - this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); - } - - async runTrainStep( - _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, - _options: InferenceSession.RunOptions): Promise { - throw new Error('Method not implemented yet.'); - } -} diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler-inference.ts similarity index 96% rename from js/web/lib/wasm/session-handler.ts rename to js/web/lib/wasm/session-handler-inference.ts index a5017a920f38b..3ca34d957c572 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -10,7 +10,7 @@ import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitializationPromise: Promise|undefined; -const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { +export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { switch (tensor.location) { case 'cpu': return [tensor.type, tensor.dims, tensor.data, 'cpu']; @@ -21,7 +21,7 @@ const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMeta } }; -const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { +export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { switch (tensor[3]) { case 'cpu': return new Tensor(tensor[0], tensor[2], tensor[1]); diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts new file mode 100644 index 0000000000000..09d91591128d1 --- /dev/null +++ b/js/web/lib/wasm/session-handler-training.ts @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; + +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; +import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; +import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; + +export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + private sessionId: number; + private checkpointId: number; + + inputNames: string[]; + outputNames: string[]; + + inputEncodedNames: number[]; + outputEncodedNames: number[]; + + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + let buffer: Uint8Array; + if (typeof uriOrBuffer === 'string') { + const response = await fetch(uriOrBuffer); + const arrayBuffer = await response.arrayBuffer(); + buffer = new Uint8Array(arrayBuffer); + } else { + buffer = uriOrBuffer; + } + return createSessionAllocate(buffer); + } + + async createTrainingSession( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions) { + if (!isOrtEnvInitialized()) { + await initRuntime(env); + } + const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + // 0 is supposed to be the nullptr + let evalModelData: SerializableModeldata = [0, 0]; + let optimizerModelData: SerializableModeldata = [0, 0]; + + if (evalModelUriOrBuffer !== '') { + evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); + } + if (optimizerModelUriOrBuffer !== '') { + optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); + } + + this.checkpointId = createCheckpointHandle(checkpointData); + [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + } + + /** + * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the + * corresponding name as a number referring to the index in the list of names provided. + * + * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType + * @param names either inputNames or outputNames + * @returns a tuple of a list of values and a list of indices. + */ + convertMapIntoValuesArrayAndIndicesArray( + feeds: {[name: string]: T}, names: string[], mapFunc: (val: T, index: number) => U): [T[], number[], U[]] { + const values: T[] = []; + const indices: number[] = []; + Object.entries(feeds).forEach(kvp => { + const name = kvp[0]; + const tensor = kvp[1]; + const index = names.indexOf(name); + if (index === -1) { + throw new Error(`invalid input '${name}`); + } + values.push(tensor); + indices.push(index); + }); + + const uList = values.map(mapFunc); + return [values, indices, uList]; + } + + /** + * Helper method that converts the TensorMetadata that the wasm-core functions return to the + * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the + * corresponding result. + * + * @param results used to populate the resultMap if there is no value for that outputName already + * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results + * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. + * @returns a map of output names and OnnxValues. + */ + convertTensorMetadataToReturnType( + results: TensorMetadata[], outputArray: Array, outputIndices: number[]): SessionHandler.ReturnType { + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results.length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); + } + return resultMap; + } + + async runTrainStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise { + const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( + feeds, this.inputNames, + (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + + const [outputArray, outputIndices, outputs] = + this.convertMapIntoValuesArrayAndIndicesArray( + fetches, this.outputNames, + (t, i): TensorMetadata|null => + t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); + return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); + } + + async dispose(): Promise { + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + } +} diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 947242945c665..3aacf8f4d90e0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -240,7 +240,7 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; -const prepareInputOutputTensor = +export const prepareInputOutputTensor = (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): void => { if (!tensor) { diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 4830b5d2b5e80..a35d285346db4 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -1,10 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession} from 'onnxruntime-common'; +import {InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; +import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; +import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; +import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; @@ -146,6 +149,170 @@ export const createTrainingSessionHandle = } }; +/** + * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the + * WASM tensors. + * + * @param trainingSessionId + * @param indices for each tensor, the index of the input or output name that the tensor corresponds with + * @param tensors list of TensorMetaData + * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting + * handles of the allocated tensors on the heap + * @param inputOutputAllocs modified in-place by this method + * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor + */ +const createAndAllocateTensors = + (trainingSessionId: number, indices: number[], tensors: Array, tensorHandles: number[], + inputOutputAllocs: number[], indexAdd: number) => { + const count = indices.length; + + // creates the tensors + for (let i = 0; i < count; i++) { + prepareInputOutputTensor( + tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); + } + + // moves to heap + const wasm = getInstance(); + const valuesOffset = wasm.stackAlloc(count * 4); + let valuesIndex = valuesOffset / 4; + for (let i = 0; i < count; i++) { + wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; + } + + return valuesOffset; + }; + +/** + * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information + * associated with the tensor handle. + * + * @param outputValuesOffset + * @param outputCount + * @returns list of TensorMetadata retrieved from the output handles. + */ +const moveOutputToTensorMetadataArr = + (outputValuesOffset: number, outputCount: number, outputTensorHandles: number[], + outputTensors: Array) => { + const wasm = getInstance(); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; + } + + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); + + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data, 'cpu']); + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + wasm._OrtReleaseTensor(tensor); + } + } + + return output; + }; + +export const runTrainStep = async( + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + + try { + // prepare parameters by moving them to heap + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + // handle inputs -- you don't want anything added to the index + const inputValuesOffset = createAndAllocateTensors( + trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + // handle outputs + // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor + const outputValuesOffset = createAndAllocateTensors( + trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); + + if (wasm._OrtTrainingRunTrainStep) { + const errorCode = wasm._OrtTrainingRunTrainStep( + trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + + if (errorCode !== 0) { + checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); + } + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); + } finally { + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); + + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => {