Skip to content

Commit

Permalink
Merge pull request #6 from argmaxinc/atila/wer_avg_method
Browse files Browse the repository at this point in the history
Adopt per-word averaging convention for WER
  • Loading branch information
atiorh authored Apr 2, 2024
2 parents 3ca64f9 + b5c1e50 commit 5cfec57
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 80 deletions.
93 changes: 52 additions & 41 deletions README.md

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions scripts/generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@

from argmaxtools import _sdpa
from argmaxtools.utils import get_logger
from argmaxtools import test_utils
from huggingface_hub import HfApi, hf_hub_download

from tests import test_audio_encoder, test_text_decoder
from whisperkit._constants import COMPRESSION_REPO_ID, MODEL_REPO_ID

logger = get_logger(__name__)

test_utils.TEST_MIN_SPEEDUP_VS_CPU = 0.3


def cli():
f""" Generates Whisper models and publishes them to hf.co/{MODEL_REPO_ID} """
Expand Down Expand Up @@ -93,6 +96,14 @@ def cli():

logger.info(f"Generating {args.model_version} files")

# FIXME(atiorh): Remove this once distil-whisper-* models are updated
args.disable_token_timestamps = False
if "distil-whisper" in args.model_version:
logger.info(
"Disabling token-level timestamps due to missing alignment_heads in distil-whisper-* models"
)
args.disable_token_timestamps = True

# Generate WhisperTextDecoder
args.test_seq_len = args.text_decoder_max_sequence_length
args.sdpa_implementation = args.text_decoder_sdpa_implementation
Expand Down
140 changes: 107 additions & 33 deletions scripts/generate_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (C) 2024 Argmax, Inc. All Rights Reserved.
#
import argparse
import evaluate
import json
import os
from collections import defaultdict
Expand All @@ -13,40 +14,47 @@

from whisperkit._constants import EVALS_REPO_ID, MODEL_REPO_ID

wer_metric = evaluate.load("wer")

logger = get_logger(__name__)

QOI_KEY = "QoI (%)"
QOI_KEY = "QoI ()"
FILE_SIZE_KEY = "File Size (MB)"
WER_KEY = "WER"
COMMIT_KEY = "Commit Hash"
WER_KEY = "WER (↓)"
COMMIT_KEY = "Code Commit"

HF_HUB_DATASET_CARD_YAML_PREFIX = """
---
pretty_name: "WhisperKit ASR Evaluation Results"
viewer: false
library_name: whisperkit
tags:
- whisper
- whisperkit
- coreml
- asr
- quantized
---
# WhisperKit Evaluation Results\n
# WhisperKit Transcription Quality\n
"""

