Skip to content

Commit

Permalink
lint + format + added error code checking wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
carzh committed Oct 24, 2023
1 parent b44b705 commit c74112e
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 133 deletions.
3 changes: 1 addition & 2 deletions js/web/lib/wasm/session-handler-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessio
import {SerializableModeldata} from './proxy-messages';
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize,
releaseTrainingSessionAndCheckpoint, runTrainStep, loadParametersBuffer} from './wasm-training-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';

export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
private sessionId: number;
Expand Down
266 changes: 135 additions & 131 deletions js/web/lib/wasm/wasm-training-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ const NO_TRAIN_FUNCS_MSG =
functionality, and make sure that all the correct artifacts are built & moved to the correct folder if \
using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.`;

/**
* Runs the checkLastError function which will throw an error, if the provided error code matches the specified
* pattern for an error code.
* @param errCode number to evaluated for if it's an erro
* @param message message to pass into checkLastError
* @param checkNeqZero when true, treats not equal to zero as an error.
* When false, treats equal to zero as an error.
*/
const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => {
if (checkNeqZero && errCode !== 0) {
checkLastError(message);
} else if (!checkNeqZero && errCode === 0) {
checkLastError(message);
}
};

export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => {
const wasm = getInstance();

Expand All @@ -29,9 +45,8 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n
throw new Error(NO_TRAIN_FUNCS_MSG);
}

if (checkpointHandle === 0) {
checkLastError('Error occurred when trying to create a CheckpointState.');
}
ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false);

return checkpointHandle;
} catch (e) {
if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
Expand All @@ -51,9 +66,8 @@ const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, n
const dataOffset = wasm.stackAlloc(8);
if (wasm._OrtTrainingGetInputOutputCount) {
const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, false);
if (errorCode !== 0) {
checkLastError('Can\'t get session input/output count.');
}
ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.');

return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
Expand All @@ -72,9 +86,7 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput:
for (let i = 0; i < count; i++) {
if (wasm._OrtTrainingGetInputOutputName) {
const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput, false);
if (name === 0) {
checkLastError('Can\'t get input or output name');
}
ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);

namesUTF8Encoded.push(name);
names.push(wasm.UTF8ToString(name));
Expand Down Expand Up @@ -119,9 +131,7 @@ export const createTrainingSessionHandle =
throw new Error(NO_TRAIN_FUNCS_MSG);
}

if (trainingSessionHandle === 0) {
checkLastError('Error occurred when trying to create a TrainingSession.');
}
ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);

[inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
getTrainingModelInputOutputNames(trainingSessionHandle);
Expand Down Expand Up @@ -200,9 +210,8 @@ const moveOutputToTensorMetadataArr = (outputValuesOffset: number, outputCount:
try {
const errorCode = wasm._OrtGetTensorData(
tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
if (errorCode !== 0) {
checkLastError(`Can't access output tensor data on index ${i}.`);
}
ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`);

let tensorDataIndex = tensorDataOffset / 4;
const dataType = wasm.HEAPU32[tensorDataIndex++];
dataOffset = wasm.HEAPU32[tensorDataIndex++];
Expand Down Expand Up @@ -278,9 +287,7 @@ export const runTrainStep = async(
const errorCode = wasm._OrtTrainingRunTrainStep(
trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle);

if (errorCode !== 0) {
checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
}
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
Expand All @@ -300,133 +307,130 @@ export const runTrainStep = async(
}
};

export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean):
number => {
const wasm = getInstance();
const stack = wasm.stackSave();
export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
const wasm = getInstance();
const stack = wasm.stackSave();

try {
const sizeOffset = wasm.stackAlloc(4);
if (wasm._OrtTrainingGetParametersSize) {
const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly);
try {
const sizeOffset = wasm.stackAlloc(4);
if (wasm._OrtTrainingGetParametersSize) {
const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly);
ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size');

if (errorCode !== 0) {
checkLastError('Can\'t get parameters size');
}
return wasm.HEAP32[sizeOffset / 4];
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
wasm.stackRestore(stack);
}
};

return wasm.HEAP32[sizeOffset / 4];
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
wasm.stackRestore(stack);
}
};
export const getContiguousParameters =
async(trainingSessionId: number, trainableOnly: boolean): Promise<TensorMetadata> => {
const wasm = getInstance();
const stack = wasm.stackSave();

export const getContiguousParameters = async(trainingSessionId: number, trainableOnly: boolean):
Promise<TensorMetadata> => {
const wasm = getInstance();
const parametersSize = getParametersSize(trainingSessionId, trainableOnly);
// alloc buffer -- assumes parameters will be of type float32
const stack = wasm.stackSave();
let tensor: number = 0;

const paramsByteLength = 4 * parametersSize;
const paramsOffset = wasm.stackAlloc(paramsByteLength);
const bufferAlloc = wasm.stackAlloc(paramsOffset/4);
wasm.HEAPU8.set(new Float32Array(parametersSize), paramsOffset);

// handles the dimensions-related createTensor parameters
const dimsOffset = wasm.stackAlloc(4);
const dimsIndex = dimsOffset / 4;
wasm.HEAP32[dimsIndex] = parametersSize;
try {
tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum('float32'), paramsOffset, paramsByteLength, dimsOffset, 1,
dataLocationStringToEnum('cpu'));
if (tensor === 0) {
checkLastError(`Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`);
}
wasm.HEAPU32[bufferAlloc] = tensor;
if (wasm._OrtTrainingCopyParametersToBuffer) {
const errCode =
wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly);
if (errCode !== 0) {
checkLastError('Can\'t get contiguous parameters.');
}
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
const tensorTypeAsString = 'float32';
const locationAsString = 'cpu';

const parametersSize = getParametersSize(trainingSessionId, trainableOnly);
let tensor = 0;

const typedArrayConstructor = tensorTypeToTypedArrayConstructor('float32');
const data = new typedArrayConstructor(parametersSize);
const output: TensorMetadata[] = [];
new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength));
output.push(['float32', [parametersSize], data, 'cpu']);
if (output.length > 1 || output.length < 1) {
throw new Error(
`something unexpected happened in the getContiguousParameters function. Expected output length of
const paramsByteLength = 4 * parametersSize;
const paramsOffset = wasm.stackAlloc(paramsByteLength);
wasm.HEAPU8.set(new Float32Array(parametersSize), paramsOffset);

