diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index dd04ef3f15997..fd2e8bb74bbf5 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -49,8 +49,9 @@ export interface TrainingSessionHandler extends SessionHandler { feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise; - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; - getContiguousParameters(trainableOnly: boolean): Promise; + getParametersSize(trainableOnly: boolean): Promise; + loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; } /** diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index faf597931165a..48fed4224514f 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -150,12 +150,16 @@ export class TrainingSession implements TrainingSessionInterface { return returnValue; } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async getParametersSize(trainableOnly: boolean): Promise { + return this.handler.getParametersSize(trainableOnly); } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise { + return this.handler.loadParametersBuffer(array, trainableOnly); + } + + async getContiguousParameters(trainableOnly: boolean): Promise { + return this.handler.getContiguousParameters(trainableOnly); } async release(): Promise { diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index 0967d79b33434..40ea16cf05ce4 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -49,13 +50,21 @@ export interface TrainingSession { // #endregion // #region copy parameters + + /** + * Retrieves the size of all parameters for the training state. + * + * @param trainableOnly skips non-trainable parameters when true. + */ + getParametersSize(trainableOnly: boolean): Promise; + /** * Copies from a buffer containing parameters to the TrainingSession parameters. * * @param buffer - buffer containing parameters * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. */ - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; /** * Copies from the TrainingSession parameters to a buffer. @@ -63,7 +72,7 @@ export interface TrainingSession { * @param trainableOnly - True if trainable parameters only to be copied, false othrwise. * @returns A promise that resolves to a buffer of the requested parameters. */ - getContiguousParameters(trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; // #endregion // #region release() diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 060fb1e756ef9..def706f53fc3a 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,8 +102,10 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; - _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number; + _OrtTrainingGetInputOutputCount? + (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetInputOutputName? + (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index e754e0bf64282..af8f6dc0e2dd2 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,20 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; +import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; import {SerializableModeldata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, + releaseTrainingSessionAndCheckpoint, runTrainStep, loadParametersBuffer} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } private sessionId: number; private checkpointId: number; @@ -104,6 +99,18 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return resultMap; } + async getParametersSize(trainableOnly: boolean): Promise { + return getParametersSize(this.sessionId, trainableOnly); + } + + async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise { + await loadParametersBuffer(this.sessionId, array, trainableOnly); + } + async getContiguousParameters(trainableOnly: boolean): Promise { + const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); + return decodeTensorMetadata(tensorResult); + } + async dispose(): Promise { return releaseTrainingSessionAndCheckpoint( this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index d8bb0fae905f0..75035e4b9f694 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -6,22 +6,25 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. ' + - 'Make sure to use the onnxruntime-training package for training functionality.'; +const NO_TRAIN_FUNCS_MSG = + `Built without training API's enabled. Use the onnxruntime-web/training import for training \ + functionality, and make sure that all the correct artifacts are built & moved to the correct folder if \ + using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.`; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); + const [checkpointDataOffset, checkpointDataLength] = checkpointData; let checkpointHandle = 0; try { if (wasm._OrtTrainingLoadCheckpoint) { - checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointData[0], checkpointData[1]); + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -47,7 +50,7 @@ const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, n try { const dataOffset = wasm.stackAlloc(8); if (wasm._OrtTrainingGetInputOutputCount) { - const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4); + const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, false); if (errorCode !== 0) { checkLastError('Can\'t get session input/output count.'); } @@ -68,7 +71,7 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetInputOutputName) { - const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput); + const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput, false); if (name === 0) { checkLastError('Can\'t get input or output name'); } @@ -182,69 +185,65 @@ const createAndAllocateTensors = * @param outputCount * @returns */ -const moveOutputToTensorMetadataArr = - (outputValuesOffset: number, outputCount: number) => { - const wasm = getInstance(); - const output: TensorMetadata[] = []; +const moveOutputToTensorMetadataArr = (outputValuesOffset: number, outputCount: number) => { + const wasm = getInstance(); + const output: TensorMetadata[] = []; - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); - const keepOutputTensor = false; - let type: Tensor.Type|undefined, dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); - } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength) - .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - if (!keepOutputTensor) { - wasm._OrtReleaseTensor(tensor); - } + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data, 'cpu']); } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + wasm._OrtReleaseTensor(tensor); + } + } - return output; - }; + return output; +}; export const runTrainStep = async( trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], @@ -301,6 +300,134 @@ export const runTrainStep = async( } }; +export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): + number => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + try { + const sizeOffset = wasm.stackAlloc(4); + if (wasm._OrtTrainingGetParametersSize) { + const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); + + if (errorCode !== 0) { + checkLastError('Can\'t get parameters size'); + } + + return wasm.HEAP32[sizeOffset / 4]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } + }; + +export const getContiguousParameters = async(trainingSessionId: number, trainableOnly: boolean): + Promise => { + const wasm = getInstance(); + const parametersSize = getParametersSize(trainingSessionId, trainableOnly); + // alloc buffer -- assumes parameters will be of type float32 + const stack = wasm.stackSave(); + let tensor: number = 0; + + const paramsByteLength = 4 * parametersSize; + const paramsOffset = wasm.stackAlloc(paramsByteLength); + const bufferAlloc = wasm.stackAlloc(paramsOffset/4); + wasm.HEAPU8.set(new Float32Array(parametersSize), paramsOffset); + + // handles the dimensions-related createTensor parameters + const dimsOffset = wasm.stackAlloc(4); + const dimsIndex = dimsOffset / 4; + wasm.HEAP32[dimsIndex] = parametersSize; + try { + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum('float32'), paramsOffset, paramsByteLength, dimsOffset, 1, + dataLocationStringToEnum('cpu')); + if (tensor === 0) { + checkLastError(`Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`); + } + wasm.HEAPU32[bufferAlloc] = tensor; + if (wasm._OrtTrainingCopyParametersToBuffer) { + const errCode = + wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); + if (errCode !== 0) { + checkLastError('Can\'t get contiguous parameters.'); + } + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + const typedArrayConstructor = tensorTypeToTypedArrayConstructor('float32'); + const data = new typedArrayConstructor(parametersSize); + const output: TensorMetadata[] = []; + new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); + output.push(['float32', [parametersSize], data, 'cpu']); + if (output.length > 1 || output.length < 1) { + throw new Error( + `something unexpected happened in the getContiguousParameters function. Expected output length of + one, got ${output.length}`); + } else { + return output[0]; + } + } finally { + console.log('test'); + if (tensor !== 0) { + console.log('tensor is not equal to 0'); + wasm._OrtReleaseTensor(tensor); + } + console.log('test after ortReleaseTensor call but before stackRestore call'); + wasm._free(paramsOffset); + wasm._free(dimsOffset); + wasm._free(bufferAlloc); + wasm.stackRestore(stack); + } + }; + +export const loadParametersBuffer = async (trainingSessionId: number, buffer: Float32Array, trainableOnly: boolean): + Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + const bufferCount = buffer.length; + const bufferByteLength = bufferCount * 4; + const bufferOffset = wasm.stackAlloc(bufferByteLength); + wasm.HEAPU8.set(new Uint8Array(buffer.buffer, buffer.byteOffset, buffer.byteLength), bufferOffset); + const dimsOffset = wasm.stackAlloc(4); + wasm.HEAP32[dimsOffset / 4] = bufferCount; + const dimsLength = 1; + let tensor: number = 0; + const bufferAlloc = wasm.stackAlloc(bufferOffset/4); + + try { + tensor = wasm._OrtCreateTensor(tensorDataTypeStringToEnum('float32'), bufferOffset, bufferByteLength, dimsOffset, dimsLength, dataLocationStringToEnum('cpu')); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${trainingSessionId}`); + } + wasm.HEAPU32[bufferAlloc] = tensor; + + if (wasm._OrtTrainingCopyParametersFromBuffer) { + const errCode = + wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); + + if (errCode !== 0) { + checkLastError('Can\'t copy buffer to parameters.'); + } + + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm.stackRestore(stack); + wasm._free(bufferAlloc); + wasm._free(bufferOffset); + wasm._free(dimsOffset); + } +} + export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 2645d6f05222f..f8375c0a77ae3 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -581,30 +581,52 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, size_t* input_count, - size_t* output_count) { - RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); - RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); - return ORT_OK; + size_t* output_count, + bool isEvalModel) { + if (isEvalModel) { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelOutputCount, training_handle, output_count); + return ORT_OK; + } else { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); + return ORT_OK; + } } char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, size_t index, - bool isInput) { + bool isInput, + bool isEvalModel) { OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); char* name = nullptr; - if (isInput) { - return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, - allocator, &name) == ORT_OK) - ? name - : nullptr; + if (isEvalModel) { + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } } else { - return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, - allocator, &name) == ORT_OK) - ? name - : nullptr; + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } } } diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index ea21eb8a9e8c8..d7bc84c0f00bd 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -433,27 +433,33 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio bool trainable_only); /** - * Gets the input count and output count of the training model associated with the given training handle. + * Gets the input count and output count of the training or eval model associated with the given training handle. * @param traning_handle handle of the traning session * @param input_count [out] a pointer to a size_t variable to accept input_count * @param output_count [out] a pointer to a size_t variable to accept output_count + * @param isEvalModel when false, returns input & output count of the training model. When true, returns input & output + * count of the eval model. * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, size_t* input_count, - size_t* output_count); + size_t* output_count, + bool isEvalModel); /** - * Gets the input or output name at the specified index associated with the training model from the + * Gets the input or output name at the specified index associated with the training or eval model from the * given training session. * @param traning_handle handle of the traning session * @param index the input or output index * @param isInput if true, this method retrieves an input name. If false, this method retrieves an output name. + * @param isEvalModel when false, returns input & output names of the training model. When true, returns input & output + * names of the eval model. * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by */ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, size_t index, - bool isInput); + bool isInput, + bool isEvalModel); /** * @brief Release the specified ORT training session.