From 3234487385054a369dc7061d29030a3ae233dc31 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 4 Dec 2024 16:44:09 -0800 Subject: [PATCH] [js] remove more unused training types (#22753) ### Description remove more unused training types --- js/.eslintrc.js | 13 -- js/common/lib/backend.ts | 36 ---- js/common/lib/env.ts | 2 - js/common/lib/index.ts | 1 - js/common/lib/training-session-impl.ts | 273 ------------------------- js/common/lib/training-session.ts | 206 ------------------- 6 files changed, 531 deletions(-) delete mode 100644 js/common/lib/training-session-impl.ts delete mode 100644 js/common/lib/training-session.ts diff --git a/js/.eslintrc.js b/js/.eslintrc.js index bd1e9061355f5..462e417df1d66 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -198,19 +198,6 @@ module.exports = { '_OrtReleaseTensor', '_OrtRun', '_OrtRunWithBinding', - '_OrtTrainingCopyParametersFromBuffer', - '_OrtTrainingCopyParametersToBuffer', - '_OrtTrainingCreateSession', - '_OrtTrainingEvalStep', - '_OrtTrainingGetModelInputOutputCount', - '_OrtTrainingGetModelInputOutputName', - '_OrtTrainingGetParametersSize', - '_OrtTrainingLazyResetGrad', - '_OrtTrainingLoadCheckpoint', - '_OrtTrainingOptimizerStep', - '_OrtTrainingReleaseCheckpoint', - '_OrtTrainingReleaseSession', - '_OrtTrainingRunTrainStep', ], }, ], diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index e27e67622aa82..e63f9c6c9147f 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -3,7 +3,6 @@ import { InferenceSession } from './inference-session.js'; import { OnnxValue } from './onnx-value.js'; -import { TrainingSession } from './training-session.js'; /** * @ignore @@ -42,33 +41,6 @@ export interface InferenceSessionHandler extends SessionHandler { ): Promise; } -/** - * Represent a handler instance of a training inference session. - * - * @ignore - */ -export interface TrainingSessionHandler extends SessionHandler { - readonly evalInputNames: readonly string[]; - readonly evalOutputNames: readonly string[]; - - lazyResetGrad(): Promise; - runTrainStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise; - runOptimizerStep(options: InferenceSession.RunOptions): Promise; - runEvalStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise; - - getParametersSize(trainableOnly: boolean): Promise; - loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; - getContiguousParameters(trainableOnly: boolean): Promise; -} - /** * Represent a backend that provides implementation of model inferencing. * @@ -84,14 +56,6 @@ export interface Backend { uriOrBuffer: string | Uint8Array, options?: InferenceSession.SessionOptions, ): Promise; - - createTrainingSessionHandler?( - checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, - trainModelUriOrBuffer: TrainingSession.UriOrBuffer, - evalModelUriOrBuffer: TrainingSession.UriOrBuffer, - optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer, - options: InferenceSession.SessionOptions, - ): Promise; } export { registerBackend } from './backend-impl.js'; diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 642a897a90d26..adb6a440cf22a 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -14,7 +14,6 @@ export declare namespace Env { * If not modified, the filename of the .wasm file is: * - `ort-wasm-simd-threaded.wasm` for default build * - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN) - * - `ort-training-wasm-simd-threaded.wasm` for training build */ wasm?: URL | string; /** @@ -25,7 +24,6 @@ export declare namespace Env { * If not modified, the filename of the .mjs file is: * - `ort-wasm-simd-threaded.mjs` for default build * - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN) - * - `ort-training-wasm-simd-threaded.mjs` for training build */ mjs?: URL | string; } diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index 3ed56b3c2e812..d75e6a477258d 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -26,4 +26,3 @@ export * from './tensor-factory.js'; export * from './trace.js'; export * from './onnx-model.js'; export * from './onnx-value.js'; -export * from './training-session.js'; diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts deleted file mode 100644 index 21dbe5fe51bb9..0000000000000 --- a/js/common/lib/training-session-impl.ts +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { resolveBackendAndExecutionProviders } from './backend-impl.js'; -import { SessionHandler, TrainingSessionHandler } from './backend.js'; -import { InferenceSession as InferenceSession } from './inference-session.js'; -import { OnnxValue } from './onnx-value.js'; -import { Tensor } from './tensor.js'; -import { TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions } from './training-session.js'; - -type SessionOptions = InferenceSession.SessionOptions; -type FeedsType = InferenceSession.FeedsType; -type FetchesType = InferenceSession.FetchesType; -type ReturnType = InferenceSession.ReturnType; -type RunOptions = InferenceSession.RunOptions; - -const noBackendErrMsg: string = - 'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files."; - -export class TrainingSession implements TrainingSessionInterface { - private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) { - this.handler = handler; - this.hasOptimizerModel = hasOptimizerModel; - this.hasEvalModel = hasEvalModel; - } - private handler: TrainingSessionHandler; - private hasOptimizerModel: boolean; - private hasEvalModel: boolean; - - get trainingInputNames(): readonly string[] { - return this.handler.inputNames; - } - get trainingOutputNames(): readonly string[] { - return this.handler.outputNames; - } - - get evalInputNames(): readonly string[] { - if (this.hasEvalModel) { - return this.handler.evalInputNames; - } else { - throw new Error('This training session has no evalModel loaded.'); - } - } - get evalOutputNames(): readonly string[] { - if (this.hasEvalModel) { - return this.handler.evalOutputNames; - } else { - throw new Error('This training session has no evalModel loaded.'); - } - } - - static async create( - trainingOptions: TrainingSessionCreateOptions, - sessionOptions?: SessionOptions, - ): Promise { - const evalModel: string | Uint8Array = trainingOptions.evalModel || ''; - const optimizerModel: string | Uint8Array = trainingOptions.optimizerModel || ''; - const options: SessionOptions = sessionOptions || {}; - - // resolve backend, update session options with validated EPs, and create session handler - const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); - if (backend.createTrainingSessionHandler) { - const handler = await backend.createTrainingSessionHandler( - trainingOptions.checkpointState, - trainingOptions.trainModel, - evalModel, - optimizerModel, - optionsWithValidatedEPs, - ); - return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); - } else { - throw new Error(noBackendErrMsg); - } - } - - /** - * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from - * the given parameters to SessionHandler.FetchesType and RunOptions. - * - * @param inputNames the feeds object is checked that they contain all input names in the provided list of input - * names. - * @param outputNames the fetches object is checked that their keys match up with valid names in the list of output - * names. - * @param feeds the required input - * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object - * @param arg2 optional RunOptions object. - * @returns - */ - typeNarrowingForRunStep( - inputNames: readonly string[], - outputNames: readonly string[], - feeds: FeedsType, - arg1?: FetchesType | RunOptions, - arg2?: RunOptions, - ): [SessionHandler.FetchesType, RunOptions] { - const fetches: { [name: string]: OnnxValue | null } = {}; - let options: RunOptions = {}; - // check inputs - if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { - throw new TypeError( - "'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.", - ); - } - - let isFetchesEmpty = true; - // determine which override is being used - if (typeof arg1 === 'object') { - if (arg1 === null) { - throw new TypeError('Unexpected argument[1]: cannot be null.'); - } - if (arg1 instanceof Tensor) { - throw new TypeError("'fetches' cannot be a Tensor"); - } - - if (Array.isArray(arg1)) { - if (arg1.length === 0) { - throw new TypeError("'fetches' cannot be an empty array."); - } - isFetchesEmpty = false; - // output names - for (const name of arg1) { - if (typeof name !== 'string') { - throw new TypeError("'fetches' must be a string array or an object."); - } - if (outputNames.indexOf(name) === -1) { - throw new RangeError(`'fetches' contains invalid output name: ${name}.`); - } - fetches[name] = null; - } - - if (typeof arg2 === 'object' && arg2 !== null) { - options = arg2; - } else if (typeof arg2 !== 'undefined') { - throw new TypeError("'options' must be an object."); - } - } else { - // decide whether arg1 is fetches or options - // if any output name is present and its value is valid OnnxValue, we consider it fetches - let isFetches = false; - const arg1Keys = Object.getOwnPropertyNames(arg1); - for (const name of outputNames) { - if (arg1Keys.indexOf(name) !== -1) { - const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; - if (v === null || v instanceof Tensor) { - isFetches = true; - isFetchesEmpty = false; - fetches[name] = v; - } - } - } - - if (isFetches) { - if (typeof arg2 === 'object' && arg2 !== null) { - options = arg2; - } else if (typeof arg2 !== 'undefined') { - throw new TypeError("'options' must be an object."); - } - } else { - options = arg1 as RunOptions; - } - } - } else if (typeof arg1 !== 'undefined') { - throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'."); - } - - // check if all inputs are in feed - for (const name of inputNames) { - if (typeof feeds[name] === 'undefined') { - throw new Error(`input '${name}' is missing in 'feeds'.`); - } - } - - // if no fetches is specified, we use the full output names list - if (isFetchesEmpty) { - for (const name of outputNames) { - fetches[name] = null; - } - } - - return [fetches, options]; - } - - /** - * Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler - * and changes it into a map of Tensors. - * - * @param results - * @returns - */ - convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType { - const returnValue: { [name: string]: OnnxValue } = {}; - for (const key in results) { - if (Object.hasOwnProperty.call(results, key)) { - const result = results[key]; - if (result instanceof Tensor) { - returnValue[key] = result; - } else { - returnValue[key] = new Tensor(result.type, result.data, result.dims); - } - } - } - return returnValue; - } - - async lazyResetGrad(): Promise { - await this.handler.lazyResetGrad(); - } - - runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; - runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; - async runTrainStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { - const [fetches, options] = this.typeNarrowingForRunStep( - this.trainingInputNames, - this.trainingOutputNames, - feeds, - arg1, - arg2, - ); - const results = await this.handler.runTrainStep(feeds, fetches, options); - return this.convertHandlerReturnTypeToMapOfTensors(results); - } - - async runOptimizerStep(options?: InferenceSession.RunOptions | undefined): Promise { - if (this.hasOptimizerModel) { - await this.handler.runOptimizerStep(options || {}); - } else { - throw new Error('This TrainingSession has no OptimizerModel loaded.'); - } - } - - runEvalStep(feeds: FeedsType, options?: RunOptions | undefined): Promise; - runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions | undefined): Promise; - async runEvalStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { - if (this.hasEvalModel) { - const [fetches, options] = this.typeNarrowingForRunStep( - this.evalInputNames, - this.evalOutputNames, - feeds, - arg1, - arg2, - ); - const results = await this.handler.runEvalStep(feeds, fetches, options); - return this.convertHandlerReturnTypeToMapOfTensors(results); - } else { - throw new Error('This TrainingSession has no EvalModel loaded.'); - } - } - - async getParametersSize(trainableOnly = true): Promise { - return this.handler.getParametersSize(trainableOnly); - } - - async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise { - const paramsSize = await this.getParametersSize(trainableOnly); - // checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number - // of parameters - if (array.length !== 4 * paramsSize) { - throw new Error( - 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + - 'the model. Please use getParametersSize method to check.', - ); - } - return this.handler.loadParametersBuffer(array, trainableOnly); - } - - async getContiguousParameters(trainableOnly = true): Promise { - return this.handler.getContiguousParameters(trainableOnly); - } - - async release(): Promise { - return this.handler.dispose(); - } -} diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts deleted file mode 100644 index 45dcafc46deb5..0000000000000 --- a/js/common/lib/training-session.ts +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession } from './inference-session.js'; -import { OnnxValue } from './onnx-value.js'; -import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js'; - -/* eslint-disable @typescript-eslint/no-redeclare */ - -export declare namespace TrainingSession { - /** - * Either URI file path (string) or Uint8Array containing model or checkpoint information. - */ - type UriOrBuffer = string | Uint8Array; -} - -/** - * Represent a runtime instance of an ONNX training session, - * which contains a model that can be trained, and, optionally, - * an eval and optimizer model. - */ -export interface TrainingSession { - // #region run() - - /** - * Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of - * runOptimizerStep. - */ - lazyResetGrad(): Promise; - - /** - * Run TrainStep asynchronously with the given feeds and options. - * - * @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for - detail. - * @param options - Optional. A set of options that controls the behavior of model training. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. - */ - runTrainStep( - feeds: InferenceSession.FeedsType, - options?: InferenceSession.RunOptions, - ): Promise; - - /** - * Run a single train step with the given inputs and options. - * - * @param feeds - Representation of the model input. - * @param fetches - Representation of the model output. - * detail. - * @param options - Optional. A set of options that controls the behavior of model training. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding - values. - */ - runTrainStep( - feeds: InferenceSession.FeedsType, - fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions, - ): Promise; - - /** - * Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model. - * - * @param options - Optional. A set of options that controls the behavior of model optimizing. - */ - runOptimizerStep(options?: InferenceSession.RunOptions): Promise; - - /** - * Run a single eval step with the given inputs and options using the eval model. - * - * @param feeds - Representation of the model input. - * @param options - Optional. A set of options that controls the behavior of model eval step. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding - values. - */ - runEvalStep( - feeds: InferenceSession.FeedsType, - options?: InferenceSession.RunOptions, - ): Promise; - - /** - * Run a single eval step with the given inputs and options using the eval model. - * - * @param feeds - Representation of the model input. - * @param fetches - Representation of the model output. - * detail. - * @param options - Optional. A set of options that controls the behavior of model eval step. - * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding - values. - */ - runEvalStep( - feeds: InferenceSession.FeedsType, - fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions, - ): Promise; - - // #endregion - - // #region copy parameters - - /** - * Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of - * the parameters) elements of all the parameters in the training state. - * - * @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true. - */ - getParametersSize(trainableOnly: boolean): Promise; - - /** - * Copies parameter values from the given buffer to the training state. Currently, only supporting models with - * parameters of type Float32. - * - * @param buffer - A Uint8Array representation of Float32 parameters. - * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. - */ - loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; - - /** - * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. - * Currently, only supporting models with parameters of type Float32. - * - * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters - * for which requires_grad is set to true. Default value is true. - * @returns A promise that resolves to a Float32 OnnxValue of the requested parameters. - */ - getContiguousParameters(trainableOnly: boolean): Promise; - // #endregion - - // #region release() - - /** - * Release the inference session and the underlying resources. - */ - release(): Promise; - // #endregion - - // #region metadata - - /** - * Get input names of the loaded training model. - */ - readonly trainingInputNames: readonly string[]; - - /** - * Get output names of the loaded training model. - */ - readonly trainingOutputNames: readonly string[]; - - /** - * Get input names of the loaded eval model. Is an empty array if no eval model is loaded. - */ - readonly evalInputNames: readonly string[]; - - /** - * Get output names of the loaded eval model. Is an empty array if no eval model is loaded. - */ - readonly evalOutputNames: readonly string[]; - - // #endregion -} - -/** - * Represents the optional parameters that can be passed into the TrainingSessionFactory. - */ -export interface TrainingSessionCreateOptions { - /** - * URI or buffer for a .ckpt file that contains the checkpoint for the training model. - */ - checkpointState: TrainingSession.UriOrBuffer; - /** - * URI or buffer for the .onnx training file. - */ - trainModel: TrainingSession.UriOrBuffer; - /** - * Optional. URI or buffer for the .onnx optimizer model file. - */ - optimizerModel?: TrainingSession.UriOrBuffer; - /** - * Optional. URI or buffer for the .onnx eval model file. - */ - evalModel?: TrainingSession.UriOrBuffer; -} - -/** - * Defines method overload possibilities for creating a TrainingSession. - */ -export interface TrainingSessionFactory { - // #region create() - - /** - * Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions - * - * @param trainingOptions specify models and checkpoints to load into the Training Session - * @param sessionOptions specify configuration for training session behavior - * - * @returns Promise that resolves to a TrainingSession object - */ - create( - trainingOptions: TrainingSessionCreateOptions, - sessionOptions?: InferenceSession.SessionOptions, - ): Promise; - - // #endregion -} - -// eslint-disable-next-line @typescript-eslint/naming-convention -export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl;