From b08de6feb990f52c1e61e42b539e09a86bd92c91 Mon Sep 17 00:00:00 2001 From: Miao Bin Date: Thu, 28 Nov 2024 18:57:10 +0800 Subject: [PATCH 1/2] Support iobinding for Whisper Base demo --- demos/whisper-base/main.js | 6 ++- demos/whisper-base/whisper.js | 86 +++++++++++++++++++++++++---------- 2 files changed, 68 insertions(+), 24 deletions(-) diff --git a/demos/whisper-base/main.js b/demos/whisper-base/main.js index a89ae49..f93a64a 100644 --- a/demos/whisper-base/main.js +++ b/demos/whisper-base/main.js @@ -88,6 +88,7 @@ const SpeechStates = { let speechState = SpeechStates.UNINITIALIZED; let mask4d = true; // use 4D mask input for decoder models +let iobinding = true; let streamingNode = null; let sourceNode = null; let audioChunks = []; // member {isSubChunk: boolean, data: Float32Array} @@ -161,6 +162,9 @@ function updateConfig() { if (pair[0] == "mask_4d") { mask4d = pair[1].toLowerCase() === "true"; } + if (pair[0] == 'iobinding') { + iobinding = pair[1].toLowerCase() === 'true'; + } } } @@ -664,7 +668,7 @@ const main = async () => { const whisper_url = location.href.includes("github.io") ? "https://huggingface.co/microsoft/whisper-base-webnn/resolve/main/" : "./models/"; - whisper = new Whisper(whisper_url, provider, deviceType, dataType, mask4d); + whisper = new Whisper(whisper_url, provider, deviceType, dataType, mask4d, iobinding); await whisper.create_whisper_processor(); await whisper.create_whisper_tokenizer(); await whisper.create_ort_sessions(); diff --git a/demos/whisper-base/whisper.js b/demos/whisper-base/whisper.js index 3583059..f5a45cc 100644 --- a/demos/whisper-base/whisper.js +++ b/demos/whisper-base/whisper.js @@ -40,12 +40,13 @@ if ( // wrapper around onnxruntime and model export class Whisper { - constructor(url, provider, deviceType = "gpu", dataType, mask_4d = true) { + constructor(url, provider, deviceType = "gpu", dataType, mask_4d = true, iobinding) { this.url = url; this.provider = provider; this.deviceType = deviceType; this.dataType = dataType; this.mask_4d = mask_4d; + this.iobinding = iobinding && deviceType == 'gpu'; ort.env.wasm.simd = true; this.models = { @@ -101,6 +102,19 @@ export class Whisper { deviceType: this.deviceType, }, ], + preferredOutputLocation: this.iobinding ? (() => { + const pairs = {}; + pairs['last_hidden_state'] = "ml-tensor"; + for (let i = 0; i < 6; i++) { + pairs[`padded_present_key_values.${i}.decoder.key`] = "ml-tensor"; + pairs[`padded_present_key_values.${i}.decoder.value`] = "ml-tensor"; + pairs[`present_key_values.${i}.encoder.key`] = "ml-tensor"; + pairs[`present_key_values.${i}.encoder.value`] = "ml-tensor"; + pairs[`updated_present_key_values.${i}.decoder.key`] = "ml-tensor"; + pairs[`updated_present_key_values.${i}.decoder.value`] = "ml-tensor"; + } + return pairs; + })() : undefined, logSeverityLevel: 0, }; @@ -111,8 +125,13 @@ export class Whisper { let url = this.url + this.models[name]["url"]; if (this.dataType == "float16") { url = url.replace(".onnx", "_fp16_layernorm_gelu.onnx"); - if (name.includes("decoder") && this.mask_4d) { - url = url.replace(".onnx", "_4dmask.onnx"); + if (name.includes("decoder")) { + if (this.mask_4d) { + url = url.replace(".onnx", "_4dmask.onnx"); + } + if (this.iobinding) { + url = url.replace(".onnx", "_iobinding.onnx"); + } } log(`Loading ${this.models[name]["title"]} · ${this.dataType} · ${this.models[name]["fp16size"]}`); } else { @@ -120,7 +139,8 @@ export class Whisper { log(`Loading ${this.models[name]["title"]} · ${this.dataType} · ${this.models[name]["fp32size"]}`); } - const modelBuffer = await getModelOPFS(`${this.deviceType}_${name}_${this.dataType}`, url, false); + const modelBuffer = await getModelOPFS(`${this.iobinding}_${this.deviceType}_${name}_${this.dataType}`, + url, false); log(`${this.models[name]["title"]} loaded`); log(`Creating session for ${this.models[name]["title"]}`); @@ -271,15 +291,24 @@ export class Whisper { } // modify the self attention kv cache in place - cache_update( - decoder_input, - decoder_output, - 0, - this.max_sequence_length, - this.num_init_tokens, - this.num_init_tokens, - this.dataType, - ); + if (this.iobinding) { + for (let i = 0; i < 6; i++) { + decoder_input[`past_key_values.${i}.decoder.key`] = + decoder_output[`padded_present_key_values.${i}.decoder.key`]; + decoder_input[`past_key_values.${i}.decoder.value`] = + decoder_output[`padded_present_key_values.${i}.decoder.value`]; + } + } else { + cache_update( + decoder_input, + decoder_output, + 0, + this.max_sequence_length, + this.num_init_tokens, + this.num_init_tokens, + this.dataType + ); + } const position_ids = new Int32Array(decoder_input["position_ids"].cpuData.buffer); // run complete inference for every item in dataset @@ -319,16 +348,27 @@ export class Whisper { position_ids[0], this.mask_4d, ); - // modify the kv cache in place - cache_update( - decoder_input, - decoder_cached_output, - i, - this.max_sequence_length, - this.num_init_tokens, - position_ids[0], - this.dataType, - ); + + // modify the kv cache in place + if (this.iobinding) { + for (let i = 0; i < 6; i++) { + decoder_input[`past_key_values.${i}.decoder.key`] = + decoder_cached_output[`updated_present_key_values.${i}.decoder.key`]; + decoder_input[`past_key_values.${i}.decoder.value`] = + decoder_cached_output[`updated_present_key_values.${i}.decoder.value`]; + } + } else { + cache_update( + decoder_input, + decoder_cached_output, + i, + this.max_sequence_length, + this.num_init_tokens, + position_ids[0], + this.dataType + ); + } + } // add token to sentence decode time From 820fdcd44e270dad59916519a8c3b33191b9a8eb Mon Sep 17 00:00:00 2001 From: Miao Bin Date: Thu, 5 Dec 2024 17:14:38 +0800 Subject: [PATCH 2/2] Optimized the code --- demos/whisper-base/main.js | 26 ++++++++++++++------------ demos/whisper-base/whisper.js | 29 +++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/demos/whisper-base/main.js b/demos/whisper-base/main.js index f93a64a..d6dda48 100644 --- a/demos/whisper-base/main.js +++ b/demos/whisper-base/main.js @@ -88,7 +88,7 @@ const SpeechStates = { let speechState = SpeechStates.UNINITIALIZED; let mask4d = true; // use 4D mask input for decoder models -let iobinding = true; +let ioBinding = true; let streamingNode = null; let sourceNode = null; let audioChunks = []; // member {isSubChunk: boolean, data: Float32Array} @@ -138,32 +138,34 @@ function updateConfig() { let vars = query.split("&"); for (let i = 0; i < vars.length; i++) { let pair = vars[i].split("="); + const key = pair[0].toLowerCase(); + const value = pair[1].toLowerCase(); if (pair[0] == "provider" && providers.includes(pair[1])) { provider = pair[1]; } - if (pair[0].toLowerCase() == "devicetype" && deviceTypes.includes(pair[1])) { + if (key == "devicetype" && deviceTypes.includes(pair[1])) { deviceType = pair[1]; } - if (pair[0].toLowerCase() == "datatype" && dataTypes.includes(pair[1])) { + if (key == "datatype" && dataTypes.includes(pair[1])) { dataType = pair[1]; } - if (pair[0].toLowerCase() == "maxchunklength") { + if (key == "maxchunklength") { maxChunkLength = parseFloat(pair[1]); } - if (pair[0].toLowerCase() == "chunklength") { + if (key == "chunklength") { chunkLength = parseFloat(pair[1]); } - if (pair[0].toLowerCase() == "maxaudiolength") { + if (key == "maxaudiolength") { maxAudioLength = Math.min(parseInt(pair[1]), kMaxAudioLengthInSec); } - if (pair[0].toLowerCase() == "accumulatesubchunks") { - accumulateSubChunks = pair[1].toLowerCase() === "true"; + if (key == "accumulatesubchunks") { + accumulateSubChunks = value === "true"; } if (pair[0] == "mask_4d") { - mask4d = pair[1].toLowerCase() === "true"; + mask4d = value === "true"; } - if (pair[0] == 'iobinding') { - iobinding = pair[1].toLowerCase() === 'true'; + if (pair[0] == 'ioBinding') { + ioBinding = value === 'true'; } } } @@ -668,7 +670,7 @@ const main = async () => { const whisper_url = location.href.includes("github.io") ? "https://huggingface.co/microsoft/whisper-base-webnn/resolve/main/" : "./models/"; - whisper = new Whisper(whisper_url, provider, deviceType, dataType, mask4d, iobinding); + whisper = new Whisper(whisper_url, provider, deviceType, dataType, mask4d, ioBinding); await whisper.create_whisper_processor(); await whisper.create_whisper_tokenizer(); await whisper.create_ort_sessions(); diff --git a/demos/whisper-base/whisper.js b/demos/whisper-base/whisper.js index f5a45cc..6d72778 100644 --- a/demos/whisper-base/whisper.js +++ b/demos/whisper-base/whisper.js @@ -40,13 +40,13 @@ if ( // wrapper around onnxruntime and model export class Whisper { - constructor(url, provider, deviceType = "gpu", dataType, mask_4d = true, iobinding) { + constructor(url, provider, deviceType = "gpu", dataType, mask_4d = true, ioBinding) { this.url = url; this.provider = provider; this.deviceType = deviceType; this.dataType = dataType; this.mask_4d = mask_4d; - this.iobinding = iobinding && deviceType == 'gpu'; + this.ioBinding = ioBinding && deviceType == 'gpu'; ort.env.wasm.simd = true; this.models = { @@ -102,8 +102,8 @@ export class Whisper { deviceType: this.deviceType, }, ], - preferredOutputLocation: this.iobinding ? (() => { - const pairs = {}; + preferredOutputLocation: this.ioBinding ? (() => { + let pairs = {}; pairs['last_hidden_state'] = "ml-tensor"; for (let i = 0; i < 6; i++) { pairs[`padded_present_key_values.${i}.decoder.key`] = "ml-tensor"; @@ -129,7 +129,7 @@ export class Whisper { if (this.mask_4d) { url = url.replace(".onnx", "_4dmask.onnx"); } - if (this.iobinding) { + if (this.ioBinding) { url = url.replace(".onnx", "_iobinding.onnx"); } } @@ -139,7 +139,7 @@ export class Whisper { log(`Loading ${this.models[name]["title"]} · ${this.dataType} · ${this.models[name]["fp32size"]}`); } - const modelBuffer = await getModelOPFS(`${this.iobinding}_${this.deviceType}_${name}_${this.dataType}`, + const modelBuffer = await getModelOPFS(`${this.ioBinding}_${this.deviceType}_${name}_${this.dataType}`, url, false); log(`${this.models[name]["title"]} loaded`); @@ -291,7 +291,7 @@ export class Whisper { } // modify the self attention kv cache in place - if (this.iobinding) { + if (this.ioBinding) { for (let i = 0; i < 6; i++) { decoder_input[`past_key_values.${i}.decoder.key`] = decoder_output[`padded_present_key_values.${i}.decoder.key`]; @@ -350,8 +350,12 @@ export class Whisper { ); // modify the kv cache in place - if (this.iobinding) { + if (this.ioBinding) { for (let i = 0; i < 6; i++) { + // dispose previous tensors + decoder_input[`past_key_values.${i}.decoder.key`].mlTensor.destroy(); + decoder_input[`past_key_values.${i}.decoder.value`].mlTensor.destroy(); + // update the kv cache decoder_input[`past_key_values.${i}.decoder.key`] = decoder_cached_output[`updated_present_key_values.${i}.decoder.key`]; decoder_input[`past_key_values.${i}.decoder.value`] = @@ -371,6 +375,15 @@ export class Whisper { } + if (this.ioBinding) { + for (let i = 0; i < 6; i++) { + decoder_output[`padded_present_key_values.${i}.decoder.key`].mlTensor.destroy(); + decoder_output[`padded_present_key_values.${i}.decoder.value`].mlTensor.destroy(); + decoder_input[`past_key_values.${i}.encoder.key`].mlTensor.destroy(); + decoder_input[`past_key_values.${i}.encoder.value`].mlTensor.destroy(); + } + } + // add token to sentence decode time const sentence = await this.tokenizer.decode(tokens, { skip_special_tokens: true,