Skip to content

Commit

Permalink
light refactoring
Browse files Browse the repository at this point in the history
renamed session-handler for inference files

lint + format
  • Loading branch information
carzh committed Oct 18, 2023
1 parent 46a9677 commit 3ca1c27
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 119 deletions.
23 changes: 11 additions & 12 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import {resolveBackend} from './backend-impl.js';
import {TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {Tensor} from './tensor.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';
import { OnnxValue } from './onnx-value.js';
import { Tensor } from './tensor.js';

type SessionOptions = InferenceSession.SessionOptions;
type FeedsType = InferenceSession.FeedsType;
Expand Down Expand Up @@ -49,17 +49,8 @@ export class TrainingSession implements TrainingSessionInterface {
}
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}

async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}

runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(
feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
Expand Down Expand Up @@ -159,6 +150,14 @@ export class TrainingSession implements TrainingSessionInterface {
return returnValue;
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}

async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}

async release(): Promise<void> {
return this.handler.dispose();
}
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-onnxjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {Session} from './onnxjs/session';
import {OnnxjsSessionHandler} from './onnxjs/session-handler';
import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference';

class OnnxjsBackend implements Backend {
// eslint-disable-next-line @typescript-eslint/no-empty-function
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-wasm-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training';
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {cpus} from 'node:os';
import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference';

/**
* This function initializes all flags for WebAssembly.
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env, InferenceSession, SessionHandler, TrainingSessionHandler, Tensor} from 'onnxruntime-common';
import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common';

import {SerializableModeldata} from './proxy-messages';
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, runTrainStep,
releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl';
import { encodeTensorMetadata, decodeTensorMetadata } from './session-handler';
import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';

export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
Expand Down Expand Up @@ -99,7 +98,7 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);

const resultMap: SessionHandler.ReturnType = {};
for (let i = 0; i < results. length; i++) {
for (let i = 0; i < results.length; i++) {
resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
}
return resultMap;
Expand All @@ -109,5 +108,4 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
}

}
222 changes: 124 additions & 98 deletions js/web/lib/wasm/wasm-training-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import {InferenceSession, Tensor} from 'onnxruntime-common';

import { prepareInputOutputTensor } from './wasm-core-impl';
import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
import { tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor } from './wasm-common';
import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {prepareInputOutputTensor} from './wasm-core-impl';
import {getInstance} from './wasm-factory';
import {checkLastError} from './wasm-utils';
import { setRunOptions } from './run-options';

const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. ' +
'Make sure to use the onnxruntime-training package for training functionality.';
Expand Down Expand Up @@ -143,10 +143,113 @@ export const createTrainingSessionHandle =
}
};

/**
* Prepares input and output tensors by creating the tensors in the WASM side then moving them to the heap
* @param trainingSessionId
* @param indices for each tensor, the index of the input or output name that the tensor corresponds with
* @param tensors list of TensorMetaData
* @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting
* handles of the allocated tensors on the heap
* @param inputOutputAllocs modified in-place by this method
* @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor
*/
const createAndAllocateTensors =
(trainingSessionId: number, indices: number[], tensors: Array<TensorMetadata|null>, tensorHandles: number[],
inputOutputAllocs: number[], indexAdd: number) => {
const wasm = getInstance();

const count = indices.length;
const valuesOffset = wasm.stackAlloc(count * 4);

// creates the tensors
for (let i = 0; i < count; i++) {
prepareInputOutputTensor(
tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]);
}

// moves to heap
let valuesIndex = valuesOffset / 4;
for (let i = 0; i < count; i++) {
wasm.HEAPU32[valuesIndex++] = tensorHandles[i];
}

return valuesOffset;
};

/**
* Move output tensors from the heap to an array
* @param outputValuesOffset
* @param outputCount
* @returns
*/
const moveOutputToTensorMetadataArr =
(outputValuesOffset: number, outputCount: number) => {
const wasm = getInstance();
const output: TensorMetadata[] = [];

for (let i = 0; i < outputCount; i++) {
const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];

const beforeGetTensorDataStack = wasm.stackSave();
// stack allocate 4 pointer value
const tensorDataOffset = wasm.stackAlloc(4 * 4);

const keepOutputTensor = false;
let type: Tensor.Type|undefined, dataOffset = 0;
try {
const errorCode = wasm._OrtGetTensorData(
tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
if (errorCode !== 0) {
checkLastError(`Can't access output tensor data on index ${i}.`);
}
let tensorDataIndex = tensorDataOffset / 4;
const dataType = wasm.HEAPU32[tensorDataIndex++];
dataOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsLength = wasm.HEAPU32[tensorDataIndex++];
const dims = [];
for (let i = 0; i < dimsLength; i++) {
dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
}
wasm._OrtFree(dimsOffset);

const size = dims.reduce((a, b) => a * b, 1);
type = tensorDataTypeEnumToString(dataType);

if (type === 'string') {
const stringData: string[] = [];
let dataIndex = dataOffset / 4;
for (let i = 0; i < size; i++) {
const offset = wasm.HEAPU32[dataIndex++];
const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
}
output.push([type, dims, stringData, 'cpu']);
} else {
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
const data = new typedArrayConstructor(size);
new Uint8Array(data.buffer, data.byteOffset, data.byteLength)
.set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength));
output.push([type, dims, data, 'cpu']);
}
} finally {
wasm.stackRestore(beforeGetTensorDataStack);
if (type === 'string' && dataOffset) {
wasm._free(dataOffset);
}
if (!keepOutputTensor) {
wasm._OrtReleaseTensor(tensor);
}
}
}

