diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 431212259..64cbd434e 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -20,6 +20,8 @@ import whisper import huggingface_hub +from buzz.locale import _ + # Catch exception from whisper.dll not getting loaded. # TODO: Remove flag and try-except when issue with loading # the DLL in some envs is fixed. @@ -46,6 +48,7 @@ class WhisperModelSize(str, enum.Enum): LARGE = "large" LARGEV2 = "large-v2" LARGEV3 = "large-v3" + CUSTOM = "custom" def to_faster_whisper_model_size(self) -> str: if self == WhisperModelSize.LARGE: @@ -112,9 +115,15 @@ def is_manually_downloadable(self): @dataclass() class TranscriptionModel: - model_type: ModelType = ModelType.WHISPER - whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY - hugging_face_model_id: Optional[str] = "openai/whisper-tiny" + def __init__( + self, + model_type: ModelType = ModelType.WHISPER, + whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY, + hugging_face_model_id: Optional[str] = "" + ): + self.model_type = model_type + self.whisper_model_size = whisper_model_size + self.hugging_face_model_id = hugging_face_model_id def __str__(self): match self.model_type: @@ -135,10 +144,16 @@ def is_deletable(self): return ( self.model_type == ModelType.WHISPER or self.model_type == ModelType.WHISPER_CPP + or self.model_type == ModelType.FASTER_WHISPER ) and self.get_local_model_path() is not None def open_file_location(self): model_path = self.get_local_model_path() + + if (self.model_type == ModelType.HUGGING_FACE + or self.model_type == ModelType.FASTER_WHISPER): + model_path = os.path.dirname(model_path) + if model_path is None: return self.open_path(path=os.path.dirname(model_path)) @@ -160,6 +175,17 @@ def open_path(path: str): def delete_local_file(self): model_path = self.get_local_model_path() + + if (self.model_type == ModelType.HUGGING_FACE + or self.model_type == ModelType.FASTER_WHISPER): + model_path = os.path.dirname(os.path.dirname(model_path)) + + logging.debug("Deleting model directory: %s", model_path) + + shutil.rmtree(model_path, ignore_errors=True) + return + + logging.debug("Deleting model file: %s", model_path) os.remove(model_path) def get_local_model_path(self) -> Optional[str]: @@ -178,7 +204,7 @@ def get_local_model_path(self) -> Optional[str]: if self.model_type == ModelType.FASTER_WHISPER: try: return download_faster_whisper_model( - size=self.whisper_model_size.value, local_files_only=True + model=self, local_files_only=True ) except (ValueError, FileNotFoundError): return None @@ -208,6 +234,7 @@ def get_local_model_path(self) -> Optional[str]: "large-v1": "7d99f41a10525d0206bddadd86760181fa920438b6b33237e3118ff6c83bb53d", "large-v2": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487", "large-v3": "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2", + "custom": None, } @@ -217,6 +244,10 @@ def get_whisper_cpp_file_path(size: WhisperModelSize) -> str: def get_whisper_file_path(size: WhisperModelSize) -> str: root_dir = os.path.join(model_root_dir, "whisper") + + if size == WhisperModelSize.CUSTOM: + return os.path.join(root_dir, "custom") + url = whisper._MODELS[size.value] return os.path.join(root_dir, os.path.basename(url)) @@ -286,13 +317,17 @@ def download_from_huggingface( allow_patterns: List[str], progress: pyqtSignal(tuple), ): - progress.emit((1, 100)) + progress.emit((0, 100)) - model_root = huggingface_hub.snapshot_download( - repo_id, - allow_patterns=allow_patterns[1:], # all, but largest - cache_dir=model_root_dir - ) + try: + model_root = huggingface_hub.snapshot_download( + repo_id, + allow_patterns=allow_patterns[1:], # all, but largest + cache_dir=model_root_dir + ) + except Exception as exc: + logging.exception(exc) + return "" progress.emit((1, 100)) @@ -302,11 +337,16 @@ def download_from_huggingface( model_download_monitor = HuggingfaceDownloadMonitor(model_root, progress, total_file_size) model_download_monitor.start_monitoring() - huggingface_hub.snapshot_download( - repo_id, - allow_patterns=allow_patterns[:1], # largest - cache_dir=model_root_dir - ) + try: + huggingface_hub.snapshot_download( + repo_id, + allow_patterns=allow_patterns[:1], # largest + cache_dir=model_root_dir + ) + except Exception as exc: + logging.exception(exc) + model_download_monitor.stop_monitoring() + return "" model_download_monitor.stop_monitoring() @@ -314,17 +354,23 @@ def download_from_huggingface( def download_faster_whisper_model( - size: str, local_files_only=False, progress: pyqtSignal(tuple) = None + model: TranscriptionModel, local_files_only=False, progress: pyqtSignal(tuple) = None ): - if size not in faster_whisper.utils._MODELS: + size = model.whisper_model_size.to_faster_whisper_model_size() + custom_repo_id = model.hugging_face_model_id + + if size != WhisperModelSize.CUSTOM and size not in faster_whisper.utils._MODELS: raise ValueError( "Invalid model size '%s', expected one of: %s" % (size, ", ".join(faster_whisper.utils._MODELS)) ) - logging.debug("Downloading Faster Whisper model: %s", size) + if size == WhisperModelSize.CUSTOM and custom_repo_id == "": + raise ValueError("Custom model id is not provided") - if size == WhisperModelSize.LARGEV3: + if size == WhisperModelSize.CUSTOM: + repo_id = custom_repo_id + elif size == WhisperModelSize.LARGEV3: repo_id = "Systran/faster-whisper-large-v3" else: repo_id = "guillaumekln/faster-whisper-%s" % size @@ -358,20 +404,28 @@ class Signals(QObject): progress = pyqtSignal(tuple) # (current, total) error = pyqtSignal(str) - def __init__(self, model: TranscriptionModel): + def __init__(self, model: TranscriptionModel, custom_model_url: Optional[str] = None): super().__init__() self.signals = self.Signals() self.model = model self.stopped = False + self.custom_model_url = custom_model_url def run(self) -> None: + logging.debug("Downloading model: %s, %s", self.model, self.model.hugging_face_model_id) + if self.model.model_type == ModelType.WHISPER_CPP: model_name = self.model.whisper_model_size.to_whisper_cpp_model_size() - url = huggingface_hub.hf_hub_url( - repo_id="ggerganov/whisper.cpp", - filename=f"ggml-{model_name}.bin", - ) + + if self.custom_model_url: + url = self.custom_model_url + else: + url = huggingface_hub.hf_hub_url( + repo_id="ggerganov/whisper.cpp", + filename=f"ggml-{model_name}.bin", + ) + file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size) expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name] return self.download_model_to_path( @@ -388,9 +442,13 @@ def run(self) -> None: if self.model.model_type == ModelType.FASTER_WHISPER: model_path = download_faster_whisper_model( - size=self.model.whisper_model_size.to_faster_whisper_model_size(), + model=self.model, progress=self.signals.progress, ) + + if model_path == "": + self.signals.error.emit(_("Error")) + self.signals.finished.emit(model_path) return @@ -417,7 +475,7 @@ def download_model_to_path( if downloaded: self.signals.finished.emit(file_path) except requests.RequestException: - self.signals.error.emit("A connection error occurred") + self.signals.error.emit(_("A connection error occurred")) logging.exception("") except Exception as exc: self.signals.error.emit(str(exc)) diff --git a/buzz/settings/settings.py b/buzz/settings/settings.py index 62fb6e26e..0b95675b6 100644 --- a/buzz/settings/settings.py +++ b/buzz/settings/settings.py @@ -38,6 +38,8 @@ class Key(enum.Enum): DEFAULT_EXPORT_FILE_NAME = "transcriber/default-export-file-name" CUSTOM_OPENAI_BASE_URL = "transcriber/custom-openai-base-url" + CUSTOM_FASTER_WHISPER_ID = "transcriber/custom-faster-whisper-id" + HUGGINGFACE_MODEL_ID = "transcriber/huggingface-model-id" SHORTCUTS = "shortcuts" @@ -50,6 +52,36 @@ class Key(enum.Enum): def set_value(self, key: Key, value: typing.Any) -> None: self.settings.setValue(key.value, value) + def save_custom_model_id(self, model) -> None: + from buzz.model_loader import ModelType + match model.model_type: + case ModelType.FASTER_WHISPER: + self.set_value( + Settings.Key.CUSTOM_FASTER_WHISPER_ID, + model.hugging_face_model_id, + ) + case ModelType.HUGGING_FACE: + self.set_value( + Settings.Key.HUGGINGFACE_MODEL_ID, + model.hugging_face_model_id, + ) + + def load_custom_model_id(self, model) -> str: + from buzz.model_loader import ModelType + match model.model_type: + case ModelType.FASTER_WHISPER: + return self.value( + Settings.Key.CUSTOM_FASTER_WHISPER_ID, + "", + ) + case ModelType.HUGGING_FACE: + return self.value( + Settings.Key.HUGGINGFACE_MODEL_ID, + "", + ) + + return "" + def value( self, key: Key, diff --git a/buzz/transcriber/whisper_file_transcriber.py b/buzz/transcriber/whisper_file_transcriber.py index 0a8d68c21..b42ed366f 100644 --- a/buzz/transcriber/whisper_file_transcriber.py +++ b/buzz/transcriber/whisper_file_transcriber.py @@ -13,7 +13,7 @@ from PyQt6.QtCore import QObject from buzz.conn import pipe_stderr -from buzz.model_loader import ModelType +from buzz.model_loader import ModelType, WhisperModelSize from buzz.transformers_whisper import TransformersWhisper from buzz.transcriber.file_transcriber import FileTranscriber from buzz.transcriber.transcriber import FileTranscriptionTask, Segment @@ -131,8 +131,13 @@ def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]: @classmethod def transcribe_faster_whisper(cls, task: FileTranscriptionTask) -> List[Segment]: + if task.transcription_options.model.whisper_model_size == WhisperModelSize.CUSTOM: + model_size_or_path = task.transcription_options.model.hugging_face_model_id + else: + model_size_or_path = task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size() + model = faster_whisper.WhisperModel( - model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size() + model_size_or_path=model_size_or_path ) whisper_segments, info = model.transcribe( audio=task.file_path, diff --git a/buzz/widgets/preferences_dialog/models_preferences_widget.py b/buzz/widgets/preferences_dialog/models_preferences_widget.py index 57e81ac4f..7f6324efa 100644 --- a/buzz/widgets/preferences_dialog/models_preferences_widget.py +++ b/buzz/widgets/preferences_dialog/models_preferences_widget.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from PyQt6.QtCore import Qt, QThreadPool @@ -18,8 +19,13 @@ TranscriptionModel, ModelDownloader, ) +from buzz.settings.settings import Settings from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog from buzz.widgets.model_type_combo_box import ModelTypeComboBox +from buzz.widgets.line_edit import LineEdit +from buzz.widgets.transcriber.hugging_face_search_line_edit import ( + HuggingFaceSearchLineEdit, +) class ModelsPreferencesWidget(QWidget): @@ -32,6 +38,7 @@ def __init__( ): super().__init__(parent) + self.settings = Settings() self.model_downloader: Optional[ModelDownloader] = None model_types = [ @@ -67,6 +74,20 @@ def __init__( buttons_layout = QHBoxLayout() + self.custom_model_id_input = HuggingFaceSearchLineEdit() + self.custom_model_id_input.setObjectName("ModelIdInput") + + self.custom_model_id_input.setPlaceholderText(_("Huggingface ID of a Faster whisper model")) + self.custom_model_id_input.textChanged.connect(self.on_custom_model_id_input_changed) + layout.addRow("", self.custom_model_id_input) + self.custom_model_id_input.hide() + + self.custom_model_link_input = LineEdit() + self.custom_model_link_input.setObjectName("ModelLinkInput") + self.custom_model_link_input.textChanged.connect(self.on_custom_model_link_input_changed) + layout.addRow("", self.custom_model_link_input) + self.custom_model_link_input.hide() + self.download_button = QPushButton(_("Download")) self.download_button.setObjectName("DownloadButton") self.download_button.clicked.connect(self.on_download_button_clicked) @@ -100,17 +121,11 @@ def on_model_size_changed(self, current: QTreeWidgetItem, _: QTreeWidgetItem): self.model.whisper_model_size = item_data self.reset() - @staticmethod - def can_delete_model(model: TranscriptionModel): - return ( - model.model_type == ModelType.WHISPER - or model.model_type == ModelType.WHISPER_CPP - ) and model.get_local_model_path() is not None - def reset(self): # reset buttons path = self.model.get_local_model_path() self.download_button.setVisible(path is None) + self.download_button.setEnabled(self.model.whisper_model_size != WhisperModelSize.CUSTOM) self.delete_button.setVisible(self.model.is_deletable()) self.show_file_location_button.setVisible(self.model.is_deletable()) @@ -129,12 +144,45 @@ def reset(self): self.model_list_widget.setHeaderHidden(True) self.model_list_widget.setAlternatingRowColors(True) + self.model.hugging_face_model_id = self.settings.load_custom_model_id(self.model) + self.custom_model_id_input.setText(self.model.hugging_face_model_id) + + if (self.model.whisper_model_size == WhisperModelSize.CUSTOM + and self.model.model_type == ModelType.FASTER_WHISPER): + self.custom_model_id_input.show() + self.download_button.setEnabled( + self.model.hugging_face_model_id != "" + ) + else: + self.custom_model_id_input.hide() + + if self.model.model_type == ModelType.WHISPER_CPP: + self.custom_model_link_input.setPlaceholderText( + _("Download link to Whisper.cpp ggml model file") + ) + + if (self.model.whisper_model_size == WhisperModelSize.CUSTOM + and self.model.model_type == ModelType.WHISPER_CPP + and path is None): + self.custom_model_link_input.show() + self.download_button.setEnabled( + self.custom_model_link_input.text() != "") + else: + self.custom_model_link_input.hide() + if self.model is None: return for model_size in WhisperModelSize: + # Skip custom size for OpenAI Whisper + if (self.model.model_type == ModelType.WHISPER and + model_size == WhisperModelSize.CUSTOM): + continue + model = TranscriptionModel( - model_type=self.model.model_type, whisper_model_size=model_size + model_type=self.model.model_type, + whisper_model_size=WhisperModelSize(model_size), + hugging_face_model_id=self.model.hugging_face_model_id, ) model_path = model.get_local_model_path() parent = downloaded_item if model_path is not None else available_item @@ -149,6 +197,16 @@ def on_model_type_changed(self, model_type: ModelType): self.model.model_type = model_type self.reset() + def on_custom_model_id_input_changed(self, text): + self.model.hugging_face_model_id = text + self.settings.save_custom_model_id(self.model) + self.download_button.setEnabled( + self.model.hugging_face_model_id != "" + ) + + def on_custom_model_link_input_changed(self, text): + self.download_button.setEnabled(text != "") + def on_download_button_clicked(self): self.progress_dialog = ModelDownloadProgressDialog( model_type=self.model.model_type, @@ -159,7 +217,15 @@ def on_download_button_clicked(self): self.download_button.setEnabled(False) - self.model_downloader = ModelDownloader(model=self.model) + if (self.model.whisper_model_size == WhisperModelSize.CUSTOM and + self.model.model_type == ModelType.WHISPER_CPP): + self.model_downloader = ModelDownloader( + model=self.model, + custom_model_url=self.custom_model_link_input.text() + ) + else: + self.model_downloader = ModelDownloader(model=self.model) + self.model_downloader.signals.finished.connect(self.on_download_completed) self.model_downloader.signals.progress.connect(self.on_download_progress) self.model_downloader.signals.error.connect(self.on_download_error) @@ -185,10 +251,12 @@ def on_download_completed(self, _: str): def on_download_error(self, error: str): self.progress_dialog.cancel() + self.progress_dialog.close() self.progress_dialog = None self.download_button.setEnabled(True) self.reset() - QMessageBox.warning(self, _("Error"), f"{_('Download failed')}: {error}") + download_failed_label = _('Download failed') + QMessageBox.warning(self, _("Error"), f"{download_failed_label}: {error}") def on_download_progress(self, progress: tuple): self.progress_dialog.set_value(float(progress[0]) / progress[1]) diff --git a/buzz/widgets/transcriber/hugging_face_search_line_edit.py b/buzz/widgets/transcriber/hugging_face_search_line_edit.py index efb319bea..16da9e6a0 100644 --- a/buzz/widgets/transcriber/hugging_face_search_line_edit.py +++ b/buzz/widgets/transcriber/hugging_face_search_line_edit.py @@ -14,12 +14,11 @@ QEvent, ) from PyQt6.QtGui import QKeyEvent -from PyQt6.QtCore import QSettings from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply from PyQt6.QtWidgets import QListWidget, QWidget, QAbstractItemView, QListWidgetItem +from buzz.locale import _ from buzz.widgets.line_edit import LineEdit -from buzz.settings.settings import APP_NAME # Adapted from https://github.com/ismailsunni/scripts/blob/master/autocomplete_from_url.py @@ -29,11 +28,12 @@ class HuggingFaceSearchLineEdit(LineEdit): def __init__( self, - default_value: str, + default_value: str = "", network_access_manager: Optional[QNetworkAccessManager] = None, parent: Optional[QWidget] = None, ): super().__init__(default_value, parent) + self.setPlaceholderText(_("Huggingface ID of a model")) self.setMinimumWidth(150) diff --git a/buzz/widgets/transcriber/transcription_options_group_box.py b/buzz/widgets/transcriber/transcription_options_group_box.py index dee9ca8a5..beabc43ce 100644 --- a/buzz/widgets/transcriber/transcription_options_group_box.py +++ b/buzz/widgets/transcriber/transcription_options_group_box.py @@ -5,6 +5,7 @@ from PyQt6.QtWidgets import QGroupBox, QWidget, QFormLayout, QComboBox from buzz.locale import _ +from buzz.settings.settings import Settings from buzz.model_loader import ModelType, WhisperModelSize from buzz.transcriber.transcriber import TranscriptionOptions, Task from buzz.widgets.model_type_combo_box import ModelTypeComboBox @@ -29,6 +30,7 @@ def __init__( parent: Optional[QWidget] = None, ): super().__init__(title="", parent=parent) + self.settings = Settings() self.transcription_options = default_transcription_options self.form_layout = QFormLayout(self) @@ -49,12 +51,8 @@ def __init__( self.whisper_model_size_combo_box = QComboBox(self) self.whisper_model_size_combo_box.addItems( - [size.value.title() for size in WhisperModelSize] + [size.value.title() for size in WhisperModelSize if size != WhisperModelSize.CUSTOM] ) - if default_transcription_options.model.whisper_model_size is not None: - self.whisper_model_size_combo_box.setCurrentText( - default_transcription_options.model.whisper_model_size.value.title() - ) self.whisper_model_size_combo_box.currentTextChanged.connect( self.on_whisper_model_size_changed ) @@ -72,6 +70,7 @@ def __init__( self.hugging_face_search_line_edit.model_selected.connect( self.on_hugging_face_model_changed ) + self.hugging_face_search_line_edit.setVisible(False) self.tasks_combo_box = TasksComboBox( default_task=self.transcription_options.task, parent=self @@ -122,9 +121,40 @@ def on_transcription_options_changed( def reset_visible_rows(self): model_type = self.transcription_options.model.model_type + whisper_model_size = self.transcription_options.model.whisper_model_size + + if (model_type == ModelType.HUGGING_FACE + or (whisper_model_size == WhisperModelSize.CUSTOM + and model_type == ModelType.FASTER_WHISPER)): + self.transcription_options.model.hugging_face_model_id = ( + self.settings.load_custom_model_id(self.transcription_options.model)) + self.hugging_face_search_line_edit.setText( + self.transcription_options.model.hugging_face_model_id) + self.form_layout.setRowVisible( - self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE + self.hugging_face_search_line_edit, + (model_type == ModelType.HUGGING_FACE) + or (model_type == ModelType.FASTER_WHISPER + and whisper_model_size == WhisperModelSize.CUSTOM), + ) + + custom_model_index = (self.whisper_model_size_combo_box + .findText(WhisperModelSize.CUSTOM.value.title())) + if (model_type == ModelType.WHISPER + and whisper_model_size == WhisperModelSize.CUSTOM + and custom_model_index != -1): + self.whisper_model_size_combo_box.removeItem(custom_model_index) + + if ((model_type == ModelType.WHISPER_CPP or model_type == ModelType.FASTER_WHISPER) + and custom_model_index == -1): + self.whisper_model_size_combo_box.addItem( + WhisperModelSize.CUSTOM.value.title() + ) + + self.whisper_model_size_combo_box.setCurrentText( + self.transcription_options.model.whisper_model_size.value.title() ) + self.form_layout.setRowVisible( self.whisper_model_size_combo_box, (model_type == ModelType.WHISPER) @@ -146,8 +176,13 @@ def on_model_type_changed(self, model_type: ModelType): def on_whisper_model_size_changed(self, text: str): model_size = WhisperModelSize(text.lower()) self.transcription_options.model.whisper_model_size = model_size + + self.reset_visible_rows() + self.transcription_options_changed.emit(self.transcription_options) def on_hugging_face_model_changed(self, model: str): self.transcription_options.model.hugging_face_model_id = model self.transcription_options_changed.emit(self.transcription_options) + + self.settings.save_custom_model_id(self.transcription_options.model)