From 730baded0d7c230d58f413eecc1ff059132ec699 Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Thu, 21 Dec 2023 00:36:05 +0000 Subject: [PATCH 1/3] feat: update transcription tasks columns --- Makefile | 2 +- buzz/model_loader.py | 22 ++- buzz/settings/settings.py | 10 + buzz/transcriber.py | 37 ++++ .../transcription_tasks_table_widget.py | 178 +++++++++++++----- 5 files changed, 195 insertions(+), 54 deletions(-) diff --git a/Makefile b/Makefile index ed4acd333..cc69c51a8 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ else endif clean: - rm -f $(LIBWHISPER) + rm -f buzz/$(LIBWHISPER) rm -f buzz/whisper_cpp.py rm -rf dist/* || true diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 6f7b47226..129082030 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -31,6 +31,9 @@ def to_faster_whisper_model_size(self) -> str: return "large-v2" return self.value + def __str__(self): + return self.value.capitalize() + class ModelType(enum.Enum): WHISPER = "Whisper" @@ -46,6 +49,21 @@ class TranscriptionModel: whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY hugging_face_model_id: Optional[str] = None + def __str__(self): + match self.model_type: + case ModelType.WHISPER: + return f"Whisper ({self.whisper_model_size})" + case ModelType.WHISPER_CPP: + return f"Whisper.cpp ({self.whisper_model_size})" + case ModelType.HUGGING_FACE: + return f"Hugging Face ({self.hugging_face_model_id})" + case ModelType.FASTER_WHISPER: + return f"Faster Whisper ({self.whisper_model_size})" + case ModelType.OPEN_AI_WHISPER_API: + return "OpenAI Whisper API" + case _: + raise Exception("Unknown model type") + def is_deletable(self): return ( self.model_type == ModelType.WHISPER @@ -110,12 +128,12 @@ def get_local_model_path(self) -> Optional[str]: "base": "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe", "small": "1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b", "medium": "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208", - "large": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487", + "large": "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2", } def get_hugging_face_file_url(author: str, repository_name: str, filename: str): - return f"https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}" + return f"https://huggingface.co/{author}/{repository_name}/resolve/bf8b606c2fcd9173605cdf6bd2ac8a75a8141b6c/{filename}" def get_whisper_cpp_file_path(size: WhisperModelSize) -> str: diff --git a/buzz/settings/settings.py b/buzz/settings/settings.py index 4ee430aea..90fc83b16 100644 --- a/buzz/settings/settings.py +++ b/buzz/settings/settings.py @@ -29,6 +29,10 @@ class Key(enum.Enum): SHORTCUTS = "shortcuts" + TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY = ( + "transcription-tasks-table/column-visibility" + ) + def set_value(self, key: Key, value: typing.Any) -> None: self.settings.setValue(key.value, value) @@ -46,3 +50,9 @@ def value( def clear(self): self.settings.clear() + + def begin_group(self, group: Key): + self.settings.beginGroup(group.value) + + def end_group(self): + self.settings.endGroup() diff --git a/buzz/transcriber.py b/buzz/transcriber.py index a87a86e7c..2902e00ae 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -27,6 +27,7 @@ from . import transformers_whisper from .conn import pipe_stderr +from .locale import _ from .model_loader import TranscriptionModel, ModelType # Catch exception from whisper.dll not getting loaded. @@ -71,6 +72,12 @@ class TranscriptionOptions: ) +def humanize_language(language: str) -> str: + if language == "": + return _("Detect Language") + return LANGUAGES[language].title() + + @dataclass() class FileTranscriptionOptions: file_paths: List[str] @@ -101,6 +108,36 @@ class Status(enum.Enum): started_at: Optional[datetime.datetime] = None completed_at: Optional[datetime.datetime] = None + def status_text(self) -> str: + if self.status == FileTranscriptionTask.Status.IN_PROGRESS: + return f'{_("In Progress")} ({self.fraction_completed :.0%})' + elif self.status == FileTranscriptionTask.Status.COMPLETED: + status = _("Completed") + if self.started_at is not None and self.completed_at is not None: + status += ( + f" ({self.format_timedelta(self.completed_at - self.started_at)})" + ) + return status + elif self.status == FileTranscriptionTask.Status.FAILED: + return f'{_("Failed")} ({self.error})' + elif self.status == FileTranscriptionTask.Status.CANCELED: + return _("Canceled") + elif self.status == FileTranscriptionTask.Status.QUEUED: + return _("Queued") + return "" + + @staticmethod + def format_timedelta(delta: datetime.timedelta): + mm, ss = divmod(delta.seconds, 60) + result = f"{ss}s" + if mm == 0: + return result + hh, mm = divmod(mm, 60) + result = f"{mm}m {result}" + if hh == 0: + return result + return f"{hh}h {result}" + class OutputFormat(enum.Enum): TXT = "txt" diff --git a/buzz/widgets/transcription_tasks_table_widget.py b/buzz/widgets/transcription_tasks_table_widget.py index cd79fcd3f..41a221562 100644 --- a/buzz/widgets/transcription_tasks_table_widget.py +++ b/buzz/widgets/transcription_tasks_table_widget.py @@ -1,21 +1,41 @@ -import datetime import enum import os +from dataclasses import dataclass from enum import auto -from typing import Optional +from typing import Optional, Callable from PyQt6 import QtGui from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex -from PyQt6.QtWidgets import QTableWidget, QWidget, QAbstractItemView, QTableWidgetItem +from PyQt6.QtWidgets import ( + QTableWidget, + QWidget, + QAbstractItemView, + QTableWidgetItem, + QMenu, +) from buzz.locale import _ -from buzz.transcriber import FileTranscriptionTask +from buzz.settings.settings import Settings +from buzz.transcriber import FileTranscriptionTask, humanize_language + + +@dataclass +class TableColDef: + id: str + header: str + column_index: int + value_getter: Callable[..., str] + width: Optional[int] = None + hidden: bool = False + hidden_toggleable: bool = True class TranscriptionTasksTableWidget(QTableWidget): class Column(enum.Enum): TASK_ID = 0 FILE_NAME = auto() + MODEL = auto() + TASK = auto() STATUS = auto() return_clicked = pyqtSignal() @@ -25,69 +45,125 @@ def __init__(self, parent: Optional[QWidget] = None): self.setRowCount(0) self.setAlternatingRowColors(True) + self.settings = Settings() + + self.column_definitions = [ + TableColDef( + id="id", + header=_("ID"), + column_index=self.Column.TASK_ID.value, + value_getter=lambda task: str(task.id), + width=0, + hidden=True, + hidden_toggleable=False, + ), + TableColDef( + id="file_name", + header=_("File Name"), + column_index=self.Column.FILE_NAME.value, + value_getter=lambda task: os.path.basename(task.file_path), + width=250, + hidden_toggleable=False, + ), + TableColDef( + id="model", + header=_("Model"), + column_index=self.Column.MODEL.value, + value_getter=lambda task: str(task.transcription_options.model), + width=180, + hidden=True, + ), + TableColDef( + id="task", + header=_("Task"), + column_index=self.Column.TASK.value, + value_getter=lambda task: self.get_task_label(task), + width=180, + hidden=True, + ), + TableColDef( + id="status", + header=_("Status"), + column_index=self.Column.STATUS.value, + value_getter=lambda task: task.status_text(), + width=180, + hidden_toggleable=False, + ), + ] - self.setColumnCount(3) - self.setColumnHidden(0, True) - + self.setColumnCount(len(self.column_definitions)) self.verticalHeader().hide() - self.setHorizontalHeaderLabels([_("ID"), _("File Name"), _("Status")]) - self.setColumnWidth(self.Column.FILE_NAME.value, 250) - self.setColumnWidth(self.Column.STATUS.value, 180) + self.setHorizontalHeaderLabels( + [definition.header for definition in self.column_definitions] + ) + for definition in self.column_definitions: + if definition.width is not None: + self.setColumnWidth(definition.column_index, definition.width) + self.load_column_visibility() + self.horizontalHeader().setMinimumSectionSize(180) self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) + def contextMenuEvent(self, event): + menu = QMenu(self) + for definition in self.column_definitions: + if not definition.hidden_toggleable: + continue + action = menu.addAction(definition.header) + action.setCheckable(True) + action.setChecked(not self.isColumnHidden(definition.column_index)) + action.toggled.connect( + lambda checked, + column_index=definition.column_index: self.on_column_checked( + column_index, checked + ) + ) + menu.exec(event.globalPos()) + + def on_column_checked(self, column_index: int, checked: bool): + self.setColumnHidden(column_index, not checked) + self.save_column_visibility() + + def save_column_visibility(self): + self.settings.begin_group( + Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY + ) + for definition in self.column_definitions: + self.settings.settings.setValue( + definition.id, not self.isColumnHidden(definition.column_index) + ) + self.settings.end_group() + + def load_column_visibility(self): + self.settings.begin_group( + Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY + ) + for definition in self.column_definitions: + visible = self.settings.settings.value(definition.id, not definition.hidden) + self.setColumnHidden(definition.column_index, not visible) + self.settings.end_group() + def upsert_task(self, task: FileTranscriptionTask): task_row_index = self.task_row_index(task.id) if task_row_index is None: self.insertRow(self.rowCount()) row_index = self.rowCount() - 1 - task_id_widget_item = QTableWidgetItem(str(task.id)) - self.setItem(row_index, self.Column.TASK_ID.value, task_id_widget_item) - - file_name_widget_item = QTableWidgetItem(os.path.basename(task.file_path)) - file_name_widget_item.setFlags( - file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable - ) - self.setItem(row_index, self.Column.FILE_NAME.value, file_name_widget_item) - - status_widget_item = QTableWidgetItem(self.get_status_text(task)) - status_widget_item.setFlags( - status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable - ) - self.setItem(row_index, self.Column.STATUS.value, status_widget_item) + for definition in self.column_definitions: + item = QTableWidgetItem(definition.value_getter(task)) + item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.setItem(row_index, definition.column_index, item) else: status_widget = self.item(task_row_index, self.Column.STATUS.value) - status_widget.setText(self.get_status_text(task)) - - @staticmethod - def format_timedelta(delta: datetime.timedelta): - mm, ss = divmod(delta.seconds, 60) - result = f"{ss}s" - if mm == 0: - return result - hh, mm = divmod(mm, 60) - result = f"{mm}m {result}" - if hh == 0: - return result - return f"{hh}h {result}" + status_widget.setText(task.status_text()) @staticmethod - def get_status_text(task: FileTranscriptionTask): - if task.status == FileTranscriptionTask.Status.IN_PROGRESS: - return f'{_("In Progress")} ({task.fraction_completed :.0%})' - elif task.status == FileTranscriptionTask.Status.COMPLETED: - status = _("Completed") - if task.started_at is not None and task.completed_at is not None: - status += f" ({TranscriptionTasksTableWidget.format_timedelta(task.completed_at - task.started_at)})" - return status - elif task.status == FileTranscriptionTask.Status.FAILED: - return f'{_("Failed")} ({task.error})' - elif task.status == FileTranscriptionTask.Status.CANCELED: - return _("Canceled") - elif task.status == FileTranscriptionTask.Status.QUEUED: - return _("Queued") + def get_task_label(task: FileTranscriptionTask) -> str: + value = task.transcription_options.task.value.capitalize() + if task.transcription_options.language is not None: + value += f" ({humanize_language(task.transcription_options.language)})" + return value def clear_task(self, task_id: int): task_row_index = self.task_row_index(task_id) From 7a2295e6931bd7cbfd96b336d45bd696d771f8ec Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Thu, 21 Dec 2023 01:37:44 +0000 Subject: [PATCH 2/3] feat: update transcription tasks columns --- buzz/settings/settings.py | 4 ++ tests/widgets/main_window_test.py | 37 +++++++++------- .../transcription_tasks_table_widget_test.py | 44 +++++++++++++++---- 3 files changed, 62 insertions(+), 23 deletions(-) diff --git a/buzz/settings/settings.py b/buzz/settings/settings.py index 90fc83b16..f5e74a620 100644 --- a/buzz/settings/settings.py +++ b/buzz/settings/settings.py @@ -9,6 +9,7 @@ class Settings: def __init__(self): self.settings = QSettings(APP_NAME) + self.settings.sync() class Key(enum.Enum): RECORDING_TRANSCRIBER_TASK = "recording-transcriber/task" @@ -56,3 +57,6 @@ def begin_group(self, group: Key): def end_group(self): self.settings.endGroup() + + def sync(self): + self.settings.sync() diff --git a/tests/widgets/main_window_test.py b/tests/widgets/main_window_test.py index 6a3fa9666..f9caf117b 100644 --- a/tests/widgets/main_window_test.py +++ b/tests/widgets/main_window_test.py @@ -10,6 +10,7 @@ from pytestqt.qtbot import QtBot from buzz.cache import TasksCache +from buzz.settings.settings import Settings from buzz.transcriber import ( FileTranscriptionTask, TranscriptionOptions, @@ -69,7 +70,7 @@ def test_should_run_transcription_task(self, qtbot: QtBot, tasks_cache): table_widget: QTableWidget = window.findChild(QTableWidget) qtbot.wait_until( - self._assert_task_status(table_widget, 0, "Completed"), + self.get_assert_task_status_callback(table_widget, 0, "Completed"), timeout=2 * 60 * 1000, ) @@ -88,19 +89,18 @@ def test_should_run_and_cancel_transcription_task(self, qtbot, tasks_cache): table_widget: QTableWidget = window.findChild(QTableWidget) - def assert_task_in_progress(): - assert table_widget.rowCount() > 0 - assert table_widget.item(0, 1).text() == "whisper-french.mp3" - assert "In Progress" in table_widget.item(0, 2).text() - - qtbot.wait_until(assert_task_in_progress, timeout=2 * 60 * 1000) + qtbot.wait_until( + self.get_assert_task_status_callback(table_widget, 0, "In Progress"), + timeout=2 * 60 * 1000, + ) # Stop task in progress table_widget.selectRow(0) window.toolbar.stop_transcription_action.trigger() qtbot.wait_until( - self._assert_task_status(table_widget, 0, "Canceled"), timeout=60 * 1000 + self.get_assert_task_status_callback(table_widget, 0, "Canceled"), + timeout=60 * 1000, ) table_widget.selectRow(0) @@ -117,15 +117,15 @@ def test_should_load_tasks_from_cache(self, qtbot, tasks_cache): table_widget: QTableWidget = window.findChild(QTableWidget) assert table_widget.rowCount() == 3 - assert table_widget.item(0, 2).text() == "Completed" + assert table_widget.item(0, 4).text() == "Completed" table_widget.selectRow(0) assert window.toolbar.open_transcript_action.isEnabled() - assert table_widget.item(1, 2).text() == "Canceled" + assert table_widget.item(1, 4).text() == "Canceled" table_widget.selectRow(1) assert window.toolbar.open_transcript_action.isEnabled() is False - assert table_widget.item(2, 2).text() == "Failed (Error)" + assert table_widget.item(2, 4).text() == "Failed (Error)" table_widget.selectRow(2) assert window.toolbar.open_transcript_action.isEnabled() is False window.close() @@ -226,15 +226,15 @@ def _start_new_transcription(window: MainWindow): run_button.click() @staticmethod - def _assert_task_status( + def get_assert_task_status_callback( table_widget: QTableWidget, row_index: int, expected_status: str ): - def assert_task_canceled(): + def assert_task_status(): assert table_widget.rowCount() > 0 assert table_widget.item(row_index, 1).text() == "whisper-french.mp3" - assert expected_status in table_widget.item(row_index, 2).text() + assert expected_status in table_widget.item(row_index, 4).text() - return assert_task_canceled + return assert_task_status @staticmethod def _get_toolbar_action(window: MainWindow, text: str): @@ -250,3 +250,10 @@ def tasks_cache(tmp_path, request: SubRequest): cache.save(tasks) yield cache cache.clear() + + +@pytest.fixture(autouse=True) +def reset_settings(): + settings = Settings() + settings.clear() + settings.sync() diff --git a/tests/widgets/transcription_tasks_table_widget_test.py b/tests/widgets/transcription_tasks_table_widget_test.py index abc449c13..865d02ec7 100644 --- a/tests/widgets/transcription_tasks_table_widget_test.py +++ b/tests/widgets/transcription_tasks_table_widget_test.py @@ -31,24 +31,37 @@ def test_upsert_task(self, qtbot: QtBot): widget.upsert_task(task) assert widget.rowCount() == 1 - assert widget.item(0, 1).text() == "whisper-french.mp3" - assert widget.item(0, 2).text() == "Queued" + self.assert_row_text( + widget, 0, "whisper-french.mp3", "Whisper (Tiny)", "Transcribe", "Queued" + ) task.status = FileTranscriptionTask.Status.IN_PROGRESS task.fraction_completed = 0.3524 widget.upsert_task(task) assert widget.rowCount() == 1 - assert widget.item(0, 1).text() == "whisper-french.mp3" - assert widget.item(0, 2).text() == "In Progress (35%)" + self.assert_row_text( + widget, + 0, + "whisper-french.mp3", + "Whisper (Tiny)", + "Transcribe", + "In Progress (35%)", + ) task.status = FileTranscriptionTask.Status.COMPLETED task.completed_at = datetime.datetime(2023, 4, 12, 0, 0, 10) widget.upsert_task(task) assert widget.rowCount() == 1 - assert widget.item(0, 1).text() == "whisper-french.mp3" - assert widget.item(0, 2).text() == "Completed (5s)" + self.assert_row_text( + widget, + 0, + "whisper-french.mp3", + "Whisper (Tiny)", + "Transcribe", + "Completed (5s)", + ) def test_upsert_task_no_timings(self, qtbot: QtBot): widget = TranscriptionTasksTableWidget() @@ -67,5 +80,20 @@ def test_upsert_task_no_timings(self, qtbot: QtBot): widget.upsert_task(task) assert widget.rowCount() == 1 - assert widget.item(0, 1).text() == "whisper-french.mp3" - assert widget.item(0, 2).text() == "Completed" + self.assert_row_text( + widget, 0, "whisper-french.mp3", "Whisper (Tiny)", "Transcribe", "Completed" + ) + + def assert_row_text( + self, + widget: TranscriptionTasksTableWidget, + row: int, + filename: str, + model: str, + task: str, + status: str, + ): + assert widget.item(row, 1).text() == filename + assert widget.item(row, 2).text() == model + assert widget.item(row, 3).text() == task + assert widget.item(row, 4).text() == status From 0f02bbb23d74e745b3bac7eb8db4bb44298098f9 Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Thu, 21 Dec 2023 19:46:00 +0000 Subject: [PATCH 3/3] feat: update transcription tasks columns --- tests/widgets/conftest.py | 11 +++++++++++ tests/widgets/main_window_test.py | 8 -------- .../transcription_tasks_table_widget_test.py | 16 ++++++++++++++++ 3 files changed, 27 insertions(+), 8 deletions(-) create mode 100644 tests/widgets/conftest.py diff --git a/tests/widgets/conftest.py b/tests/widgets/conftest.py new file mode 100644 index 000000000..e69cb0765 --- /dev/null +++ b/tests/widgets/conftest.py @@ -0,0 +1,11 @@ +import logging +import pytest +from buzz.settings.settings import Settings + + +@pytest.fixture(scope="package") +def reset_settings(): + settings = Settings() + settings.clear() + settings.sync() + logging.debug("Reset settings") diff --git a/tests/widgets/main_window_test.py b/tests/widgets/main_window_test.py index f9caf117b..190e901db 100644 --- a/tests/widgets/main_window_test.py +++ b/tests/widgets/main_window_test.py @@ -10,7 +10,6 @@ from pytestqt.qtbot import QtBot from buzz.cache import TasksCache -from buzz.settings.settings import Settings from buzz.transcriber import ( FileTranscriptionTask, TranscriptionOptions, @@ -250,10 +249,3 @@ def tasks_cache(tmp_path, request: SubRequest): cache.save(tasks) yield cache cache.clear() - - -@pytest.fixture(autouse=True) -def reset_settings(): - settings = Settings() - settings.clear() - settings.sync() diff --git a/tests/widgets/transcription_tasks_table_widget_test.py b/tests/widgets/transcription_tasks_table_widget_test.py index 865d02ec7..313565362 100644 --- a/tests/widgets/transcription_tasks_table_widget_test.py +++ b/tests/widgets/transcription_tasks_table_widget_test.py @@ -84,6 +84,22 @@ def test_upsert_task_no_timings(self, qtbot: QtBot): widget, 0, "whisper-french.mp3", "Whisper (Tiny)", "Transcribe", "Completed" ) + def test_toggle_column_visibility(self, qtbot, reset_settings): + widget = TranscriptionTasksTableWidget() + qtbot.add_widget(widget) + + assert widget.isColumnHidden(TranscriptionTasksTableWidget.Column.TASK_ID.value) + assert not widget.isColumnHidden( + TranscriptionTasksTableWidget.Column.FILE_NAME.value + ) + assert widget.isColumnHidden(TranscriptionTasksTableWidget.Column.MODEL.value) + assert widget.isColumnHidden(TranscriptionTasksTableWidget.Column.TASK.value) + assert not widget.isColumnHidden( + TranscriptionTasksTableWidget.Column.STATUS.value + ) + + # TODO: open context menu and toggle column visibility + def assert_row_text( self, widget: TranscriptionTasksTableWidget,