Skip to content

Commit

Permalink
added package.json modification for training + minimized training art…
Browse files Browse the repository at this point in the history
…ifact
  • Loading branch information
carzh committed Aug 31, 2023
1 parent 94dfa16 commit d3c8f7e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 13 deletions.
4 changes: 3 additions & 1 deletion js/web/lib/backend-wasm-with-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@

import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'
import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
/* eslint-disable @typescript-eslint/no-unused-vars */
async createTrainingSessionHandler(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
}
/* eslint-enable @typescript-eslint/no-unused-vars */
}

export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();
10 changes: 1 addition & 9 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Backend, env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common';
import {Backend, env, InferenceSession, SessionHandler} from 'onnxruntime-common';
import {cpus} from 'os';

import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper';
Expand Down Expand Up @@ -48,14 +48,6 @@ export class OnnxruntimeWebAssemblyBackend implements Backend {
await handler.loadModel(pathOrBuffer, options);
return Promise.resolve(handler);
}
/* eslint-disable @typescript-eslint/no-unused-vars */
async createTrainingSessionHandler(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Training not supported on Nodejs');
}
/* eslint-enable @typescript-eslint/no-unused-vars */
}

export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
2 changes: 1 addition & 1 deletion js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// So we import code inside the if-clause to allow terser remove the code safely.

export * from 'onnxruntime-common';
import {registerBackend, env, Backend} from 'onnxruntime-common';
import {registerBackend, env} from 'onnxruntime-common';
import {version} from './version';

if (!BUILD_DEFS.DISABLE_WEBGL) {
Expand Down
4 changes: 4 additions & 0 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@
"./webgpu": {
"types": "./types.d.ts",
"default": "./dist/ort.webgpu.min.js"
},
"./training": {
"types": "./types.d.ts",
"default": "./dist/ort-web.training.wasm.min.js"
}
},
"types": "./types.d.ts",
Expand Down
26 changes: 26 additions & 0 deletions js/web/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const ROOT_FOLDER = path.join(__dirname, '..');
const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding');
const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js');
const TRAINING_WASM_BINDING_SIMD_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.js');
const TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.min.js');
const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js');
const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js');
const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js');
Expand Down Expand Up @@ -93,6 +94,31 @@ if (WASM) {
*/
`;

npmlog.info('Build', 'Minimizing file "ort-training-wasm-simd.js"...');
try {
const terser = spawnSync(
'npx',
[
'terser', TRAINING_WASM_BINDING_SIMD_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false',
'--mangle', 'reserved=[_scriptDir]', '--module'
],
{shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER});
if (terser.status !== 0) {
console.error(terser.error);
process.exit(terser.status === null ? undefined : terser.status);
}

fs.writeFileSync(TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH, terser.stdout);
fs.writeFileSync(TRAINING_WASM_BINDING_SIMD_JS_PATH, `${COPYRIGHT_BANNER}${terser.stdout}`);

validateFile(TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH);
validateFile(TRAINING_WASM_BINDING_SIMD_JS_PATH);
} catch (e) {
npmlog.error('Build', `Failed to run terser on ort-training-wasm-simd.js. ERR: ${e}`);
throw e;
}
npmlog.info('Build', 'Minimizing file "ort-training-wasm-simd.js"... DONE');

npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.js"...');
try {
const terser = spawnSync(
Expand Down
5 changes: 3 additions & 2 deletions js/web/webpack.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ function buildConfig({filename, format, target, mode, devtool, build_defs}) {
config.resolve.alias['./binding/ort-wasm-threaded.js'] = './binding/ort-wasm-threaded.min.js';
config.resolve.alias['./binding/ort-wasm-threaded-simd.jsep.js'] = './binding/ort-wasm-threaded-simd.jsep.min.js';
config.resolve.alias['./binding/ort-wasm-threaded.worker.js'] = './binding/ort-wasm-threaded.min.worker.js';
config.resolve.alias['./binding/ort-training-wasm-simd.js'] = './binding/ort-training-wasm-simd.min.js';

const options = defaultTerserPluginOptions(target);
options.terserOptions.format.preamble = COPYRIGHT_BANNER;
Expand Down Expand Up @@ -291,8 +292,8 @@ module.exports = () => {
// ort-web.es5.min.js
buildOrtWebConfig({suffix: '.es5.min', target: 'es5'}),

// ort.wasm.min.js
buildOrtConfig({
// ort-web.training.wasm.min.js
buildOrtWebConfig({
suffix: '.training.wasm.min',
build_defs: {
...DEFAULT_BUILD_DEFS,
Expand Down

0 comments on commit d3c8f7e

Please sign in to comment.