From 2b839c35cbed92f2916dcdc352a3606ac77957b6 Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Wed, 27 Dec 2023 10:25:13 +0000 Subject: [PATCH] feat: add folder watch (#655) --- buzz/settings/settings.py | 4 +- buzz/transcriber.py | 35 +++- buzz/widgets/audio_player.py | 2 +- buzz/widgets/icon.py | 45 +++-- buzz/widgets/main_window.py | 65 +++++-- buzz/widgets/menu_bar.py | 28 ++- .../folder_watch_preferences_widget.py | 137 +++++++++++++++ .../general_preferences_widget.py | 2 +- .../preferences_dialog/models/__init__.py | 0 .../models/file_transcription_preferences.py | 96 ++++++++++ .../models/folder_watch_preferences.py | 38 ++++ .../preferences_dialog/models/preferences.py | 24 +++ .../preferences_dialog/preferences_dialog.py | 26 ++- .../transcriber/file_transcriber_widget.py | 164 ++++-------------- .../file_transcription_form_widget.py | 110 ++++++++++++ .../transcription_task_folder_watcher.py | 76 ++++++++ .../export_transcription_button.py | 4 +- .../transcription_viewer_widget.py | 2 +- tests/transcriber_test.py | 92 ++++++++-- tests/widgets/menu_bar_test.py | 25 +++ .../folder_watch_preferences_widget_test.py | 55 ++++++ .../preferences_dialog_test.py | 11 +- .../transcription_task_folder_watcher_test.py | 105 +++++++++++ 23 files changed, 952 insertions(+), 194 deletions(-) create mode 100644 buzz/widgets/preferences_dialog/folder_watch_preferences_widget.py create mode 100644 buzz/widgets/preferences_dialog/models/__init__.py create mode 100644 buzz/widgets/preferences_dialog/models/file_transcription_preferences.py create mode 100644 buzz/widgets/preferences_dialog/models/folder_watch_preferences.py create mode 100644 buzz/widgets/preferences_dialog/models/preferences.py create mode 100644 buzz/widgets/transcriber/file_transcription_form_widget.py create mode 100644 buzz/widgets/transcription_task_folder_watcher.py create mode 100644 tests/widgets/menu_bar_test.py create mode 100644 tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py create mode 100644 tests/widgets/transcription_task_folder_watcher_test.py diff --git a/buzz/settings/settings.py b/buzz/settings/settings.py index c450c7438..ea9d819f4 100644 --- a/buzz/settings/settings.py +++ b/buzz/settings/settings.py @@ -54,10 +54,10 @@ def value( def clear(self): self.settings.clear() - def begin_group(self, group: Key): + def begin_group(self, group: Key) -> None: self.settings.beginGroup(group.value) - def end_group(self): + def end_group(self) -> None: self.settings.endGroup() def sync(self): diff --git a/buzz/transcriber.py b/buzz/transcriber.py index c8e3f835a..4013934fa 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -6,6 +6,7 @@ import math import multiprocessing import os +import shutil import subprocess import sys import tempfile @@ -96,6 +97,10 @@ class Status(enum.Enum): FAILED = "failed" CANCELED = "canceled" + class Source(enum.Enum): + FILE_IMPORT = "file_import" + FOLDER_WATCH = "folder_watch" + file_path: str transcription_options: TranscriptionOptions file_transcription_options: FileTranscriptionOptions @@ -108,6 +113,8 @@ class Status(enum.Enum): queued_at: Optional[datetime.datetime] = None started_at: Optional[datetime.datetime] = None completed_at: Optional[datetime.datetime] = None + output_directory: Optional[str] = None + source: Source = Source.FILE_IMPORT def status_text(self) -> str: if self.status == FileTranscriptionTask.Status.IN_PROGRESS: @@ -169,7 +176,7 @@ def run(self): for ( output_format ) in self.transcription_task.file_transcription_options.output_formats: - default_path = get_default_output_file_path( + default_path = get_output_file_path( task=self.transcription_task, output_format=output_format ) @@ -177,6 +184,15 @@ def run(self): path=default_path, segments=segments, output_format=output_format ) + if self.transcription_task.source == FileTranscriptionTask.Source.FOLDER_WATCH: + shutil.move( + self.transcription_task.file_path, + os.path.join( + self.transcription_task.output_directory, + os.path.basename(self.transcription_task.file_path), + ), + ) + @abstractmethod def transcribe(self) -> List[Segment]: ... @@ -644,24 +660,22 @@ def segments_to_text(segments: List[Segment]) -> str: def to_timestamp(ms: float, ms_separator=".") -> str: hr = int(ms / (1000 * 60 * 60)) - ms = ms - hr * (1000 * 60 * 60) + ms -= hr * (1000 * 60 * 60) min = int(ms / (1000 * 60)) - ms = ms - min * (1000 * 60) + ms -= min * (1000 * 60) sec = int(ms / 1000) ms = int(ms - sec * 1000) return f"{hr:02d}:{min:02d}:{sec:02d}{ms_separator}{ms:03d}" -SUPPORTED_OUTPUT_FORMATS = "Audio files (*.mp3 *.wav *.m4a *.ogg);;\ +SUPPORTED_AUDIO_FORMATS = "Audio files (*.mp3 *.wav *.m4a *.ogg);;\ Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)" -def get_default_output_file_path( - task: FileTranscriptionTask, output_format: OutputFormat -): - input_file_name = os.path.splitext(task.file_path)[0] +def get_output_file_path(task: FileTranscriptionTask, output_format: OutputFormat): + input_file_name = os.path.splitext(os.path.basename(task.file_path))[0] date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S") - return ( + output_file_name = ( task.file_transcription_options.default_output_file_name.replace( "{{ input_file_name }}", input_file_name ) @@ -678,6 +692,9 @@ def get_default_output_file_path( + f".{output_format.value}" ) + output_directory = task.output_directory or os.path.dirname(task.file_path) + return os.path.join(output_directory, output_file_name) + def whisper_cpp_params( language: str, diff --git a/buzz/widgets/audio_player.py b/buzz/widgets/audio_player.py index f34e324ab..cc714a0ea 100644 --- a/buzz/widgets/audio_player.py +++ b/buzz/widgets/audio_player.py @@ -53,7 +53,7 @@ def __init__(self, file_path: str): self.media_player.playbackStateChanged.connect(self.on_playback_state_changed) self.media_player.mediaStatusChanged.connect(self.on_media_status_changed) - self.update_time_label() + self.on_duration_changed(self.media_player.duration()) def on_duration_changed(self, duration_ms: int): self.scrubber.setRange(0, duration_ms) diff --git a/buzz/widgets/icon.py b/buzz/widgets/icon.py index f8cfef9aa..0788f8496 100644 --- a/buzz/widgets/icon.py +++ b/buzz/widgets/icon.py @@ -4,27 +4,48 @@ from buzz.assets import get_asset_path -# TODO: move icons to Qt resources: https://stackoverflow.com/a/52341917/9830227 class Icon(QIcon): - LIGHT_THEME_BACKGROUND = "#555" - DARK_THEME_BACKGROUND = "#EEE" + LIGHT_THEME_COLOR = "#555" + DARK_THEME_COLOR = "#EEE" def __init__(self, path: str, parent: QWidget): - # Adapted from https://stackoverflow.com/questions/15123544/change-the-color-of-an-svg-in-qt - is_dark_theme = parent.palette().window().color().black() > 127 - color = self.get_color(is_dark_theme) - + super().__init__() + self.path = path + self.parent = parent + + self.color = self.get_color() + normal_pixmap = self.create_default_pixmap(self.path, self.color) + disabled_pixmap = self.create_disabled_pixmap(normal_pixmap, self.color) + self.addPixmap(normal_pixmap, QIcon.Mode.Normal) + self.addPixmap(disabled_pixmap, QIcon.Mode.Disabled) + + # https://stackoverflow.com/questions/15123544/change-the-color-of-an-svg-in-qt + def create_default_pixmap(self, path, color): pixmap = QPixmap(path) painter = QPainter(pixmap) painter.setCompositionMode(QPainter.CompositionMode.CompositionMode_SourceIn) - painter.fillRect(pixmap.rect(), QColor(color)) + painter.fillRect(pixmap.rect(), color) painter.end() + return pixmap + + def create_disabled_pixmap(self, pixmap, color): + disabled_pixmap = QPixmap(pixmap.size()) + disabled_pixmap.fill(QColor(0, 0, 0, 0)) - super().__init__(pixmap) + painter = QPainter(disabled_pixmap) + painter.setOpacity(0.4) + painter.drawPixmap(0, 0, pixmap) + painter.setCompositionMode( + QPainter.CompositionMode.CompositionMode_DestinationIn + ) + painter.fillRect(disabled_pixmap.rect(), color) + painter.end() + return disabled_pixmap - def get_color(self, is_dark_theme): - return ( - self.DARK_THEME_BACKGROUND if is_dark_theme else self.LIGHT_THEME_BACKGROUND + def get_color(self) -> QColor: + is_dark_theme = self.parent.palette().window().color().black() > 127 + return QColor( + self.DARK_THEME_COLOR if is_dark_theme else self.LIGHT_THEME_COLOR ) diff --git a/buzz/widgets/main_window.py b/buzz/widgets/main_window.py index 022df9bea..a61cf950f 100644 --- a/buzz/widgets/main_window.py +++ b/buzz/widgets/main_window.py @@ -1,7 +1,11 @@ -from typing import Dict, Optional, Tuple, List +from typing import Dict, Tuple, List from PyQt6 import QtGui -from PyQt6.QtCore import pyqtSignal, Qt, QThread, QModelIndex +from PyQt6.QtCore import ( + Qt, + QThread, + QModelIndex, +) from PyQt6.QtGui import QIcon from PyQt6.QtWidgets import QMainWindow, QMessageBox, QFileDialog @@ -15,12 +19,16 @@ FileTranscriptionTask, TranscriptionOptions, FileTranscriptionOptions, - SUPPORTED_OUTPUT_FORMATS, + SUPPORTED_AUDIO_FORMATS, ) from buzz.widgets.icon import BUZZ_ICON_PATH from buzz.widgets.main_window_toolbar import MainWindowToolbar from buzz.widgets.menu_bar import MenuBar +from buzz.widgets.preferences_dialog.models.preferences import Preferences from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget +from buzz.widgets.transcription_task_folder_watcher import ( + TranscriptionTaskFolderWatcher, +) from buzz.widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget from buzz.widgets.transcription_viewer.transcription_viewer_widget import ( TranscriptionViewerWidget, @@ -30,8 +38,6 @@ class MainWindow(QMainWindow): table_widget: TranscriptionTasksTableWidget tasks: Dict[int, "FileTranscriptionTask"] - tasks_changed = pyqtSignal() - openai_access_token: Optional[str] def __init__(self, tasks_cache=TasksCache()): super().__init__(flags=Qt.WindowType.Window) @@ -54,7 +60,6 @@ def __init__(self, tasks_cache=TasksCache()): ) self.tasks = {} - self.tasks_changed.connect(self.on_tasks_changed) self.toolbar = MainWindowToolbar(shortcuts=self.shortcuts, parent=self) self.toolbar.new_transcription_action_triggered.connect( @@ -72,9 +77,11 @@ def __init__(self, tasks_cache=TasksCache()): self.addToolBar(self.toolbar) self.setUnifiedTitleAndToolBarOnMac(True) + self.preferences = self.load_preferences(settings=self.settings) self.menu_bar = MenuBar( shortcuts=self.shortcuts, default_export_file_name=self.default_export_file_name, + preferences=self.preferences, parent=self, ) self.menu_bar.import_action_triggered.connect( @@ -87,6 +94,7 @@ def __init__(self, tasks_cache=TasksCache()): self.menu_bar.default_export_file_name_changed.connect( self.default_export_file_name_changed ) + self.menu_bar.preferences_changed.connect(self.on_preferences_changed) self.setMenuBar(self.menu_bar) self.table_widget = TranscriptionTasksTableWidget(self) @@ -113,6 +121,31 @@ def __init__(self, tasks_cache=TasksCache()): self.load_geometry() + self.folder_watcher = TranscriptionTaskFolderWatcher( + tasks=self.tasks, + preferences=self.preferences.folder_watch, + default_export_file_name=self.default_export_file_name, + ) + self.folder_watcher.task_found.connect(self.add_task) + self.folder_watcher.find_tasks() + + def on_preferences_changed(self, preferences: Preferences): + self.preferences = preferences + self.save_preferences(preferences) + self.folder_watcher.set_preferences(preferences.folder_watch) + self.folder_watcher.find_tasks() + + def save_preferences(self, preferences: Preferences): + self.settings.settings.beginGroup("preferences") + preferences.save(self.settings.settings) + self.settings.settings.endGroup() + + def load_preferences(self, settings: Settings): + settings.settings.beginGroup("preferences") + preferences = Preferences.load(settings.settings) + settings.settings.endGroup() + return preferences + def dragEnterEvent(self, event): # Accept file drag events if event.mimeData().hasUrls(): @@ -134,13 +167,13 @@ def on_file_transcriber_triggered( ) self.add_task(task) - def load_task(self, task: FileTranscriptionTask): + def upsert_task_in_table(self, task: FileTranscriptionTask): self.table_widget.upsert_task(task) self.tasks[task.id] = task def update_task_table_row(self, task: FileTranscriptionTask): - self.load_task(task=task) - self.tasks_changed.emit() + self.upsert_task_in_table(task=task) + self.on_tasks_changed() @staticmethod def task_completed_or_errored(task: FileTranscriptionTask): @@ -158,7 +191,8 @@ def on_clear_history_action_triggered(self): self, _("Clear History"), _( - "Are you sure you want to delete the selected transcription(s)? This action cannot be undone." + "Are you sure you want to delete the selected transcription(s)? " + "This action cannot be undone." ), ) if reply == QMessageBox.StandardButton.Yes: @@ -169,7 +203,7 @@ def on_clear_history_action_triggered(self): for task_id in task_ids: self.table_widget.clear_task(task_id) self.tasks.pop(task_id) - self.tasks_changed.emit() + self.on_tasks_changed() def on_stop_transcription_action_triggered(self): selected_rows = self.table_widget.selectionModel().selectedRows() @@ -178,13 +212,13 @@ def on_stop_transcription_action_triggered(self): task = self.tasks[task_id] task.status = FileTranscriptionTask.Status.CANCELED - self.tasks_changed.emit() + self.on_tasks_changed() self.transcriber_worker.cancel_task(task_id) self.table_widget.upsert_task(task) def on_new_transcription_action_triggered(self): (file_paths, __) = QFileDialog.getOpenFileNames( - self, _("Select audio file"), "", SUPPORTED_OUTPUT_FORMATS + self, _("Select audio file"), "", SUPPORTED_AUDIO_FORMATS ) if len(file_paths) == 0: return @@ -213,6 +247,7 @@ def default_export_file_name_changed(self, default_export_file_name: str): self.settings.set_value( Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name ) + self.folder_watcher.default_export_file_name = default_export_file_name def open_transcript_viewer(self): selected_rows = self.table_widget.selectionModel().selectedRows() @@ -291,9 +326,9 @@ def load_tasks_from_cache(self): or task.status == FileTranscriptionTask.Status.IN_PROGRESS ): task.status = None - self.transcriber_worker.add_task(task) + self.add_task(task) else: - self.load_task(task=task) + self.upsert_task_in_table(task=task) def save_tasks_to_cache(self): self.tasks_cache.save(list(self.tasks.values())) diff --git a/buzz/widgets/menu_bar.py b/buzz/widgets/menu_bar.py index 14e61ef5d..932dfb659 100644 --- a/buzz/widgets/menu_bar.py +++ b/buzz/widgets/menu_bar.py @@ -1,15 +1,18 @@ import webbrowser -from typing import Dict +from typing import Dict, Optional from PyQt6.QtCore import pyqtSignal from PyQt6.QtGui import QAction, QKeySequence from PyQt6.QtWidgets import QMenuBar, QWidget -from buzz.widgets.about_dialog import AboutDialog from buzz.locale import _ from buzz.settings.settings import APP_NAME from buzz.settings.shortcut import Shortcut -from buzz.widgets.preferences_dialog.preferences_dialog import PreferencesDialog +from buzz.widgets.about_dialog import AboutDialog +from buzz.widgets.preferences_dialog.models.preferences import Preferences +from buzz.widgets.preferences_dialog.preferences_dialog import ( + PreferencesDialog, +) class MenuBar(QMenuBar): @@ -17,14 +20,21 @@ class MenuBar(QMenuBar): shortcuts_changed = pyqtSignal(dict) openai_api_key_changed = pyqtSignal(str) default_export_file_name_changed = pyqtSignal(str) + preferences_changed = pyqtSignal(Preferences) + preferences_dialog: Optional[PreferencesDialog] = None def __init__( - self, shortcuts: Dict[str, str], default_export_file_name: str, parent: QWidget + self, + shortcuts: Dict[str, str], + default_export_file_name: str, + preferences: Preferences, + parent: Optional[QWidget] = None, ): super().__init__(parent) self.shortcuts = shortcuts self.default_export_file_name = default_export_file_name + self.preferences = preferences self.import_action = QAction(_("Import Media File..."), self) self.import_action.triggered.connect(self.on_import_action_triggered) @@ -59,6 +69,7 @@ def on_preferences_action_triggered(self): preferences_dialog = PreferencesDialog( shortcuts=self.shortcuts, default_export_file_name=self.default_export_file_name, + preferences=self.preferences, parent=self, ) preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed) @@ -66,8 +77,17 @@ def on_preferences_action_triggered(self): preferences_dialog.default_export_file_name_changed.connect( self.default_export_file_name_changed ) + preferences_dialog.finished.connect(self.on_preferences_dialog_finished) preferences_dialog.open() + self.preferences_dialog = preferences_dialog + + def on_preferences_dialog_finished(self, result): + if result == self.preferences_dialog.DialogCode.Accepted: + updated_preferences = self.preferences_dialog.updated_preferences + self.preferences = updated_preferences + self.preferences_changed.emit(updated_preferences) + def on_help_action_triggered(self): webbrowser.open("https://chidiwilliams.github.io/buzz/docs") diff --git a/buzz/widgets/preferences_dialog/folder_watch_preferences_widget.py b/buzz/widgets/preferences_dialog/folder_watch_preferences_widget.py new file mode 100644 index 000000000..8ade77735 --- /dev/null +++ b/buzz/widgets/preferences_dialog/folder_watch_preferences_widget.py @@ -0,0 +1,137 @@ +from typing import Tuple, Optional + +from PyQt6.QtCore import pyqtSignal +from PyQt6.QtWidgets import ( + QWidget, + QPushButton, + QFormLayout, + QHBoxLayout, + QFileDialog, + QCheckBox, + QVBoxLayout, +) + +from buzz.store.keyring_store import KeyringStore +from buzz.transcriber import ( + TranscriptionOptions, + FileTranscriptionOptions, +) +from buzz.widgets.line_edit import LineEdit +from buzz.widgets.preferences_dialog.models.file_transcription_preferences import ( + FileTranscriptionPreferences, +) +from buzz.widgets.preferences_dialog.models.folder_watch_preferences import ( + FolderWatchPreferences, +) +from buzz.widgets.transcriber.file_transcription_form_widget import ( + FileTranscriptionFormWidget, +) + + +class FolderWatchPreferencesWidget(QWidget): + config_changed = pyqtSignal(FolderWatchPreferences) + + def __init__( + self, config: FolderWatchPreferences, parent: Optional[QWidget] = None + ): + super().__init__(parent) + + self.config = config + + checkbox = QCheckBox("Enable folder watch") + checkbox.setChecked(config.enabled) + checkbox.setObjectName("EnableFolderWatchCheckbox") + checkbox.stateChanged.connect(self.on_enable_changed) + + input_folder_browse_button = QPushButton("Browse") + input_folder_browse_button.clicked.connect(self.on_click_browse_input_folder) + + output_folder_browse_button = QPushButton("Browse") + output_folder_browse_button.clicked.connect(self.on_click_browse_output_folder) + + input_folder_row = QHBoxLayout() + self.input_folder_line_edit = LineEdit(config.input_directory, self) + self.input_folder_line_edit.setPlaceholderText("/path/to/input/folder") + self.input_folder_line_edit.textChanged.connect(self.on_input_folder_changed) + self.input_folder_line_edit.setObjectName("InputFolderLineEdit") + + input_folder_row.addWidget(self.input_folder_line_edit) + input_folder_row.addWidget(input_folder_browse_button) + + output_folder_row = QHBoxLayout() + self.output_folder_line_edit = LineEdit(config.output_directory, self) + self.output_folder_line_edit.setPlaceholderText("/path/to/output/folder") + self.output_folder_line_edit.textChanged.connect(self.on_output_folder_changed) + self.output_folder_line_edit.setObjectName("OutputFolderLineEdit") + + output_folder_row.addWidget(self.output_folder_line_edit) + output_folder_row.addWidget(output_folder_browse_button) + + openai_access_token = KeyringStore().get_password( + KeyringStore.Key.OPENAI_API_KEY + ) + ( + transcription_options, + file_transcription_options, + ) = config.file_transcription_options.to_transcription_options( + openai_access_token=openai_access_token, + file_paths=[], + ) + + transcription_form_widget = FileTranscriptionFormWidget( + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + parent=self, + ) + transcription_form_widget.transcription_options_changed.connect( + self.on_transcription_options_changed + ) + + layout = QVBoxLayout(self) + + folders_form_layout = QFormLayout() + + folders_form_layout.addRow("", checkbox) + folders_form_layout.addRow("Input folder", input_folder_row) + folders_form_layout.addRow("Output folder", output_folder_row) + folders_form_layout.addWidget(transcription_form_widget) + + layout.addLayout(folders_form_layout) + layout.addWidget(transcription_form_widget) + layout.addStretch() + + self.setLayout(layout) + + def on_click_browse_input_folder(self): + folder = QFileDialog.getExistingDirectory(self, "Select Input Folder") + self.input_folder_line_edit.setText(folder) + self.on_input_folder_changed(folder) + + def on_input_folder_changed(self, folder): + self.config.input_directory = folder + self.config_changed.emit(self.config) + + def on_click_browse_output_folder(self): + folder = QFileDialog.getExistingDirectory(self, "Select Output Folder") + self.output_folder_line_edit.setText(folder) + self.on_output_folder_changed(folder) + + def on_output_folder_changed(self, folder): + self.config.output_directory = folder + self.config_changed.emit(self.config) + + def on_enable_changed(self, state: int): + self.config.enabled = state == 2 + self.config_changed.emit(self.config) + + def on_transcription_options_changed( + self, options: Tuple[TranscriptionOptions, FileTranscriptionOptions] + ): + transcription_options, file_transcription_options = options + self.config.file_transcription_options = ( + FileTranscriptionPreferences.from_transcription_options( + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + ) + ) + self.config_changed.emit(self.config) diff --git a/buzz/widgets/preferences_dialog/general_preferences_widget.py b/buzz/widgets/preferences_dialog/general_preferences_widget.py index ef244aeff..bca825632 100644 --- a/buzz/widgets/preferences_dialog/general_preferences_widget.py +++ b/buzz/widgets/preferences_dialog/general_preferences_widget.py @@ -40,7 +40,7 @@ def __init__( ) self.update_test_openai_api_key_button() - layout.addRow("OpenAI API Key", self.openai_api_key_line_edit) + layout.addRow("OpenAI API key", self.openai_api_key_line_edit) layout.addRow("", self.test_openai_api_key_button) default_export_file_name_line_edit = LineEdit(default_export_file_name, self) diff --git a/buzz/widgets/preferences_dialog/models/__init__.py b/buzz/widgets/preferences_dialog/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py b/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py new file mode 100644 index 000000000..f3749f3d0 --- /dev/null +++ b/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py @@ -0,0 +1,96 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Set, List + +from PyQt6.QtCore import QSettings + +from buzz.model_loader import TranscriptionModel +from buzz.transcriber import ( + Task, + OutputFormat, + DEFAULT_WHISPER_TEMPERATURE, + TranscriptionOptions, + FileTranscriptionOptions, +) + + +@dataclass() +class FileTranscriptionPreferences: + language: Optional[str] + task: Task + model: TranscriptionModel + word_level_timings: bool + temperature: Tuple[float, ...] + initial_prompt: str + output_formats: Set["OutputFormat"] + + def save(self, settings: QSettings) -> None: + settings.setValue("language", self.language) + settings.setValue("task", self.task) + settings.setValue("model", self.model) + settings.setValue("word_level_timings", self.word_level_timings) + settings.setValue("temperature", self.temperature) + settings.setValue("initial_prompt", self.initial_prompt) + settings.setValue( + "output_formats", + [output_format.value for output_format in self.output_formats], + ) + + @classmethod + def load(cls, settings: QSettings) -> "FileTranscriptionPreferences": + language = settings.value("language", None) + task = settings.value("task", Task.TRANSCRIBE) + model = settings.value("model", TranscriptionModel()) + word_level_timings = settings.value("word_level_timings", False) + temperature = settings.value("temperature", DEFAULT_WHISPER_TEMPERATURE) + initial_prompt = settings.value("initial_prompt", "") + output_formats = settings.value("output_formats", []) + return FileTranscriptionPreferences( + language=language, + task=task, + model=model, + word_level_timings=word_level_timings, + temperature=temperature, + initial_prompt=initial_prompt, + output_formats=set( + [OutputFormat(output_format) for output_format in output_formats] + ), + ) + + @classmethod + def from_transcription_options( + cls, + transcription_options: TranscriptionOptions, + file_transcription_options: FileTranscriptionOptions, + ) -> "FileTranscriptionPreferences": + return FileTranscriptionPreferences( + task=transcription_options.task, + language=transcription_options.language, + temperature=transcription_options.temperature, + initial_prompt=transcription_options.initial_prompt, + word_level_timings=transcription_options.word_level_timings, + model=transcription_options.model, + output_formats=file_transcription_options.output_formats, + ) + + def to_transcription_options( + self, + openai_access_token: Optional[str], + file_paths: List[str], + default_output_file_name: str = "", + ) -> Tuple[TranscriptionOptions, FileTranscriptionOptions]: + return ( + TranscriptionOptions( + task=self.task, + language=self.language, + temperature=self.temperature, + initial_prompt=self.initial_prompt, + word_level_timings=self.word_level_timings, + model=self.model, + openai_access_token=openai_access_token, + ), + FileTranscriptionOptions( + output_formats=self.output_formats, + file_paths=file_paths, + default_output_file_name=default_output_file_name, + ), + ) diff --git a/buzz/widgets/preferences_dialog/models/folder_watch_preferences.py b/buzz/widgets/preferences_dialog/models/folder_watch_preferences.py new file mode 100644 index 000000000..c0062cc2c --- /dev/null +++ b/buzz/widgets/preferences_dialog/models/folder_watch_preferences.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +from PyQt6.QtCore import QSettings + +from buzz.widgets.preferences_dialog.models.file_transcription_preferences import ( + FileTranscriptionPreferences, +) + + +@dataclass +class FolderWatchPreferences: + enabled: bool + input_directory: str + output_directory: str + file_transcription_options: FileTranscriptionPreferences + + def save(self, settings: QSettings): + settings.setValue("enabled", self.enabled) + settings.setValue("input_folder", self.input_directory) + settings.setValue("output_directory", self.output_directory) + settings.beginGroup("file_transcription_options") + self.file_transcription_options.save(settings) + settings.endGroup() + + @classmethod + def load(cls, settings: QSettings) -> "FolderWatchPreferences": + enabled = settings.value("enabled", defaultValue=False, type=bool) + input_folder = settings.value("input_folder", defaultValue="", type=str) + output_folder = settings.value("output_directory", defaultValue="", type=str) + settings.beginGroup("file_transcription_options") + file_transcription_options = FileTranscriptionPreferences.load(settings) + settings.endGroup() + return FolderWatchPreferences( + enabled=enabled, + input_directory=input_folder, + output_directory=output_folder, + file_transcription_options=file_transcription_options, + ) diff --git a/buzz/widgets/preferences_dialog/models/preferences.py b/buzz/widgets/preferences_dialog/models/preferences.py new file mode 100644 index 000000000..b3d8d0fa4 --- /dev/null +++ b/buzz/widgets/preferences_dialog/models/preferences.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass + +from PyQt6.QtCore import QSettings + +from buzz.widgets.preferences_dialog.models.folder_watch_preferences import ( + FolderWatchPreferences, +) + + +@dataclass +class Preferences: + folder_watch: FolderWatchPreferences + + def save(self, settings: QSettings): + settings.beginGroup("folder_watch") + self.folder_watch.save(settings) + settings.endGroup() + + @classmethod + def load(cls, settings: QSettings) -> "Preferences": + settings.beginGroup("folder_watch") + folder_watch = FolderWatchPreferences.load(settings) + settings.endGroup() + return Preferences(folder_watch=folder_watch) diff --git a/buzz/widgets/preferences_dialog/preferences_dialog.py b/buzz/widgets/preferences_dialog/preferences_dialog.py index d04a6ceae..1874696f6 100644 --- a/buzz/widgets/preferences_dialog/preferences_dialog.py +++ b/buzz/widgets/preferences_dialog/preferences_dialog.py @@ -1,12 +1,20 @@ +import copy from typing import Dict, Optional from PyQt6.QtCore import pyqtSignal from PyQt6.QtWidgets import QDialog, QWidget, QVBoxLayout, QTabWidget, QDialogButtonBox from buzz.locale import _ +from buzz.widgets.preferences_dialog.folder_watch_preferences_widget import ( + FolderWatchPreferencesWidget, +) from buzz.widgets.preferences_dialog.general_preferences_widget import ( GeneralPreferencesWidget, ) +from buzz.widgets.preferences_dialog.models.folder_watch_preferences import ( + FolderWatchPreferences, +) +from buzz.widgets.preferences_dialog.models.preferences import Preferences from buzz.widgets.preferences_dialog.models_preferences_widget import ( ModelsPreferencesWidget, ) @@ -18,16 +26,22 @@ class PreferencesDialog(QDialog): shortcuts_changed = pyqtSignal(dict) openai_api_key_changed = pyqtSignal(str) + folder_watch_config_changed = pyqtSignal(FolderWatchPreferences) default_export_file_name_changed = pyqtSignal(str) + preferences_changed = pyqtSignal(Preferences) def __init__( self, + # TODO: move shortcuts and default export file name into preferences shortcuts: Dict[str, str], default_export_file_name: str, + preferences: Preferences, parent: Optional[QWidget] = None, ) -> None: super().__init__(parent) + self.updated_preferences = copy.deepcopy(preferences) + self.setWindowTitle("Preferences") layout = QVBoxLayout(self) @@ -49,8 +63,15 @@ def __init__( shortcuts_table_widget.shortcuts_changed.connect(self.shortcuts_changed) tab_widget.addTab(shortcuts_table_widget, _("Shortcuts")) + folder_watch_widget = FolderWatchPreferencesWidget( + config=self.updated_preferences.folder_watch, parent=self + ) + folder_watch_widget.config_changed.connect(self.folder_watch_config_changed) + tab_widget.addTab(folder_watch_widget, _("Folder Watch")) + button_box = QDialogButtonBox( - QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel, + self, ) button_box.accepted.connect(self.accept) button_box.rejected.connect(self.reject) @@ -60,4 +81,5 @@ def __init__( self.setLayout(layout) - self.setFixedSize(self.sizeHint()) + self.setMinimumHeight(500) + self.setMinimumWidth(500) diff --git a/buzz/widgets/transcriber/file_transcriber_widget.py b/buzz/widgets/transcriber/file_transcriber_widget.py index 09ff5d7a3..bef88e7ec 100644 --- a/buzz/widgets/transcriber/file_transcriber_widget.py +++ b/buzz/widgets/transcriber/file_transcriber_widget.py @@ -5,28 +5,25 @@ from PyQt6.QtWidgets import ( QWidget, QVBoxLayout, - QCheckBox, - QFormLayout, - QHBoxLayout, QPushButton, ) from buzz.dialogs import show_model_download_error_dialog from buzz.locale import _ -from buzz.model_loader import ModelDownloader, TranscriptionModel, ModelType +from buzz.model_loader import ModelDownloader from buzz.paths import file_paths_as_title from buzz.settings.settings import Settings from buzz.store.keyring_store import KeyringStore from buzz.transcriber import ( FileTranscriptionOptions, TranscriptionOptions, - Task, - DEFAULT_WHISPER_TEMPERATURE, - OutputFormat, ) from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog -from buzz.widgets.transcriber.transcription_options_group_box import ( - TranscriptionOptionsGroupBox, +from buzz.widgets.preferences_dialog.models.file_transcription_preferences import ( + FileTranscriptionPreferences, +) +from buzz.widgets.transcriber.file_transcription_form_widget import ( + FileTranscriptionFormWidget, ) @@ -57,89 +54,34 @@ def __init__( ) self.file_paths = file_paths - default_language = self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value="" - ) - self.transcription_options = TranscriptionOptions( + + preferences = self.load_preferences() + + ( + self.transcription_options, + self.file_transcription_options, + ) = preferences.to_transcription_options( openai_access_token=openai_access_token, - model=self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_MODEL, - default_value=TranscriptionModel(), - ), - task=self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE - ), - language=default_language if default_language != "" else None, - initial_prompt=self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value="" - ), - temperature=self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, - default_value=DEFAULT_WHISPER_TEMPERATURE, - ), - word_level_timings=self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, - default_value=False, - ), - ) - default_export_format_states: List[str] = self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, default_value=[] - ) - self.file_transcription_options = FileTranscriptionOptions( file_paths=self.file_paths, - output_formats=set( - [ - OutputFormat(output_format) - for output_format in default_export_format_states - ] - ), default_output_file_name=default_output_file_name, ) layout = QVBoxLayout(self) - transcription_options_group_box = TranscriptionOptionsGroupBox( - default_transcription_options=self.transcription_options, parent=self + self.form_widget = FileTranscriptionFormWidget( + transcription_options=self.transcription_options, + file_transcription_options=self.file_transcription_options, + parent=self, ) - transcription_options_group_box.transcription_options_changed.connect( - self.on_transcription_options_changed + self.form_widget.openai_access_token_changed.connect( + self.openai_access_token_changed ) - self.word_level_timings_checkbox = QCheckBox(_("Word-level timings")) - self.word_level_timings_checkbox.setChecked( - self.settings.value( - key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, - default_value=False, - ) - ) - self.word_level_timings_checkbox.stateChanged.connect( - self.on_word_level_timings_changed - ) - - file_transcription_layout = QFormLayout() - file_transcription_layout.addRow("", self.word_level_timings_checkbox) - - export_format_layout = QHBoxLayout() - for output_format in OutputFormat: - export_format_checkbox = QCheckBox( - f"{output_format.value.upper()}", parent=self - ) - export_format_checkbox.setChecked( - output_format in self.file_transcription_options.output_formats - ) - export_format_checkbox.stateChanged.connect( - self.get_on_checkbox_state_changed_callback(output_format) - ) - export_format_layout.addWidget(export_format_checkbox) - - file_transcription_layout.addRow("Export:", export_format_layout) - self.run_button = QPushButton(_("Run"), self) self.run_button.setDefault(True) self.run_button.clicked.connect(self.on_click_run) - layout.addWidget(transcription_options_group_box) - layout.addLayout(file_transcription_layout) + layout.addWidget(self.form_widget) layout.addWidget(self.run_button, 0, Qt.AlignmentFlag.AlignRight) self.setLayout(layout) @@ -147,23 +89,19 @@ def __init__( self.reset_transcriber_controls() - def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat): - def on_checkbox_state_changed(state: int): - if state == Qt.CheckState.Checked.value: - self.file_transcription_options.output_formats.add(output_format) - elif state == Qt.CheckState.Unchecked.value: - self.file_transcription_options.output_formats.remove(output_format) - - return on_checkbox_state_changed - - def on_transcription_options_changed( - self, transcription_options: TranscriptionOptions - ): - self.transcription_options = transcription_options - if self.transcription_options.openai_access_token != "": - self.openai_access_token_changed.emit( - self.transcription_options.openai_access_token - ) + def load_preferences(self): + self.settings.settings.beginGroup("file_transcriber") + preferences = FileTranscriptionPreferences.load(settings=self.settings.settings) + self.settings.settings.endGroup() + return preferences + + def save_preferences(self): + self.settings.settings.beginGroup("file_transcriber") + preferences = FileTranscriptionPreferences.from_transcription_options( + self.transcription_options, self.file_transcription_options + ) + preferences.save(settings=self.settings.settings) + self.settings.settings.endGroup() def on_click_run(self): self.run_button.setDisabled(True) @@ -210,11 +148,6 @@ def on_download_model_error(self, error: str): def reset_transcriber_controls(self): self.run_button.setDisabled(False) - self.word_level_timings_checkbox.setDisabled( - self.transcription_options.model.model_type == ModelType.HUGGING_FACE - or self.transcription_options.model.model_type - == ModelType.OPEN_AI_WHISPER_API - ) def on_cancel_model_progress_dialog(self): if self.model_loader is not None: @@ -234,34 +167,5 @@ def on_word_level_timings_changed(self, value: int): def closeEvent(self, event: QtGui.QCloseEvent) -> None: if self.model_loader is not None: self.model_loader.cancel() - - self.settings.set_value( - Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language - ) - self.settings.set_value( - Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task - ) - self.settings.set_value( - Settings.Key.FILE_TRANSCRIBER_TEMPERATURE, - self.transcription_options.temperature, - ) - self.settings.set_value( - Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, - self.transcription_options.initial_prompt, - ) - self.settings.set_value( - Settings.Key.FILE_TRANSCRIBER_MODEL, self.transcription_options.model - ) - self.settings.set_value( - key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS, - value=self.transcription_options.word_level_timings, - ) - self.settings.set_value( - key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, - value=[ - export_format.value - for export_format in self.file_transcription_options.output_formats - ], - ) - + self.save_preferences() super().closeEvent(event) diff --git a/buzz/widgets/transcriber/file_transcription_form_widget.py b/buzz/widgets/transcriber/file_transcription_form_widget.py new file mode 100644 index 000000000..ef0c09682 --- /dev/null +++ b/buzz/widgets/transcriber/file_transcription_form_widget.py @@ -0,0 +1,110 @@ +from typing import Optional + +from PyQt6.QtCore import pyqtSignal, Qt +from PyQt6.QtWidgets import QWidget, QVBoxLayout, QCheckBox, QFormLayout, QHBoxLayout + +from buzz.locale import _ +from buzz.model_loader import ModelType +from buzz.transcriber import ( + TranscriptionOptions, + FileTranscriptionOptions, + OutputFormat, +) +from buzz.widgets.transcriber.transcription_options_group_box import ( + TranscriptionOptionsGroupBox, +) + + +class FileTranscriptionFormWidget(QWidget): + openai_access_token_changed = pyqtSignal(str) + transcription_options_changed = pyqtSignal(tuple) + + def __init__( + self, + transcription_options: TranscriptionOptions, + file_transcription_options: FileTranscriptionOptions, + parent: Optional[QWidget] = None, + ): + super().__init__(parent) + + self.transcription_options = transcription_options + self.file_transcription_options = file_transcription_options + + layout = QVBoxLayout(self) + + transcription_options_group_box = TranscriptionOptionsGroupBox( + default_transcription_options=self.transcription_options, parent=self + ) + transcription_options_group_box.transcription_options_changed.connect( + self.on_transcription_options_changed + ) + + self.word_level_timings_checkbox = QCheckBox(_("Word-level timings")) + self.word_level_timings_checkbox.setChecked( + self.transcription_options.word_level_timings + ) + self.word_level_timings_checkbox.stateChanged.connect( + self.on_word_level_timings_changed + ) + + file_transcription_layout = QFormLayout() + file_transcription_layout.addRow("", self.word_level_timings_checkbox) + + export_format_layout = QHBoxLayout() + for output_format in OutputFormat: + export_format_checkbox = QCheckBox( + f"{output_format.value.upper()}", parent=self + ) + export_format_checkbox.setChecked( + output_format in self.file_transcription_options.output_formats + ) + export_format_checkbox.stateChanged.connect( + self.get_on_checkbox_state_changed_callback(output_format) + ) + export_format_layout.addWidget(export_format_checkbox) + + file_transcription_layout.addRow("Export:", export_format_layout) + + layout.addWidget(transcription_options_group_box) + layout.addLayout(file_transcription_layout) + self.setLayout(layout) + + self.reset_word_level_timings() + + def on_transcription_options_changed( + self, transcription_options: TranscriptionOptions + ): + self.transcription_options = transcription_options + self.reset_word_level_timings() + self.transcription_options_changed.emit( + (self.transcription_options, self.file_transcription_options) + ) + if self.transcription_options.openai_access_token != "": + self.openai_access_token_changed.emit( + self.transcription_options.openai_access_token + ) + + def on_word_level_timings_changed(self, value: int): + self.transcription_options.word_level_timings = ( + value == Qt.CheckState.Checked.value + ) + + def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat): + def on_checkbox_state_changed(state: int): + if state == Qt.CheckState.Checked.value: + self.file_transcription_options.output_formats.add(output_format) + elif state == Qt.CheckState.Unchecked.value: + self.file_transcription_options.output_formats.remove(output_format) + + self.transcription_options_changed.emit( + (self.transcription_options, self.file_transcription_options) + ) + + return on_checkbox_state_changed + + def reset_word_level_timings(self): + self.word_level_timings_checkbox.setDisabled( + self.transcription_options.model.model_type == ModelType.HUGGING_FACE + or self.transcription_options.model.model_type + == ModelType.OPEN_AI_WHISPER_API + ) diff --git a/buzz/widgets/transcription_task_folder_watcher.py b/buzz/widgets/transcription_task_folder_watcher.py new file mode 100644 index 000000000..e1d59d1f1 --- /dev/null +++ b/buzz/widgets/transcription_task_folder_watcher.py @@ -0,0 +1,76 @@ +import logging +import os +from typing import Dict + +from PyQt6.QtCore import QFileSystemWatcher, pyqtSignal, QObject + +from buzz.store.keyring_store import KeyringStore +from buzz.transcriber import FileTranscriptionTask +from buzz.widgets.preferences_dialog.models.folder_watch_preferences import ( + FolderWatchPreferences, +) + + +class TranscriptionTaskFolderWatcher(QFileSystemWatcher): + preferences: FolderWatchPreferences + task_found = pyqtSignal(FileTranscriptionTask) + + def __init__( + self, + tasks: Dict[int, FileTranscriptionTask], + preferences: FolderWatchPreferences, + default_export_file_name: str, + parent: QObject = None, + ): + super().__init__(parent) + self.tasks = tasks + self.default_export_file_name = default_export_file_name + self.set_preferences(preferences) + self.directoryChanged.connect(self.find_tasks) + + def set_preferences(self, preferences: FolderWatchPreferences): + self.preferences = preferences + if len(self.directories()) > 0: + self.removePaths(self.directories()) + if preferences.enabled: + self.addPath(preferences.input_directory) + logging.debug( + 'Watching for media files in "%s"', preferences.input_directory + ) + + def find_tasks(self): + input_directory = self.preferences.input_directory + tasks = {task.file_path: task for task in self.tasks.values()} + for dirpath, dirnames, filenames in os.walk(input_directory): + for filename in filenames: + file_path = os.path.join(dirpath, filename) + if ( + filename.startswith(".") # hidden files + or file_path in tasks # file already in tasks + ): + continue + + openai_access_token = KeyringStore().get_password( + KeyringStore.Key.OPENAI_API_KEY + ) + ( + transcription_options, + file_transcription_options, + ) = self.preferences.file_transcription_options.to_transcription_options( + openai_access_token=openai_access_token, + default_output_file_name=self.default_export_file_name, + file_paths=[file_path], + ) + model_path = transcription_options.model.get_local_model_path() + task = FileTranscriptionTask( + file_path=file_path, + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + model_path=model_path, + output_directory=self.preferences.output_directory, + source=FileTranscriptionTask.Source.FOLDER_WATCH, + ) + self.task_found.emit(task) + + # Don't traverse into subdirectories + break diff --git a/buzz/widgets/transcription_viewer/export_transcription_button.py b/buzz/widgets/transcription_viewer/export_transcription_button.py index 24aa91f2f..e240c75cf 100644 --- a/buzz/widgets/transcription_viewer/export_transcription_button.py +++ b/buzz/widgets/transcription_viewer/export_transcription_button.py @@ -5,7 +5,7 @@ from buzz.transcriber import ( FileTranscriptionTask, OutputFormat, - get_default_output_file_path, + get_output_file_path, write_output, ) from buzz.widgets.icon import FileDownloadIcon @@ -30,7 +30,7 @@ def __init__(self, transcription_task: FileTranscriptionTask, parent: QWidget): def on_menu_triggered(self, action: QAction): output_format = OutputFormat[action.text()] - default_path = get_default_output_file_path( + default_path = get_output_file_path( task=self.transcription_task, output_format=output_format ) diff --git a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py index 210f667c9..16be8320a 100644 --- a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py +++ b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py @@ -62,7 +62,6 @@ def set_segment_text(self, text: str): self.task_changed.emit() -# TODO: Fix player duration and add spacer below class TranscriptionViewerWidget(QWidget): transcription_task: FileTranscriptionTask task_changed = pyqtSignal() @@ -114,6 +113,7 @@ def __init__( self.current_segment_label = QLabel() self.current_segment_label.setText("") self.current_segment_label.setAlignment(Qt.AlignmentFlag.AlignHCenter) + self.current_segment_label.setContentsMargins(0, 0, 0, 10) buttons_layout = QHBoxLayout() buttons_layout.addStretch() diff --git a/tests/transcriber_test.py b/tests/transcriber_test.py index 24e4bc48e..8cf9dade0 100644 --- a/tests/transcriber_test.py +++ b/tests/transcriber_test.py @@ -2,6 +2,7 @@ import os import pathlib import platform +import shutil import tempfile import time from typing import List @@ -21,7 +22,7 @@ WhisperCpp, WhisperCppFileTranscriber, WhisperFileTranscriber, - get_default_output_file_path, + get_output_file_path, to_timestamp, whisper_cpp_params, write_output, @@ -159,24 +160,34 @@ def test_transcribe( class TestWhisperFileTranscriber: @pytest.mark.parametrize( - "output_format,expected_file_path,default_output_file_name", + "file_path,output_format,expected_file_path,default_output_file_name", [ - ( + pytest.param( + "/a/b/c.mp4", OutputFormat.SRT, "/a/b/c-translate--Whisper-tiny.srt", "{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}", + marks=pytest.mark.skipif(platform.system() == "Windows", reason=""), + ), + pytest.param( + "C:\\a\\b\\c.mp4", + OutputFormat.SRT, + "C:\\a\\b\\c-translate--Whisper-tiny.srt", + "{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}", + marks=pytest.mark.skipif(platform.system() != "Windows", reason=""), ), ], ) - def test_default_output_file2( + def test_default_output_file( self, + file_path: str, output_format: OutputFormat, expected_file_path: str, default_output_file_name: str, ): - file_path = get_default_output_file_path( + file_path = get_output_file_path( task=FileTranscriptionTask( - file_path="/a/b/c.mp4", + file_path=file_path, transcription_options=TranscriptionOptions(task=Task.TRANSLATE), file_transcription_options=FileTranscriptionOptions( file_paths=[], default_output_file_name=default_output_file_name @@ -187,10 +198,27 @@ def test_default_output_file2( ) assert file_path == expected_file_path - def test_default_output_file(self): - srt = get_default_output_file_path( + @pytest.mark.parametrize( + "file_path,expected_starts_with", + [ + pytest.param( + "/a/b/c.mp4", + "/a/b/c (Translated on ", + marks=pytest.mark.skipif(platform.system() == "Windows", reason=""), + ), + pytest.param( + "C:\\a\\b\\c.mp4", + "C:\\a\\b\\c (Translated on ", + marks=pytest.mark.skipif(platform.system() != "Windows", reason=""), + ), + ], + ) + def test_default_output_file_with_date( + self, file_path: str, expected_starts_with: str + ): + srt = get_output_file_path( task=FileTranscriptionTask( - file_path="/a/b/c.mp4", + file_path=file_path, transcription_options=TranscriptionOptions(task=Task.TRANSLATE), file_transcription_options=FileTranscriptionOptions( file_paths=[], @@ -200,12 +228,13 @@ def test_default_output_file(self): ), output_format=OutputFormat.TXT, ) - assert srt.startswith("/a/b/c (Translated on ") + + assert srt.startswith(expected_starts_with) assert srt.endswith(".txt") - srt = get_default_output_file_path( + srt = get_output_file_path( task=FileTranscriptionTask( - file_path="/a/b/c.mp4", + file_path=file_path, transcription_options=TranscriptionOptions(task=Task.TRANSLATE), file_transcription_options=FileTranscriptionOptions( file_paths=[], @@ -215,7 +244,7 @@ def test_default_output_file(self): ), output_format=OutputFormat.SRT, ) - assert srt.startswith("/a/b/c (Translated on ") + assert srt.startswith(expected_starts_with) assert srt.endswith(".srt") @pytest.mark.parametrize( @@ -327,6 +356,43 @@ def test_transcribe( assert len(segments[i].text) > 0 logging.debug(f"{segments[i].start} {segments[i].end} {segments[i].text}") + def test_transcribe_from_folder_watch_source(self, qtbot): + file_path = tempfile.mktemp(suffix=".mp3") + shutil.copy("testdata/whisper-french.mp3", file_path) + + file_transcription_options = FileTranscriptionOptions( + file_paths=[file_path], + output_formats={OutputFormat.TXT}, + default_output_file_name="{{ input_file_name }}", + ) + transcription_options = TranscriptionOptions() + model_path = get_model_path(transcription_options.model) + + output_directory = tempfile.mkdtemp() + transcriber = WhisperFileTranscriber( + task=FileTranscriptionTask( + model_path=model_path, + transcription_options=transcription_options, + file_transcription_options=file_transcription_options, + file_path=file_path, + output_directory=output_directory, + source=FileTranscriptionTask.Source.FOLDER_WATCH, + ) + ) + with qtbot.wait_signal(transcriber.completed, timeout=10 * 6000): + transcriber.run() + + assert not os.path.isfile(file_path) + assert os.path.isfile( + os.path.join(output_directory, os.path.basename(file_path)) + ) + assert os.path.isfile( + os.path.join( + output_directory, + os.path.splitext(os.path.basename(file_path))[0] + ".txt", + ) + ) + @pytest.mark.skip() def test_transcribe_stop(self): output_file_path = os.path.join(tempfile.gettempdir(), "whisper.txt") diff --git a/tests/widgets/menu_bar_test.py b/tests/widgets/menu_bar_test.py new file mode 100644 index 000000000..27b9bf079 --- /dev/null +++ b/tests/widgets/menu_bar_test.py @@ -0,0 +1,25 @@ +from PyQt6.QtCore import QSettings + +from buzz.settings.settings import Settings +from buzz.settings.shortcut_settings import ShortcutSettings +from buzz.widgets.menu_bar import MenuBar +from buzz.widgets.preferences_dialog.models.preferences import Preferences +from buzz.widgets.preferences_dialog.preferences_dialog import PreferencesDialog + + +class TestMenuBar: + def test_open_preferences_dialog(self, qtbot): + menu_bar = MenuBar( + shortcuts=ShortcutSettings(Settings()).load(), + default_export_file_name="", + preferences=Preferences.load(QSettings()), + ) + qtbot.add_widget(menu_bar) + + preferences_dialog = menu_bar.findChild(PreferencesDialog) + assert preferences_dialog is None + + menu_bar.preferences_action.trigger() + + preferences_dialog = menu_bar.findChild(PreferencesDialog) + assert isinstance(preferences_dialog, PreferencesDialog) diff --git a/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py b/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py new file mode 100644 index 000000000..d19d43fb4 --- /dev/null +++ b/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py @@ -0,0 +1,55 @@ +from unittest.mock import Mock + +from PyQt6.QtWidgets import QCheckBox, QLineEdit + +from buzz.model_loader import TranscriptionModel +from buzz.transcriber import Task, DEFAULT_WHISPER_TEMPERATURE +from buzz.widgets.preferences_dialog.folder_watch_preferences_widget import ( + FolderWatchPreferencesWidget, +) +from buzz.widgets.preferences_dialog.models.file_transcription_preferences import ( + FileTranscriptionPreferences, +) +from buzz.widgets.preferences_dialog.models.folder_watch_preferences import ( + FolderWatchPreferences, +) + + +class TestFolderWatchPreferencesWidget: + def test_edit_folder_watch_preferences(self, qtbot): + widget = FolderWatchPreferencesWidget( + config=FolderWatchPreferences( + enabled=False, + input_directory="", + output_directory="", + file_transcription_options=FileTranscriptionPreferences( + language=None, + task=Task.TRANSCRIBE, + model=TranscriptionModel(), + word_level_timings=False, + temperature=DEFAULT_WHISPER_TEMPERATURE, + initial_prompt="", + output_formats=set(), + ), + ), + ) + mock_config_changed = Mock() + widget.config_changed.connect(mock_config_changed) + qtbot.add_widget(widget) + + checkbox = widget.findChild(QCheckBox, "EnableFolderWatchCheckbox") + input_folder_line_edit = widget.findChild(QLineEdit, "InputFolderLineEdit") + output_folder_line_edit = widget.findChild(QLineEdit, "OutputFolderLineEdit") + + assert not checkbox.isChecked() + assert input_folder_line_edit.text() == "" + assert output_folder_line_edit.text() == "" + + checkbox.setChecked(True) + input_folder_line_edit.setText("test/input/folder") + output_folder_line_edit.setText("test/output/folder") + + last_config_changed_call = mock_config_changed.call_args_list[-1] + assert last_config_changed_call[0][0].enabled + assert last_config_changed_call[0][0].input_directory == "test/input/folder" + assert last_config_changed_call[0][0].output_directory == "test/output/folder" diff --git a/tests/widgets/preferences_dialog/preferences_dialog_test.py b/tests/widgets/preferences_dialog/preferences_dialog_test.py index 7c15402c4..c89e7fe00 100644 --- a/tests/widgets/preferences_dialog/preferences_dialog_test.py +++ b/tests/widgets/preferences_dialog/preferences_dialog_test.py @@ -1,19 +1,26 @@ +from PyQt6.QtCore import QSettings from PyQt6.QtWidgets import QTabWidget from pytestqt.qtbot import QtBot +from buzz.widgets.preferences_dialog.models.preferences import Preferences from buzz.widgets.preferences_dialog.preferences_dialog import PreferencesDialog class TestPreferencesDialog: def test_create(self, qtbot: QtBot): - dialog = PreferencesDialog(shortcuts={}, default_export_file_name="") + dialog = PreferencesDialog( + shortcuts={}, + default_export_file_name="", + preferences=Preferences.load(QSettings()), + ) qtbot.add_widget(dialog) assert dialog.windowTitle() == "Preferences" tab_widget = dialog.findChild(QTabWidget) assert isinstance(tab_widget, QTabWidget) - assert tab_widget.count() == 3 + assert tab_widget.count() == 4 assert tab_widget.tabText(0) == "General" assert tab_widget.tabText(1) == "Models" assert tab_widget.tabText(2) == "Shortcuts" + assert tab_widget.tabText(3) == "Folder Watch" diff --git a/tests/widgets/transcription_task_folder_watcher_test.py b/tests/widgets/transcription_task_folder_watcher_test.py new file mode 100644 index 000000000..faafdd7f3 --- /dev/null +++ b/tests/widgets/transcription_task_folder_watcher_test.py @@ -0,0 +1,105 @@ +import os +import shutil +from tempfile import mkdtemp + +from pytestqt.qtbot import QtBot + +from buzz.model_loader import TranscriptionModel +from buzz.transcriber import ( + Task, + DEFAULT_WHISPER_TEMPERATURE, + FileTranscriptionTask, + TranscriptionOptions, + FileTranscriptionOptions, +) +from buzz.widgets.preferences_dialog.models.file_transcription_preferences import ( + FileTranscriptionPreferences, +) +from buzz.widgets.preferences_dialog.models.folder_watch_preferences import ( + FolderWatchPreferences, +) +from buzz.widgets.transcription_task_folder_watcher import ( + TranscriptionTaskFolderWatcher, +) + + +class TestTranscriptionTaskFolderWatcher: + def test_should_add_task_not_in_tasks(self, qtbot: QtBot): + input_directory = mkdtemp() + watcher = TranscriptionTaskFolderWatcher( + tasks={}, + preferences=FolderWatchPreferences( + enabled=True, + input_directory=input_directory, + output_directory="/path/to/output/folder", + file_transcription_options=FileTranscriptionPreferences( + language=None, + task=Task.TRANSCRIBE, + model=TranscriptionModel(), + word_level_timings=False, + temperature=DEFAULT_WHISPER_TEMPERATURE, + initial_prompt="", + output_formats=set(), + ), + ), + default_export_file_name="", + ) + + shutil.copy("testdata/whisper-french.mp3", input_directory) + + with qtbot.wait_signal(watcher.task_found, timeout=10_000) as blocker: + pass + + task: FileTranscriptionTask = blocker.args[0] + assert task.file_path == os.path.join(input_directory, "whisper-french.mp3") + assert task.source == FileTranscriptionTask.Source.FOLDER_WATCH + assert task.output_directory == "/path/to/output/folder" + + def test_should_not_add_task_in_tasks(self, qtbot): + input_directory = mkdtemp() + tasks = { + 1: FileTranscriptionTask( + file_path=os.path.join(input_directory, "whisper-french.mp3"), + transcription_options=TranscriptionOptions(), + file_transcription_options=FileTranscriptionOptions(file_paths=[]), + output_directory="/path/to/output/folder", + model_path="", + ), + } + + watcher = TranscriptionTaskFolderWatcher( + tasks=tasks, + preferences=FolderWatchPreferences( + enabled=True, + input_directory=input_directory, + output_directory="/path/to/output/folder", + file_transcription_options=FileTranscriptionPreferences( + language=None, + task=Task.TRANSCRIBE, + model=TranscriptionModel(), + word_level_timings=False, + temperature=DEFAULT_WHISPER_TEMPERATURE, + initial_prompt="", + output_formats=set(), + ), + ), + default_export_file_name="", + ) + + # Ignored because already in tasks + shutil.copy( + "testdata/whisper-french.mp3", + os.path.join(input_directory, "whisper-french.mp3"), + ) + shutil.copy( + "testdata/whisper-french.mp3", + os.path.join(input_directory, "whisper-french2.mp3"), + ) + + with qtbot.wait_signal(watcher.task_found, timeout=10_000) as blocker: + pass + + task: FileTranscriptionTask = blocker.args[0] + assert task.file_path == os.path.join(input_directory, "whisper-french2.mp3") + assert task.source == FileTranscriptionTask.Source.FOLDER_WATCH + assert task.output_directory == "/path/to/output/folder"