diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 7052544f6..45206fa01 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -20,6 +20,17 @@ import faster_whisper import whisper +# Catch exception from whisper.dll not getting loaded. +# TODO: Remove flag and try-except when issue with loading +# the DLL in some envs is fixed. +LOADED_WHISPER_DLL = False +try: + import buzz.whisper_cpp as whisper_cpp # noqa: F401 + + LOADED_WHISPER_DLL = True +except ImportError: + logging.exception("") + class WhisperModelSize(str, enum.Enum): TINY = "tiny" @@ -44,6 +55,38 @@ class ModelType(enum.Enum): FASTER_WHISPER = "Faster Whisper" OPEN_AI_WHISPER_API = "OpenAI Whisper API" + def supports_recording(self): + # Live transcription with OpenAI Whisper API not supported + return self != ModelType.OPEN_AI_WHISPER_API + + def is_available(self): + if ( + # Hide Whisper.cpp option if whisper.dll did not load correctly. + # See: https://github.com/chidiwilliams/buzz/issues/274, + # https://github.com/chidiwilliams/buzz/issues/197 + (self == ModelType.WHISPER_CPP and not LOADED_WHISPER_DLL) + # Disable Whisper and Faster Whisper options + # on Linux due to execstack errors on Snap + or ( + sys.platform == "linux" + and self + in ( + ModelType.WHISPER, + ModelType.FASTER_WHISPER, + ModelType.HUGGING_FACE, + ) + ) + ): + return False + return True + + def is_manually_downloadable(self): + return self in ( + ModelType.WHISPER, + ModelType.WHISPER_CPP, + ModelType.FASTER_WHISPER, + ) + @dataclass() class TranscriptionModel: diff --git a/buzz/recording_transcriber.py b/buzz/recording_transcriber.py index 811dfe157..24bad0c88 100644 --- a/buzz/recording_transcriber.py +++ b/buzz/recording_transcriber.py @@ -9,7 +9,7 @@ from PyQt6.QtCore import QObject, pyqtSignal from sounddevice import PortAudioError -from buzz import transformers_whisper +from buzz import transformers_whisper, whisper_audio from buzz.model_loader import ModelType from buzz.transcriber import TranscriptionOptions, WhisperCpp, whisper_cpp_params from buzz.transformers_whisper import TransformersWhisper @@ -23,6 +23,7 @@ class RecordingTranscriber(QObject): finished = pyqtSignal() error = pyqtSignal(str) is_running = False + SAMPLE_RATE = whisper_audio.SAMPLE_RATE MAX_QUEUE_SIZE = 10 def __init__( @@ -152,17 +153,15 @@ def get_device_sample_rate(device_id: Optional[int]) -> int: provided by Whisper if the microphone supports it, or else it uses the device's default sample rate. """ - whisper_sample_rate = whisper.audio.SAMPLE_RATE + sample_rate = whisper_audio.SAMPLE_RATE try: - sounddevice.check_input_settings( - device=device_id, samplerate=whisper_sample_rate - ) - return whisper_sample_rate + sounddevice.check_input_settings(device=device_id, samplerate=sample_rate) + return sample_rate except PortAudioError: device_info = sounddevice.query_devices(device=device_id) if isinstance(device_info, dict): - return int(device_info.get("default_samplerate", whisper_sample_rate)) - return whisper_sample_rate + return int(device_info.get("default_samplerate", sample_rate)) + return sample_rate def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status): # Try to enqueue the next block. If the queue is already full, drop the block. diff --git a/buzz/transcriber.py b/buzz/transcriber.py index f7cf4f214..eabf92a01 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -23,6 +23,7 @@ from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot from dataclasses_json import dataclass_json, config, Exclude +from buzz.model_loader import whisper_cpp from . import transformers_whisper from .conn import pipe_stderr from .locale import _ @@ -33,17 +34,6 @@ import whisper import stable_whisper -# Catch exception from whisper.dll not getting loaded. -# TODO: Remove flag and try-except when issue with loading -# the DLL in some envs is fixed. -LOADED_WHISPER_DLL = False -try: - import buzz.whisper_cpp as whisper_cpp - - LOADED_WHISPER_DLL = True -except ImportError: - logging.exception("") - DEFAULT_WHISPER_TEMPERATURE = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) diff --git a/buzz/widgets/model_type_combo_box.py b/buzz/widgets/model_type_combo_box.py index f494bf3cf..31273f609 100644 --- a/buzz/widgets/model_type_combo_box.py +++ b/buzz/widgets/model_type_combo_box.py @@ -1,11 +1,9 @@ -import sys from typing import Optional, List from PyQt6.QtCore import pyqtSignal from PyQt6.QtWidgets import QComboBox, QWidget from buzz.model_loader import ModelType -from buzz.transcriber import LOADED_WHISPER_DLL class ModelTypeComboBox(QComboBox): @@ -20,28 +18,11 @@ def __init__( super().__init__(parent) if model_types is None: - model_types = [model_type for model_type in ModelType] + model_types = [ + model_type for model_type in ModelType if model_type.is_available() + ] for model_type in model_types: - if ( - # Hide Whisper.cpp option if whisper.dll did not load correctly. - # See: https://github.com/chidiwilliams/buzz/issues/274, - # https://github.com/chidiwilliams/buzz/issues/197 - model_type == ModelType.WHISPER_CPP and LOADED_WHISPER_DLL is False - ) or ( - # Disable Whisper and Faster Whisper options - # on Linux due to execstack errors on Snap - ( - model_type - in ( - ModelType.WHISPER, - ModelType.FASTER_WHISPER, - ModelType.HUGGING_FACE, - ) - ) - and sys.platform == "linux" - ): - continue self.addItem(model_type.value) self.currentTextChanged.connect(self.on_text_changed) diff --git a/buzz/widgets/preferences_dialog/models_preferences_widget.py b/buzz/widgets/preferences_dialog/models_preferences_widget.py index 4b0a4b868..931f0fb78 100644 --- a/buzz/widgets/preferences_dialog/models_preferences_widget.py +++ b/buzz/widgets/preferences_dialog/models_preferences_widget.py @@ -23,6 +23,8 @@ class ModelsPreferencesWidget(QWidget): + model: Optional[TranscriptionModel] + def __init__( self, progress_dialog_modality=Qt.WindowModality.WindowModal, @@ -31,8 +33,19 @@ def __init__( super().__init__(parent) self.model_downloader: Optional[ModelDownloader] = None - self.model = TranscriptionModel( - model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY + + model_types = [ + model_type + for model_type in ModelType + if model_type.is_available() and model_type.is_manually_downloadable() + ] + + self.model = ( + TranscriptionModel( + model_type=model_types[0], whisper_model_size=WhisperModelSize.TINY + ) + if model_types[0] is not None + else None ) self.progress_dialog_modality = progress_dialog_modality @@ -40,12 +53,8 @@ def __init__( layout = QFormLayout() model_type_combo_box = ModelTypeComboBox( - model_types=[ - ModelType.WHISPER, - ModelType.WHISPER_CPP, - ModelType.FASTER_WHISPER, - ], - default_model=self.model.model_type, + model_types=model_types, + default_model=self.model.model_type if self.model is not None else None, parent=self, ) model_type_combo_box.changed.connect(self.on_model_type_changed) @@ -119,6 +128,10 @@ def reset(self): self.model_list_widget.expandToDepth(2) self.model_list_widget.setHeaderHidden(True) self.model_list_widget.setAlternatingRowColors(True) + + if self.model is None: + return + for model_size in WhisperModelSize: model = TranscriptionModel( model_type=self.model.model_type, whisper_model_size=model_size diff --git a/buzz/widgets/recording_transcriber_widget.py b/buzz/widgets/recording_transcriber_widget.py index c5215c97a..4a28fd893 100644 --- a/buzz/widgets/recording_transcriber_widget.py +++ b/buzz/widgets/recording_transcriber_widget.py @@ -12,14 +12,12 @@ ModelDownloader, TranscriptionModel, ModelType, - WhisperModelSize, ) from buzz.recording import RecordingAmplitudeListener from buzz.recording_transcriber import RecordingTranscriber from buzz.settings.settings import Settings from buzz.transcriber import ( TranscriptionOptions, - LOADED_WHISPER_DLL, Task, DEFAULT_WHISPER_TEMPERATURE, ) @@ -65,15 +63,20 @@ def __init__( default_language = self.settings.value( key=Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, default_value="" ) + + model_types = [ + model_type + for model_type in ModelType + if model_type.is_available() and model_type.supports_recording() + ] + default_model: Optional[TranscriptionModel] = None + if len(model_types) > 0: + default_model = TranscriptionModel(model_type=model_types[0]) + self.transcription_options = TranscriptionOptions( model=self.settings.value( key=Settings.Key.RECORDING_TRANSCRIBER_MODEL, - default_value=TranscriptionModel( - model_type=ModelType.WHISPER_CPP - if LOADED_WHISPER_DLL - else ModelType.WHISPER, - whisper_model_size=WhisperModelSize.TINY, - ), + default_value=default_model, ), task=self.settings.value( key=Settings.Key.RECORDING_TRANSCRIBER_TASK, @@ -102,12 +105,7 @@ def __init__( transcription_options_group_box = TranscriptionOptionsGroupBox( default_transcription_options=self.transcription_options, - # Live transcription with OpenAI Whisper API not implemented - model_types=[ - model_type - for model_type in ModelType - if model_type is not ModelType.OPEN_AI_WHISPER_API - ], + model_types=model_types, parent=self, ) transcription_options_group_box.transcription_options_changed.connect(