Skip to content

Commit

Permalink
[js/web/training] Added parameters methods (#18250)
Browse files Browse the repository at this point in the history
### Description
* Implemented: `getParametersSize`, `getContiguousParameters`
(equivalent to copyParametersToBuffer), and `loadParametersBuffer`
(equivalent to copyParametersFromBuffer)
* as part of these changes, getParametersSize was added to the
TrainingSession interface so that users know what size buffer to create
for loadParametersBuffer
* The parameters methods in the interface were modified to take in a
Float32Array instead


### Motivation and Context
* part of the work for implementing web bindings for training
* enables federated learning in the web
* previous  PR: #18006

---------

Co-authored-by: Ashwini Khade <[email protected]>
  • Loading branch information
carzh and askhade authored Nov 27, 2023
1 parent a2fd8a6 commit dd355e3
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 40 deletions.
3 changes: 2 additions & 1 deletion js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ export interface TrainingSessionHandler extends SessionHandler {
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;

getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
}

/**
Expand Down
20 changes: 16 additions & 4 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,24 @@ export class TrainingSession implements TrainingSessionInterface {
return this.convertHandlerReturnTypeToMapOfTensors(results);
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
async getParametersSize(trainableOnly = true): Promise<number> {
return this.handler.getParametersSize(trainableOnly);
}

async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise<void> {
const paramsSize = await this.getParametersSize(trainableOnly);
// checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number
// of parameters
if (array.length !== 4 * paramsSize) {
throw new Error(
'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
'the model. Please use getParametersSize method to check.');
}
return this.handler.loadParametersBuffer(array, trainableOnly);
}

async getContiguousParameters(trainableOnly = true): Promise<OnnxValue> {
return this.handler.getContiguousParameters(trainableOnly);
}

async release(): Promise<void> {
Expand Down
27 changes: 20 additions & 7 deletions js/common/lib/training-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js';

/* eslint-disable @typescript-eslint/no-redeclare */
Expand Down Expand Up @@ -49,21 +50,33 @@ export interface TrainingSession {
// #endregion

// #region copy parameters

/**
* Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of
* the parameters) elements of all the parameters in the training state.
*
* @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true.
*/
getParametersSize(trainableOnly: boolean): Promise<number>;

/**
* Copies from a buffer containing parameters to the TrainingSession parameters.
* Copies parameter values from the given array to the training state. Currently, only supporting models with
* parameters of type Float32.
*
* @param buffer - buffer containing parameters
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise.
* @param buffer - Float32 buffer containing parameters converted to a Uint8Array.
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
*/
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;

/**
* Copies from the TrainingSession parameters to a buffer.
* Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
* Currently, only supporting models with parameters of type Float32.
*
* @param trainableOnly - True if trainable parameters only to be copied, false othrwise.
* @returns A promise that resolves to a buffer of the requested parameters.
* @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters
* for which requires_grad is set to true. Default value is true.
* @returns A promise that resolves to a Float32 OnnxValue of the requested parameters.
*/
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
// #endregion

// #region release()
Expand Down
22 changes: 14 additions & 8 deletions js/web/lib/wasm/session-handler-training.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common';
import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common';

import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';

export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}
async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}
private sessionId: number;
private checkpointId: number;

Expand Down Expand Up @@ -124,6 +118,18 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}

async getParametersSize(trainableOnly: boolean): Promise<number> {
return getParametersSize(this.sessionId, trainableOnly);
}

async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void> {
await loadParametersBuffer(this.sessionId, array, trainableOnly);
}
async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly);
return decodeTensorMetadata(tensorResult);
}

async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
Expand Down
166 changes: 146 additions & 20 deletions js/web/lib/wasm/wasm-training-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {InferenceSession, Tensor} from 'onnxruntime-common';
import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {prepareInputOutputTensor} from './wasm-core-impl';
import {getInstance} from './wasm-factory';
import {checkLastError} from './wasm-utils';
Expand All @@ -16,6 +16,22 @@ const NO_TRAIN_FUNCS_MSG =
'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';

/**
* Runs the checkLastError function which will throw an error, if the provided error code matches the specified
* pattern for an error code.
* @param errCode number to evaluated for if it's an error
* @param message message to pass into checkLastError
* @param checkNeqZero when true, treats not equal to zero as an error.
* When false, treats equal to zero as an error.
*/
const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => {
if (checkNeqZero && errCode !== 0) {
checkLastError(message);
} else if (!checkNeqZero && errCode === 0) {
checkLastError(message);
}
};

export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => {
const wasm = getInstance();

Expand All @@ -29,9 +45,7 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n
throw new Error(NO_TRAIN_FUNCS_MSG);
}

