From fa9f54576acdd5a9cf07e34d4e46ddea9595ded3 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 22 Nov 2023 14:25:27 -0800 Subject: [PATCH] added suggestion --- js/common/lib/backend.ts | 2 +- js/web/lib/wasm/wasm-training-core-impl.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index fd2e8bb74bbf5..67d283b694955 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -50,7 +50,7 @@ export interface TrainingSessionHandler extends SessionHandler { options: InferenceSession.RunOptions): Promise; getParametersSize(trainableOnly: boolean): Promise; - loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; getContiguousParameters(trainableOnly: boolean): Promise; } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 251bc612ab085..c0a4235113148 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -406,8 +406,8 @@ export const loadParametersBuffer = const locationAsString = 'cpu'; // allocates & copies JavaScript buffer to WASM heap - const bufferCount = getParametersSize(trainingSessionId, trainableOnly); const bufferByteLength = buffer.length; + const bufferCount = bufferByteLength / 4; const bufferOffset = wasm._malloc(bufferByteLength); wasm.HEAPU8.set(buffer, bufferOffset);