Skip to content

Commit

Permalink
Refactored model downloading (#761)
Browse files Browse the repository at this point in the history
  • Loading branch information
raivisdejus authored May 28, 2024
1 parent 7820952 commit 731efd7
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 30 deletions.
130 changes: 100 additions & 30 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
import hashlib
import logging
import os
import time
import threading
import shutil
import subprocess
import sys
import tempfile
import warnings
from dataclasses import dataclass
from typing import Optional
from typing import Optional, List

import requests
from PyQt6.QtCore import QObject, pyqtSignal, QRunnable
from platformdirs import user_cache_dir
from tqdm.auto import tqdm

import faster_whisper
import whisper
Expand Down Expand Up @@ -86,11 +87,11 @@ def is_manually_downloadable(self):


HUGGING_FACE_MODEL_ALLOW_PATTERNS = [
"model.safetensors", # largest by size first
"added_tokens.json",
"config.json",
"generation_config.json",
"merges.txt",
"model.safetensors",
"normalizer.json",
"preprocessor_config.json",
"special_tokens_map.json",
Expand Down Expand Up @@ -198,10 +199,6 @@ def get_local_model_path(self) -> Optional[str]:
}


def get_hugging_face_file_url(author: str, repository_name: str, filename: str):
return f"https://huggingface.co/{author}/{repository_name}/resolve/bf8b606c2fcd9173605cdf6bd2ac8a75a8141b6c/{filename}"


def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
root_dir = user_cache_dir("Buzz")
return os.path.join(root_dir, f"ggml-model-whisper-{size.value}.bin")
Expand All @@ -215,8 +212,88 @@ def get_whisper_file_path(size: WhisperModelSize) -> str:
return os.path.join(root_dir, os.path.basename(url))


class HuggingfaceDownloadMonitor:
def __init__(self, model_root: str, progress: pyqtSignal(tuple), total_file_size: int):
self.model_root = model_root
self.progress = progress
self.total_file_size = total_file_size
self.tmp_download_root = self.get_tmp_download_root(model_root)
self.stop_event = threading.Event()
self.monitor_thread = None

@staticmethod
def get_tmp_download_root(model_root):
normalized_model_root = os.path.normpath(model_root)
normalized_hub_path = os.path.normpath("huggingface/hub/")
index = normalized_model_root.find(normalized_hub_path)
if index == -1:
raise ValueError(f"Invalid model_root, '{normalized_hub_path}' not found")
return normalized_model_root[:index + len(normalized_hub_path)]

def clean_tmp_files(self):
for filename in os.listdir(self.tmp_download_root):
if filename.startswith("tmp"):
os.remove(os.path.join(self.tmp_download_root, filename))

def monitor_file_size(self):
while not self.stop_event.is_set():
for filename in os.listdir(self.tmp_download_root):
if filename.startswith("tmp"):
file_size = os.path.getsize(os.path.join(self.tmp_download_root, filename))
self.progress.emit((file_size, self.total_file_size))
time.sleep(2)

def start_monitoring(self):
self.clean_tmp_files()
self.monitor_thread = threading.Thread(target=self.monitor_file_size)
self.monitor_thread.start()

def stop_monitoring(self):
self.progress.emit((self.total_file_size, self.total_file_size))

if self.monitor_thread is not None:
self.stop_event.set()
self.monitor_thread.join()


def get_file_size(url):
response = requests.head(url, allow_redirects=True)
response.raise_for_status()
return int(response.headers['Content-Length'])


def download_from_huggingface(
repo_id: str,
allow_patterns: List[str],
progress: pyqtSignal(tuple),
):
progress.emit((1, 100))

model_root = huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[1:], # all, but largest
)

progress.emit((1, 100))

largest_file_url = huggingface_hub.hf_hub_url(repo_id, allow_patterns[0])
total_file_size = get_file_size(largest_file_url)

model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, total_file_size)
model_download_monitor.start_monitoring()

huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns[:1], # largest
)

model_download_monitor.stop_monitoring()

return model_root


def download_faster_whisper_model(
size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None
size: str, local_files_only=False, progress: pyqtSignal(tuple) = None
):
if size not in faster_whisper.utils._MODELS:
raise ValueError(
Expand All @@ -227,17 +304,23 @@ def download_faster_whisper_model(
repo_id = "guillaumekln/faster-whisper-%s" % size

allow_patterns = [
"model.bin", # largest by size first
"config.json",
"model.bin",
"tokenizer.json",
"vocabulary.txt",
]

return huggingface_hub.snapshot_download(
if local_files_only:
return huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns,
local_files_only=True,
)

return download_from_huggingface(
repo_id,
allow_patterns=allow_patterns,
local_files_only=local_files_only,
tqdm_class=tqdm_class,
progress=progress,
)


Expand All @@ -257,9 +340,8 @@ def __init__(self, model: TranscriptionModel):
def run(self) -> None:
if self.model.model_type == ModelType.WHISPER_CPP:
model_name = self.model.whisper_model_size.value
url = get_hugging_face_file_url(
author="ggerganov",
repository_name="whisper.cpp",
url = huggingface_hub.hf_hub_url(
repo_id="ggerganov/whisper.cpp",
filename=f"ggml-{model_name}.bin",
)
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
Expand All @@ -276,31 +358,19 @@ def run(self) -> None:
url=url, file_path=file_path, expected_sha256=expected_sha256
)

progress = self.signals.progress

# gross abuse of power...
class _tqdm(tqdm):
def update(self, n: float | None = ...) -> bool | None:
progress.emit((n, self.total))
return super().update(n)

def close(self):
progress.emit((self.n, self.total))
return super().close()

if self.model.model_type == ModelType.FASTER_WHISPER:
model_path = download_faster_whisper_model(
size=self.model.whisper_model_size.to_faster_whisper_model_size(),
tqdm_class=_tqdm,
progress=self.signals.progress,
)
self.signals.finished.emit(model_path)
return

if self.model.model_type == ModelType.HUGGING_FACE:
model_path = huggingface_hub.snapshot_download(
model_path = download_from_huggingface(
self.model.hugging_face_model_id,
allow_patterns=HUGGING_FACE_MODEL_ALLOW_PATTERNS,
tqdm_class=_tqdm
progress=self.signals.progress,
)
self.signals.finished.emit(model_path)
return
Expand Down
25 changes: 25 additions & 0 deletions tests/model_loader_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
import pytest

from buzz.model_loader import ModelDownloader,TranscriptionModel, ModelType, WhisperModelSize


class TestModelLoader:
@pytest.mark.parametrize(
"model",
[
TranscriptionModel(
model_type=ModelType.HUGGING_FACE,
hugging_face_model_id="RaivisDejus/whisper-tiny-lv",
),
],
)
def test_download_model(self, model: TranscriptionModel):
model_loader = ModelDownloader(model=model)
model_loader.run()

model_path = model.get_local_model_path()

assert model_path is not None, "Model path is None"
assert os.path.isdir(model_path), "Model path is not a directory"
assert len(os.listdir(model_path)) > 0, "Model directory is empty"

0 comments on commit 731efd7

Please sign in to comment.