return output;
};

export const runTrainStep = async(
trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
outputTensors: Array<TensorMetadata|null>, options: InferenceSession.RunOptions): Promise<TensorMetadata[]> => {
const wasm = getInstance();
trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
outputTensors: Array<TensorMetadata|null>, options: InferenceSession.RunOptions): Promise<TensorMetadata[]> => {
const wasm = getInstance();

const inputCount = inputIndices.length;
const outputCount = outputIndices.length;
Expand All @@ -159,108 +262,31 @@ export const runTrainStep = async(
const inputOutputAllocs: number[] = [];

const beforeRunStack = wasm.stackSave();
const inputValuesOffset = wasm.stackAlloc(inputCount * 4);
const outputValuesOffset = wasm.stackAlloc(outputCount * 4);

try {
// prepare parameters by moving them to heap
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);

// TODO:
// move all input and output processing -> wasm heap to one helper method????
// can abstract out the similarities between input and output
// create input tensors
for (let i = 0; i < inputCount; i++) {
prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, trainingSessionId, inputIndices[i]);
}

// create output tensors
for (let i = 0; i < outputCount; i++) {
prepareInputOutputTensor(
outputTensors[i], outputTensorHandles, inputOutputAllocs, trainingSessionId, inputCount + outputIndices[i]);
}

let inputValuesIndex = inputValuesOffset / 4;
let outputValuesIndex = outputValuesOffset / 4;
for (let i = 0; i < inputCount; i++) {
wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i];
}
for (let i = 0; i < outputCount; i++) {
wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i];
}

let errorCode: number;
// handle inputs -- you don't want anything added to the index
const inputValuesOffset = createAndAllocateTensors(
trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0);
// handle outputs
// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
const outputValuesOffset = createAndAllocateTensors(
trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount);

if (wasm._OrtTrainingRunTrainStep) {
errorCode = await wasm._OrtTrainingRunTrainStep(trainingSessionId, inputValuesOffset, inputCount,
outputValuesOffset, outputCount, runOptionsHandle);
}
else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
const errorCode = wasm._OrtTrainingRunTrainStep(
trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle);

if (errorCode !== 0) {
checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
}

const output: TensorMetadata[] = [];

for (let i = 0; i < outputCount; i++) {
const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];

const beforeGetTensorDataStack = wasm.stackSave();
// stack allocate 4 pointer value
const tensorDataOffset = wasm.stackAlloc(4 * 4);

let keepOutputTensor = false;
let type: Tensor.Type|undefined, dataOffset = 0;
try {
const errorCode = wasm._OrtGetTensorData(
tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
if (errorCode !== 0) {
checkLastError(`Can't access output tensor data on index ${i}.`);
}
let tensorDataIndex = tensorDataOffset / 4;
const dataType = wasm.HEAPU32[tensorDataIndex++];
dataOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsLength = wasm.HEAPU32[tensorDataIndex++];
const dims = [];
for (let i = 0; i < dimsLength; i++) {
dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
}
wasm._OrtFree(dimsOffset);

const size = dims.reduce((a, b) => a * b, 1);
type = tensorDataTypeEnumToString(dataType);

if (type === 'string') {
const stringData: string[] = [];
let dataIndex = dataOffset / 4;
for (let i = 0; i < size; i++) {
const offset = wasm.HEAPU32[dataIndex++];
const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
}
output.push([type, dims, stringData, 'cpu']);
} else {
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
const data = new typedArrayConstructor(size);
new Uint8Array(data.buffer, data.byteOffset, data.byteLength)
.set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength));
output.push([type, dims, data, 'cpu']);
}
} finally {
wasm.stackRestore(beforeGetTensorDataStack);
if (type === 'string' && dataOffset) {
wasm._free(dataOffset);
}
if (!keepOutputTensor) {
wasm._OrtReleaseTensor(tensor);
}
if (errorCode !== 0) {
checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
}
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}

return output;
return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount);
} finally {
wasm.stackRestore(beforeRunStack);

Expand Down

0 comments on commit 3ca1c27

Please sign in to comment.