Skip to content

Commit

Permalink
getContiguousParameters & loadParametersBuffer impl
Browse files Browse the repository at this point in the history
wrote untested getContiguousParameters method

updated getInputOutputCount and getInputOutputNames signature, added more informative error message

updated parameter names according to suggestions

semi working getContiguousParameters impl

working getContiguousParams, started writing loadParametersBuffer

working version of loadParametersBuffer
  • Loading branch information
carzh committed Oct 24, 2023
1 parent 3ca1c27 commit b44b705
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 99 deletions.
5 changes: 3 additions & 2 deletions js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ export interface TrainingSessionHandler extends SessionHandler {
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;

loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
}

/**
Expand Down
12 changes: 8 additions & 4 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,16 @@ export class TrainingSession implements TrainingSessionInterface {
return returnValue;
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
async getParametersSize(trainableOnly: boolean): Promise<number> {
return this.handler.getParametersSize(trainableOnly);
}

async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void> {
return this.handler.loadParametersBuffer(array, trainableOnly);
}

async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
return this.handler.getContiguousParameters(trainableOnly);
}

async release(): Promise<void> {
Expand Down
13 changes: 11 additions & 2 deletions js/common/lib/training-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js';

/* eslint-disable @typescript-eslint/no-redeclare */
Expand Down Expand Up @@ -49,21 +50,29 @@ export interface TrainingSession {
// #endregion

// #region copy parameters

/**
* Retrieves the size of all parameters for the training state.
*
* @param trainableOnly skips non-trainable parameters when true.
*/
getParametersSize(trainableOnly: boolean): Promise<number>;

/**
* Copies from a buffer containing parameters to the TrainingSession parameters.
*
* @param buffer - buffer containing parameters
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise.
*/
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void>;

/**
* Copies from the TrainingSession parameters to a buffer.
*
* @param trainableOnly - True if trainable parameters only to be copied, false othrwise.
* @returns A promise that resolves to a buffer of the requested parameters.
*/
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
// #endregion

// #region release()
Expand Down
6 changes: 4 additions & 2 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtTrainingCopyParametersFromBuffer?
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;

_OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number;
_OrtTrainingGetInputOutputCount?
(trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
_OrtTrainingGetInputOutputName?
(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;

_OrtTrainingReleaseSession?(trainingHandle: number): void;
// #endregion
Expand Down
23 changes: 15 additions & 8 deletions js/web/lib/wasm/session-handler-training.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common';
import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common';

import {SerializableModeldata} from './proxy-messages';
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize,
releaseTrainingSessionAndCheckpoint, runTrainStep, loadParametersBuffer} from './wasm-training-core-impl';

export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}
async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}
private sessionId: number;
private checkpointId: number;

Expand Down Expand Up @@ -104,6 +99,18 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
return resultMap;
}

async getParametersSize(trainableOnly: boolean): Promise<number> {
return getParametersSize(this.sessionId, trainableOnly);
}

async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void> {
await loadParametersBuffer(this.sessionId, array, trainableOnly);
}
async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly);
return decodeTensorMetadata(tensorResult);
}

async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
Expand Down
Loading

0 comments on commit b44b705

Please sign in to comment.