Skip to content

Commit

Permalink
Merge pull request #35 from Honry/add-perf-info
Browse files Browse the repository at this point in the history
[Whisper Base] Add two metrics to measure non real-time performance
  • Loading branch information
fdwr authored Aug 29, 2024
2 parents c49c5e6 + 7c610e3 commit 72f9f52
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 40 deletions.
85 changes: 48 additions & 37 deletions demos/whisper-base/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ let resultShow;
let latency;
let audioProcessing;
let copy;
let audio_src;
let audioSrc;
let outputText;
let container;
let audioMotion;
Expand All @@ -79,7 +79,7 @@ const SpeechStates = {
};
let speechState = SpeechStates.UNINITIALIZED;

let mask_4d = true; // use 4D mask input for decoder models
let mask4d = true; // use 4D mask input for decoder models
let streamingNode = null;
let sourceNode = null;
let audioChunks = []; // member {isSubChunk: boolean, data: Float32Array}
Expand All @@ -105,6 +105,9 @@ let subAudioChunkLength = 0; // length of a sub audio chunk
let subText = "";
let speechToText = "";

let timeToFirstToken = 0; // TTFT
let numTokens = 0; // number of tokens

const blacklistTags = [
"[inaudible]",
" [inaudible]",
Expand Down Expand Up @@ -148,7 +151,7 @@ function updateConfig() {
accumulateSubChunks = pair[1].toLowerCase() === 'true';
}
if (pair[0] == 'mask_4d') {
mask_4d = pair[1].toLowerCase() === 'true';
mask4d = pair[1].toLowerCase() === 'true';
}
}
}
Expand Down Expand Up @@ -185,9 +188,13 @@ async function process_audio(audio, starttime, idx, pos) {
// run inference for 30 sec
const xa = audio.slice(idx, idx + kSteps);
const ret = await whisper.run(xa);
if (idx == 0) {
timeToFirstToken = ret.time_to_first_token;
}
numTokens += ret.num_tokens;
// append results to outputText
outputText.innerText += ret;
logUser(ret);
outputText.innerText += ret.sentence;
logUser(ret.sentence);
// outputText.scrollTop = outputText.scrollHeight;

await process_audio(audio, starttime, idx + kSteps, pos + kMaxAudioLengthInSec);
Expand All @@ -196,19 +203,23 @@ async function process_audio(audio, starttime, idx, pos) {
}
} else {
// done with audio buffer
const processing_time = (performance.now() - starttime) / 1000;
const processingTime = (performance.now() - starttime) / 1000;
const total = audio.length / kSampleRate;
const tokensPerSecond = (numTokens - 1) / (processingTime - timeToFirstToken / 1000);
numTokens = 0;
resultShow.setAttribute('class', 'show');
progress.style.width = "100%";

if(getMode()) {
latency.innerText = `100.0%, ${(
total / processing_time
).toFixed(1)} x realtime`;
latency.innerText = `100.0%, ${
(total / processingTime).toFixed(1)
} x realtime, time to first token: ${
timeToFirstToken.toFixed(1)
}ms, ${tokensPerSecond.toFixed(1)} tokens/s`;
log(
`${
latency.innerText
}, total ${processing_time.toFixed(
}, total ${processingTime.toFixed(
1
)}s processing time for ${total.toFixed(1)}s audio`
);
Expand All @@ -226,7 +237,7 @@ async function process_audio(audio, starttime, idx, pos) {
// transcribe audio source
async function transcribe_file() {
resultShow.setAttribute('class', 'show');
if (audio_src.src == "") {
if (audioSrc.src == "") {
logError("Error · No audio input, please record the audio");
ready();
return;
Expand All @@ -236,7 +247,7 @@ async function transcribe_file() {
log("Starting transcription ...");
audioProcessing.setAttribute('class', 'show');
try {
const buffer = await (await fetch(audio_src.src)).arrayBuffer();
const buffer = await (await fetch(audioSrc.src)).arrayBuffer();
const audioBuffer = await context.decodeAudioData(buffer);
const offlineContext = new OfflineAudioContext(
audioBuffer.numberOfChannels,
Expand Down Expand Up @@ -265,10 +276,10 @@ async function startRecord() {
speech.disabled = true;
stream = null;
outputText.innerText = '';
if (!audio_src.paused) {
audio_src.pause();
if (!audioSrc.paused) {
audioSrc.pause();
}
audio_src.src == "";
audioSrc.src == "";

resultShow.setAttribute('class', '');
if (mediaRecorder === undefined) {
Expand Down Expand Up @@ -309,9 +320,9 @@ async function startRecord() {
1
)}s audio`
);
audio_src.src = window.URL.createObjectURL(blob);
audioSrc.src = window.URL.createObjectURL(blob);
initAudioMotion();
audio_src.play();
audioSrc.play();
await transcribe_file();
};
mediaRecorder.start(kIntervalAudio_ms);
Expand All @@ -334,10 +345,10 @@ async function startSpeech() {
fileUpload.disabled = true;
record.disabled = true;
speech.disabled = false;
if (!audio_src.paused) {
audio_src.pause();
if (!audioSrc.paused) {
audioSrc.pause();
}
audio_src.src == "";
audioSrc.src == "";
resultShow.setAttribute('class', '');
speechState = SpeechStates.PROCESSING;
await captureAudioStream();
Expand Down Expand Up @@ -510,17 +521,17 @@ async function processAudioBuffer() {
if (processBufferLength > 0.16) {
const start = performance.now();
const ret = await whisper.run(processBuffer);
const processing_time = (performance.now() - start) / 1000;
const processingTime = (performance.now() - start) / 1000;
resultShow.setAttribute('class', 'show');

if(getMode()) {
latency.innerText = `${(
processBufferLength / processing_time
processBufferLength / processingTime
).toFixed(1)} x realtime`;
log(
`${
latency.innerText
}, ${processBufferLength}s audio processing time: ${processing_time.toFixed(2)}s`
}, ${processBufferLength}s audio processing time: ${processingTime.toFixed(2)}s`
);
} else {
latency.innerText = `realtime`;
Expand All @@ -530,20 +541,20 @@ async function processAudioBuffer() {
}

// ignore slient, inaudible audio output, i.e. '[BLANK_AUDIO]'
if (!blacklistTags.includes(ret)) {
if (!blacklistTags.includes(ret.sentence)) {
if (subAudioChunks.length > 0) {
if (accumulateSubChunks) {
subText = ret;
subText = ret.sentence;
} else {
subText += ret;
subText += ret.sentence;
}
outputText.innerText = speechToText + subText;
} else {
subText = '';
speechToText += ret;
speechToText += ret.sentence;
outputText.innerText = speechToText;
}
logUser(ret);
logUser(ret.sentence);
// outputText.scrollTop = outputText.scrollHeight;
}
} else {
Expand Down Expand Up @@ -588,7 +599,7 @@ const initAudioMotion = () => {
audioMotion = new AudioMotionAnalyzer(
container,
{
source: audio_src
source: audioSrc
}
);
audioMotion.setOptions(options);
Expand Down Expand Up @@ -643,19 +654,19 @@ const main = async () => {
record.disabled = true;
speech.disabled = true;
subText = "";
if (!audio_src.paused) {
audio_src.pause();
if (!audioSrc.paused) {
audioSrc.pause();
}
audio_src.src == "";
audioSrc.src == "";
let target = evt.target || window.event.src,
files = target.files;
if(files && files.length > 0) {
audio_src.src = URL.createObjectURL(files[0]);
audioSrc.src = URL.createObjectURL(files[0]);
initAudioMotion();
audio_src.play();
audioSrc.play();
await transcribe_file();
} else {
audio_src.src = '';
audioSrc.src = '';
}
};

Expand All @@ -675,7 +686,7 @@ const main = async () => {
const whisper_url = location.href.includes("github.io")
? "https://huggingface.co/microsoft/whisper-base-webnn/resolve/main/"
: "./models/";
whisper = new Whisper(whisper_url, provider, deviceType, dataType, mask_4d);
whisper = new Whisper(whisper_url, provider, deviceType, dataType, mask4d);
await whisper.create_whisper_processor();
await whisper.create_whisper_tokenizer();
await whisper.create_ort_sessions();
Expand All @@ -702,7 +713,7 @@ const main = async () => {
const ui = async () => {
device = document.getElementById('device');
badge = document.getElementById('badge');
audio_src = document.querySelector("audio");
audioSrc = document.querySelector("audio");
labelFileUpload = document.getElementById("label-file-upload");
fileUpload = document.getElementById("file-upload");
record = document.getElementById("record");
Expand Down
9 changes: 6 additions & 3 deletions demos/whisper-base/whisper.js
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ export class Whisper {
// -----------------------------------FEATURE EXTRACTION-----------------------------------------
// const audio = await read_audio('https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac', 16000);
// const audio = await read_audio(audio_data, sampling_rate);
// let start = performance.now();
let start = performance.now();
const { input_features } = await this.processor(audio_data);
// -----------------------------------ENCODER INFERENCE-----------------------------------------
// run encoder to get output
Expand Down Expand Up @@ -270,7 +270,7 @@ export class Whisper {
"input_ids": new ort.Tensor("int32", new Int32Array(tokens), [1, 4]),
"attention_mask": attention_mask,
"encoder_hidden_states": last_hidden_state,
};
};
// console.log(`Non-KV cache decoder input preparation time: ${(performance.now() - start).toFixed(2)}ms`);
// start = performance.now();
// run the first inference which generates SA and CA KV cache
Expand All @@ -289,6 +289,7 @@ export class Whisper {

// add token to final buffer
tokens = tokens.concat(new_token);
const time_to_first_token = (performance.now() - start); // TTFT

// for 2+ inference, we don't need encoder hidden states as input to OV model
delete decoder_input.encoder_hidden_states;
Expand Down Expand Up @@ -405,7 +406,9 @@ export class Whisper {
const sentence = await this.tokenizer.decode(tokens, {
skip_special_tokens: true,
});

const num_tokens = tokens.length - 4;
// log(`Post-processing time: ${(performance.now() - start).toFixed(2)}ms`);
return sentence;
return { sentence, time_to_first_token, num_tokens };
}
}

0 comments on commit 72f9f52

Please sign in to comment.