HF_HUB_METRIC_EXPLANATION = """
### Explanation
We believe that rigorously measuring the quality of inference is necessary for developers and
enterprises to make informed decisions when opting to use optimized or compressed variants of
any machine learning model in production. To contextualize `WhisperKit`, we take the following Whisper
implementations and benchmark them using a consistent evaluation harness:
Server-side:
- `WhisperOpenAIAPI`: [OpenAI's Whisper API](https://platform.openai.com/docs/guides/speech-to-text) ($0.36 per hour of audio as of 02/29/24, 25MB file size limit per request)
- `WhisperOpenAIAPI`: [OpenAI's Whisper API](https://platform.openai.com/docs/guides/speech-to-text)
\n($0.36 per hour of audio as of 02/29/24, 25MB file size limit per request)
On-device:
- `WhisperKit`: Argmax's implementation [[Eval Harness]](https://github.com/argmaxinc/whisperkittools/blob/main/whisperkit/pipelines.py#L100) [[Repo]](https://github.com/argmaxinc/WhisperKit)
- `whisper.cpp`: A C++ implementation form ggerganov [[Eval Harness]](https://github.com/argmaxinc/whisperkittools/blob/main/whisperkit/pipelines.py#L212) [[Repo]](https://github.com/ggerganov/whisper.cpp)
- `WhisperMLX`: A Python implementation from Apple MLX [[Eval Harness]](https://github.com/argmaxinc/whisperkittools/blob/main/whisperkit/pipelines.py#L338) [[Repo]](https://github.com/ml-explore/mlx-examples/blob/main/whisper/whisper/transcribe.py)
\n(All on-device implementations are available for free under MIT license as of 03/19/2024)
`WhisperOpenAIAPI` sets the reference and we assume that it is using the equivalent of [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2)
in float16 precision along with additional undisclosed optimizations from OpenAI. In all measurements, we care primarily about per-example no-regressions (quantified as `qoi` below)
Expand All @@ -70,16 +78,21 @@
We anticipate developers that use Whisper (or similar models) in production to have their own Quality Assurance test sets and [whisperkittools](https://github.com/argmaxinc/whisperkittools) offers
the tooling necessary to run the same measurements on such custom test sets, please see the [Model Evaluation on Custom Dataset]((https://github.com/argmaxinc/whisperkittools)) for details.
### Why are there so many Whisper versions?
WhisperKit is an SDK for building speech-to-text features in apps across a wide range of Apple devices. We are working towards abstracting away the model versioning from the developer so WhisperKit
"just works" by deploying the highest-quality model version that a particular device can execute. In the interim, we leave the choice to the developer by providing quality and size trade-offs.
### Datasets
- [librispeech](https://huggingface.co/datasets/argmaxinc/librispeech): ~5 hours of short English audio clips, tests short-form transcription quality
- [earnings22](https://huggingface.co/datasets/argmaxinc/earnings22): ~120 hours of English audio clips from earnings calls with various accents, tests long-form transcription quality
### Reproducing Results
Results in this page are generated by our cluster of Apple Silicon Macs. We use them as self-hosted runners on
Github Actions as our CI infrastructure. Due to [security concerns](https://docs.github.com/en/actions/security-guides/security-hardening-for-github-actions#hardening-for-self-hosted-runners),
Benchmark results on this page were automatically generated by [whisperkittools](https://github.com/argmaxinc/whisperkittools) using our cluster of Apple Silicon Macs as self-hosted runners on
Github Actions. We periodically recompute these benchmarks as part of our CI pipeline. Due to [security concerns](https://docs.github.com/en/actions/security-guides/security-hardening-for-github-actions#hardening-for-self-hosted-runners),
we are unable to open up the cluster to the public. However, any Apple Silicon Mac (even with 8GB RAM) can be used to
run identical [evaluation jobs](#evaluation) locally. For reference, our M2 Ultra devices complete a `librispeech` + `openai/whisper-large-v3`
evaluation in under 1 hour regardless of the Whisper implementation. Older Apple Silicon Macs should take less than 1 day to complete the same evaluation.
evaluation in under 1 hour regardless of the Whisper implementation. Oldest Apple Silicon Macs should take less than 1 day to complete the same evaluation.
""" # noqa: E501

Expand All @@ -97,19 +110,33 @@