if (checkpointHandle === 0) {
checkLastError('Error occurred when trying to create a CheckpointState.');
}
ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false);
return checkpointHandle;
} catch (e) {
if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
Expand All @@ -52,9 +66,7 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea
if (wasm._OrtTrainingGetModelInputOutputCount) {
const errorCode =
wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel);
if (errorCode !== 0) {
checkLastError('Can\'t get session input/output count.');
}
ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.');
return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
Expand All @@ -74,9 +86,7 @@ const getModelInputOutputNamesLoop =
for (let i = 0; i < count; i++) {
if (wasm._OrtTrainingGetModelInputOutputName) {
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
if (name === 0) {
checkLastError('Can\'t get input or output name');
}
ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);

namesUTF8Encoded.push(name);
names.push(wasm.UTF8ToString(name));
Expand Down Expand Up @@ -122,9 +132,7 @@ export const createTrainingSessionHandle =
throw new Error(NO_TRAIN_FUNCS_MSG);
}

if (trainingSessionHandle === 0) {
checkLastError('Error occurred when trying to create a TrainingSession.');
}
ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);

[inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
getTrainingModelInputOutputNames(trainingSessionHandle);
Expand Down Expand Up @@ -213,9 +221,8 @@ const moveOutputToTensorMetadataArr =
try {
const errorCode = wasm._OrtGetTensorData(
tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
if (errorCode !== 0) {
checkLastError(`Can't access output tensor data on index ${i}.`);
}
ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`);

let tensorDataIndex = tensorDataOffset / 4;
const dataType = wasm.HEAPU32[tensorDataIndex++];
dataOffset = wasm.HEAPU32[tensorDataIndex++];
Expand Down Expand Up @@ -290,10 +297,7 @@ export const runTrainStep = async(
if (wasm._OrtTrainingRunTrainStep) {
const errorCode = wasm._OrtTrainingRunTrainStep(
trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle);

if (errorCode !== 0) {
checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
}
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
Expand All @@ -313,6 +317,128 @@ export const runTrainStep = async(
}
};

export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
const wasm = getInstance();
const stack = wasm.stackSave();

try {
const sizeOffset = wasm.stackAlloc(4);
if (wasm._OrtTrainingGetParametersSize) {
const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly);
ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size');

return wasm.HEAP32[sizeOffset / 4];
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
wasm.stackRestore(stack);
}
};

export const getContiguousParameters =
async(trainingSessionId: number, trainableOnly: boolean): Promise<TensorMetadata> => {
const wasm = getInstance();
const stack = wasm.stackSave();

const tensorTypeAsString = 'float32';
const locationAsString = 'cpu';

const parametersSize = getParametersSize(trainingSessionId, trainableOnly);
let tensor = 0;

// allocates a buffer of the correct size on the WASM heap
const paramsByteLength = 4 * parametersSize;
const paramsOffset = wasm._malloc(paramsByteLength);

// handles the dimensions-related createTensor parameters
const dims = [parametersSize];

const dimsOffset = wasm.stackAlloc(4);
const dimsIndex = dimsOffset / 4;
wasm.HEAP32[dimsIndex] = parametersSize;

try {
// wraps allocated array in a tensor
tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length,
dataLocationStringToEnum(locationAsString));
ifErrCodeCheckLastError(
tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false);

if (wasm._OrtTrainingCopyParametersToBuffer) {
const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly);
ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.');

} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}

// copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString);
const data = new typedArrayConstructor(parametersSize);
const output: TensorMetadata[] = [];
new Uint8Array(data.buffer, data.byteOffset, data.byteLength)
.set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength));
output.push([tensorTypeAsString, dims, data, locationAsString]);
if (output.length !== 1) {
throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of
one, got ${output.length}`);
} else {
return output[0];
}
} finally {
if (tensor !== 0) {
wasm._OrtReleaseTensor(tensor);
}
wasm._free(paramsOffset);
wasm._free(dimsOffset);
wasm.stackRestore(stack);
}
};

export const loadParametersBuffer =
async(trainingSessionId: number, buffer: Uint8Array, trainableOnly: boolean): Promise<void> => {
const wasm = getInstance();
const stack = wasm.stackSave();

const tensorTypeAsString = 'float32';
const locationAsString = 'cpu';

// allocates & copies JavaScript buffer to WASM heap
const bufferByteLength = buffer.length;
const bufferCount = bufferByteLength / 4;
const bufferOffset = wasm._malloc(bufferByteLength);
wasm.HEAPU8.set(buffer, bufferOffset);

// allocates and handles moving dimensions information to WASM memory
const dimsOffset = wasm.stackAlloc(4);
wasm.HEAP32[dimsOffset / 4] = bufferCount;
const dimsLength = 1;
let tensor = 0;

try {
tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength,
dataLocationStringToEnum(locationAsString));
ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false);

if (wasm._OrtTrainingCopyParametersFromBuffer) {
const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly);
ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.');
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
if (tensor !== 0) {
wasm._OrtReleaseTensor(tensor);
}
wasm.stackRestore(stack);
wasm._free(bufferOffset);
wasm._free(dimsOffset);
}
};

export const releaseTrainingSessionAndCheckpoint =
(checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
void => {
Expand Down

0 comments on commit dd355e3

Please sign in to comment.