const tensorOffset = wasm.stackAlloc(paramsOffset / 4);

// handles the dimensions-related createTensor parameters
const dims = [parametersSize];

const dimsOffset = wasm.stackAlloc(4);
const dimsIndex = dimsOffset / 4;
wasm.HEAP32[dimsIndex] = parametersSize;

try {
tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length,
dataLocationStringToEnum(locationAsString));
ifErrCodeCheckLastError(
tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false);

wasm.HEAPU32[tensorOffset] = tensor;
if (wasm._OrtTrainingCopyParametersToBuffer) {
const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly);
ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.');

} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}

const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString);
const data = new typedArrayConstructor(parametersSize);
const output: TensorMetadata[] = [];
new Uint8Array(data.buffer, data.byteOffset, data.byteLength)
.set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength));
output.push([tensorTypeAsString, dims, data, locationAsString]);
if (output.length > 1 || output.length < 1) {
throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of
one, got ${output.length}`);
} else {
return output[0];
}
} finally {
console.log('test');
if (tensor !== 0) {
console.log('tensor is not equal to 0');
wasm._OrtReleaseTensor(tensor);
}
console.log('test after ortReleaseTensor call but before stackRestore call');
wasm._free(paramsOffset);
wasm._free(dimsOffset);
wasm._free(bufferAlloc);
wasm.stackRestore(stack);
}
};
} else {
return output[0];
}
} finally {
if (tensor !== 0) {
wasm._OrtReleaseTensor(tensor);
}
wasm._free(paramsOffset);
wasm._free(dimsOffset);
wasm._free(tensorOffset);
wasm.stackRestore(stack);
}
};

export const loadParametersBuffer = async (trainingSessionId: number, buffer: Float32Array, trainableOnly: boolean):
Promise<void> => {
const wasm = getInstance();
const stack = wasm.stackSave();
const bufferCount = buffer.length;
const bufferByteLength = bufferCount * 4;
const bufferOffset = wasm.stackAlloc(bufferByteLength);
wasm.HEAPU8.set(new Uint8Array(buffer.buffer, buffer.byteOffset, buffer.byteLength), bufferOffset);
const dimsOffset = wasm.stackAlloc(4);
wasm.HEAP32[dimsOffset / 4] = bufferCount;
const dimsLength = 1;
let tensor: number = 0;
const bufferAlloc = wasm.stackAlloc(bufferOffset/4);
export const loadParametersBuffer =
async(trainingSessionId: number, buffer: Float32Array, trainableOnly: boolean): Promise<void> => {
const wasm = getInstance();
const stack = wasm.stackSave();

try {
tensor = wasm._OrtCreateTensor(tensorDataTypeStringToEnum('float32'), bufferOffset, bufferByteLength, dimsOffset, dimsLength, dataLocationStringToEnum('cpu'));
if (tensor === 0) {
checkLastError(`Can't create tensor for input/output. session=${trainingSessionId}`);
}
wasm.HEAPU32[bufferAlloc] = tensor;
const tensorTypeAsString = 'float32';
const locationAsString = 'cpu';

if (wasm._OrtTrainingCopyParametersFromBuffer) {
const errCode =
wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly);
const bufferCount = buffer.length;
const bufferByteLength = bufferCount * 4;
const bufferOffset = wasm.stackAlloc(bufferByteLength);
wasm.HEAPU8.set(new Uint8Array(buffer.buffer, buffer.byteOffset, buffer.byteLength), bufferOffset);
const dimsOffset = wasm.stackAlloc(4);
wasm.HEAP32[dimsOffset / 4] = bufferCount;
const dimsLength = 1;
let tensor = 0;
const bufferAlloc = wasm.stackAlloc(bufferOffset / 4);

if (errCode !== 0) {
checkLastError('Can\'t copy buffer to parameters.');
}
try {
tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength,
dataLocationStringToEnum(locationAsString));
ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false);

} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
wasm.HEAPU32[bufferAlloc] = tensor;

} finally {
if (tensor !== 0) {
if (wasm._OrtTrainingCopyParametersFromBuffer) {
const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly);
ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.');
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
if (tensor !== 0) {
wasm._OrtReleaseTensor(tensor);
}
wasm.stackRestore(stack);
wasm._free(bufferAlloc);
wasm._free(bufferOffset);
wasm._free(dimsOffset);
}
}
wasm.stackRestore(stack);
wasm._free(bufferAlloc);
wasm._free(bufferOffset);
wasm._free(dimsOffset);
}
};

export const releaseTrainingSessionAndCheckpoint =
(checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
Expand Down

0 comments on commit c74112e

Please sign in to comment.