# TODO(atiorh): Read remote git file size
REFERENCE_MODEL_FILE_SIZES = {
"WhisperKit/openai_whisper-large-v2": 3100, # MB
"WhisperKit/openai_whisper-large-v2_turbo": 3100, # MB
"WhisperKit/openai_whisper-large-v3": 3100, # MB
"WhisperKit/openai_whisper-large-v3_turbo": 3100, # MB
"WhisperKit/openai_whisper-small": 483, # MB
"WhisperKit/openai_whisper-small.en": 483, # MB
"WhisperKit/openai_whisper-base": 145, # MB
"WhisperKit/openai_whisper-base.en": 145, # MB
"WhisperKit/openai_whisper-tiny": 66, # MB
"WhisperKit/openai_whisper-tiny.en": 66, # MB
"whisper.cpp/openai_whisper-large-v2-q5_0": 1080, # MB
"whisper.cpp/openai_whisper-large-v3-q5_0": 1080, # MB
"WhisperOpenAIAPI/openai_whisper-large-v2": 3100, # MB
"WhisperKit/openai_whisper-large-v2": 3100, # MB
"WhisperKit/openai_whisper-large-v2_turbo": 3100, # MB
"WhisperKit/openai_whisper-large-v3": 3100, # MB
"WhisperKit/openai_whisper-large-v3_turbo": 3100, # MB
"WhisperKit/openai_whisper-small": 483, # MB
"WhisperKit/openai_whisper-small.en": 483, # MB
"WhisperKit/openai_whisper-base": 145, # MB
"WhisperKit/openai_whisper-base.en": 145, # MB
"WhisperKit/openai_whisper-tiny": 66, # MB
"WhisperKit/openai_whisper-tiny.en": 66, # MB
"whisper.cpp/openai_whisper-large-v2-q5_0": 1080, # MB
"whisper.cpp/openai_whisper-large-v3-q5_0": 1080, # MB
"whisper.cpp/openai_whisper-large-v3": 3100, # MB
"whisper.cpp/openai_whisper-large-v2": 3100, # MB
"WhisperOpenAIAPI/openai_whisper-large-v2": 3100, # MB
"WhisperKit/distil-whisper_distil-large-v3": 1510, # MB
"WhisperKit/distil-whisper_distil-large-v3_turbo": 1510, # MB
}

DATASET_CAPTIONS = {
"librispeech": "Short-form Audio (<30s/clip) - 5 hours of English audiobook clips",
"earnings22": "Long-Form Audio (>1hr/clip) - 120 hours of earnings call recordings in English with various accents",
}

REPO_URLS = {
"whisper.cpp": "https://github.com/ggerganov/whisper.cpp",
"WhisperKit": "https://github.com/argmaxinc/WhisperKit"
}


Expand Down Expand Up @@ -141,7 +168,7 @@ def cli():
readme = ""

for dataset_name in args.dataset_names:
readme += f"\n## Dataset: `{dataset_name}`\n"
readme += f"\n## Dataset: `{dataset_name}`\n{DATASET_CAPTIONS[dataset_name]}\n"
"-------------------------------------------------"

# Quality-of-Inference (QoI) certifications for Whisper models
Expand All @@ -151,30 +178,58 @@ def cli():
results_dict[WER_KEY] = defaultdict(float)
results_dict[QOI_KEY] = defaultdict(float)
results_dict[FILE_SIZE_KEY] = defaultdict(int)
results_dict[COMMIT_KEY] = defaultdict(str)

# Fetch the reference eval results
reference_code_repo, reference_model = parse_name(reference)

reference_eval, reference_link = get_latest_eval(
reference_code_repo, dataset_name, reference_model)
reference_key = f"[{reference}]({reference_link})"

reference_key = reference.rsplit('/')[
-1].replace('openai_whisper-', '').replace('distil-whisper_', '')
if reference_code_repo == "WhisperKit":
reference_key = \
f"[{reference_key}]" \
f"({get_model_link(reference_model)}) "
else:
reference_key = reference_key + f" ({reference_code_repo})"

# Fill reference model version values
results_dict[QOI_KEY][reference_key] = 100. # By definition of QoI
results_dict[FILE_SIZE_KEY][reference_key] = \
REFERENCE_MODEL_FILE_SIZES[reference]

# Sample average WER for reference model
results_dict[WER_KEY][reference_key] = round(
sum([sample["wer"] for sample in reference_eval["results"]]) /
len(reference_eval["results"]) * 100., 2)
results_dict[WER_KEY][reference_key] = \
f"[{compute_average_wer(reference_eval['results'])}]({reference_link})"

# Add commit hash for reference results
commit_hash = reference_eval["metadata"]["inference_context"]["code_spec"]["code_commit_hash"]
if commit_hash is not None:
results_dict[COMMIT_KEY][reference_key] = \
f"[Link]({REPO_URLS[reference_code_repo]}/commit/{commit_hash[:7]})"
else:
results_dict[COMMIT_KEY][reference_key] = "N/A"

# Fill optimized model version values
for optimized in optimized_csv.split(","):
optimized_code_repo, optimized_model = parse_name(optimized)
optimized_eval, optimized_link = get_latest_eval(
optimized_code_repo, dataset_name, optimized_model)
optimized_key = f"[{optimized}]({optimized_link})"
try:
optimized_eval, optimized_link = get_latest_eval(
optimized_code_repo, dataset_name, optimized_model)
except Exception as e:
logger.warning(f"Could not fetch eval JSON for {optimized}: {e}")
continue

optimized_key = optimized.rsplit('/')[
-1].replace('openai_whisper-', '').replace('distil-whisper_', '')
if optimized_code_repo == "WhisperKit":
optimized_key = \
f"[{optimized_key}]" \
f"({get_model_link(optimized_model)}) "
else:
optimized_key = optimized_key + f" ({optimized_code_repo})"

# Verify fetched evals are comparable
logger.info(f"Compare {optimized_link} vs {reference_link}")
Expand All @@ -184,9 +239,16 @@ def cli():
optimized_eval["results"]
)
results_dict[QOI_KEY][optimized_key] = qoi["no_regression"]
results_dict[WER_KEY][optimized_key] = round(
sum([sample["wer"] for sample in optimized_eval["results"]]) /
len(optimized_eval["results"]) * 100., 2)
results_dict[WER_KEY][optimized_key] = \
f"[{compute_average_wer(optimized_eval['results'])}]({optimized_link})"

# Add commit hash for reference results
commit_hash = optimized_eval["metadata"]["inference_context"]["code_spec"]["code_commit_hash"]
if commit_hash is not None:
results_dict[COMMIT_KEY][optimized_key] = \
f"[Link]({REPO_URLS[optimized_code_repo]}/commit/{commit_hash[:7]})"
else:
results_dict[COMMIT_KEY][optimized_key] = "N/A"

# TODO(atiorh): Read remote git file size
if optimized in REFERENCE_MODEL_FILE_SIZES:
Expand Down Expand Up @@ -270,10 +332,11 @@ def parse_name(result, default_code_repo="WhisperKit"):
return code_repo, model


def get_latest_eval(code_repo, dataset_name, model_version, local_dir="/tmp"):
def get_latest_eval(code_repo, dataset_name, model_version, local_dir="external"):
f""" Fetch the latest eval from hf.co/datasets/{EVALS_REPO_ID}
for given code repo, model version and dataset
"""
os.makedirs(local_dir, exist_ok=True)
repo_rel_dir = os.path.join(code_repo, model_version, dataset_name)
_ = snapshot_download(
repo_id=EVALS_REPO_ID,
Expand Down Expand Up @@ -319,3 +382,14 @@ def verify_apples_to_apples(reference_eval, optimized_eval):
logger.warning(
"Reference and optimized evals weren't generated with the same "
"whisperkittools commit")


def compute_average_wer(results):
return round(wer_metric.compute(
references=[result["reference"] for result in results],
predictions=[result["prediction"] for result in results],
) * 100., 2)


def get_model_link(model_version):
return f"https://hf.co/{MODEL_REPO_ID}/tree/main/{model_version}"
3 changes: 2 additions & 1 deletion tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from argmaxtools.utils import get_logger
from huggingface_hub import HfApi

from whisperkit._constants import EVALS_REPO_ID
from whisperkit._constants import EVALS_REPO_ID, MODEL_REPO_ID
from whisperkit.evaluate.datasets import EVAL_DATASETS
from whisperkit.evaluate.evaluate import evaluate
from whisperkit.pipelines import get_pipeline_cls
Expand Down Expand Up @@ -75,6 +75,7 @@ def setUpClass(cls) -> None:
"model_version": TEST_MODEL_VERSION,
"whisperkittools_commit_hash": wkt_commit_hash,
"inference_context": cls.inference_context.spec_dict(),
"model_repo_id": MODEL_REPO_ID
}
}

Expand Down
16 changes: 14 additions & 2 deletions tests/test_text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"logits", "key_cache_updates", "value_cache_updates", "alignment_heads_weights"]
TEST_CONTEXT_PREFILL_OUTPUT_NAMES = ["key_cache_prefill", "value_cache_prefill"]
TEST_DEC_KV_SEQ_LEN = None
TEST_TOKEN_TIMESTAMPS = True


class TestWhisperTextDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase):
Expand All @@ -48,6 +49,9 @@ def setUpClass(cls):
cls.test_cache_dir = TEST_CACHE_DIR
cls.model_name = "TextDecoder"

if not TEST_TOKEN_TIMESTAMPS:
cls.test_output_names.pop(cls.test_output_names.index("alignment_heads_weights"))

# Original model
orig_torch_model = (
modeling_whisper.WhisperForConditionalGeneration.from_pretrained(
Expand All @@ -68,7 +72,9 @@ def setUpClass(cls):
cls.test_torch_model.to(TEST_DEV).to(TEST_TORCH_DTYPE).eval()
)
cls.gen_cfg = orig_torch_model.generation_config
cls.test_torch_model.configure_for_token_timestamps(cls.gen_cfg)

if TEST_TOKEN_TIMESTAMPS:
cls.test_torch_model.configure_for_token_timestamps(cls.gen_cfg)

# Elaboration: I/O and architecture config
cfg = cls.orig_torch_model.config
Expand Down Expand Up @@ -347,6 +353,9 @@ class TestWhisperTextDecoderPalettizer(
def setUpClass(cls):
cls.model_name = "TextDecoder"
cls.output_names = TEST_OUTPUT_NAMES
if not TEST_TOKEN_TIMESTAMPS:
cls.output_names.pop("alignment_heads_weights")

cls.palettizer = palettize.WhisperTextDecoderPalettizer(
model_version=TEST_WHISPER_VERSION,
cache_dir=os.path.join(
Expand All @@ -370,9 +379,11 @@ def place(t):


def main(args):
global TEST_WHISPER_VERSION, TEST_CACHE_DIR, TEST_DEC_KV_SEQ_LEN
global TEST_WHISPER_VERSION, TEST_CACHE_DIR, TEST_DEC_KV_SEQ_LEN, TEST_TOKEN_TIMESTAMPS

TEST_WHISPER_VERSION = args.test_model_version
TEST_TOKEN_TIMESTAMPS = not args.disable_token_timestamps

logger.info(f"Testing {TEST_WHISPER_VERSION}")

text_decoder.SDPA_IMPL = getattr(_sdpa, args.sdpa_implementation)
Expand Down Expand Up @@ -422,6 +433,7 @@ def main(args):
parser.add_argument("--palettizer-tests", action="store_true")
parser.add_argument("--disable-default-tests", action="store_true")
parser.add_argument("--context-prefill-tests", action="store_true")
parser.add_argument("--disable-token-timestamps", action="store_true")
parser.add_argument(
"--sdpa-implementation", default="Cat", choices=tuple(_sdpa.__all__)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_word_timestamps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2023 Argmax, Inc. All Rights Reserved.
# Copyright (C) 2024 Argmax, Inc. All Rights Reserved.
#

import json
Expand Down
4 changes: 2 additions & 2 deletions whisperkit/evaluate/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def evaluate(whisper_pipeline: Union[pipelines.WhisperPipeline, pipelines.Whispe
int(bool(_num_fallbacks)) for _num_fallbacks in num_fallbacks
]) / len(num_fallbacks)
fallback_str = "-------------------------------------------------------"
fallback_str += f"\nTotal fallbacks: {total_fallbacks}"
fallback_str += "\nSamples with fallback: "
fallback_str += f"\n Total fallbacks: {total_fallbacks}"
fallback_str += "\n Samples with fallback: "
fallback_str += f"{samples_with_fallback_percent * 100.:.3g}%"

# Failed example bookkeeping
Expand Down

0 comments on commit 5cfec57

Please sign in to comment.