From 730baded0d7c230d58f413eecc1ff059132ec699 Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Thu, 21 Dec 2023 00:36:05 +0000 Subject: [PATCH] feat: update transcription tasks columns --- Makefile | 2 +- buzz/model_loader.py | 22 ++- buzz/settings/settings.py | 10 + buzz/transcriber.py | 37 ++++ .../transcription_tasks_table_widget.py | 178 +++++++++++++----- 5 files changed, 195 insertions(+), 54 deletions(-) diff --git a/Makefile b/Makefile index ed4acd333..cc69c51a8 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ else endif clean: - rm -f $(LIBWHISPER) + rm -f buzz/$(LIBWHISPER) rm -f buzz/whisper_cpp.py rm -rf dist/* || true diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 6f7b47226..129082030 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -31,6 +31,9 @@ def to_faster_whisper_model_size(self) -> str: return "large-v2" return self.value + def __str__(self): + return self.value.capitalize() + class ModelType(enum.Enum): WHISPER = "Whisper" @@ -46,6 +49,21 @@ class TranscriptionModel: whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY hugging_face_model_id: Optional[str] = None + def __str__(self): + match self.model_type: + case ModelType.WHISPER: + return f"Whisper ({self.whisper_model_size})" + case ModelType.WHISPER_CPP: + return f"Whisper.cpp ({self.whisper_model_size})" + case ModelType.HUGGING_FACE: + return f"Hugging Face ({self.hugging_face_model_id})" + case ModelType.FASTER_WHISPER: + return f"Faster Whisper ({self.whisper_model_size})" + case ModelType.OPEN_AI_WHISPER_API: + return "OpenAI Whisper API" + case _: + raise Exception("Unknown model type") + def is_deletable(self): return ( self.model_type == ModelType.WHISPER @@ -110,12 +128,12 @@ def get_local_model_path(self) -> Optional[str]: "base": "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe", "small": "1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b", "medium": "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208", - "large": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487", + "large": "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2", } def get_hugging_face_file_url(author: str, repository_name: str, filename: str): - return f"https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}" + return f"https://huggingface.co/{author}/{repository_name}/resolve/bf8b606c2fcd9173605cdf6bd2ac8a75a8141b6c/{filename}" def get_whisper_cpp_file_path(size: WhisperModelSize) -> str: diff --git a/buzz/settings/settings.py b/buzz/settings/settings.py index 4ee430aea..90fc83b16 100644 --- a/buzz/settings/settings.py +++ b/buzz/settings/settings.py @@ -29,6 +29,10 @@ class Key(enum.Enum): SHORTCUTS = "shortcuts" + TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY = ( + "transcription-tasks-table/column-visibility" + ) + def set_value(self, key: Key, value: typing.Any) -> None: self.settings.setValue(key.value, value) @@ -46,3 +50,9 @@ def value( def clear(self): self.settings.clear() + + def begin_group(self, group: Key): + self.settings.beginGroup(group.value) + + def end_group(self): + self.settings.endGroup() diff --git a/buzz/transcriber.py b/buzz/transcriber.py index a87a86e7c..2902e00ae 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -27,6 +27,7 @@ from . import transformers_whisper from .conn import pipe_stderr +from .locale import _ from .model_loader import TranscriptionModel, ModelType # Catch exception from whisper.dll not getting loaded. @@ -71,6 +72,12 @@ class TranscriptionOptions: ) +def humanize_language(language: str) -> str: + if language == "": + return _("Detect Language") + return LANGUAGES[language].title() + + @dataclass() class FileTranscriptionOptions: file_paths: List[str] @@ -101,6 +108,36 @@ class Status(enum.Enum): started_at: Optional[datetime.datetime] = None completed_at: Optional[datetime.datetime] = None + def status_text(self) -> str: + if self.status == FileTranscriptionTask.Status.IN_PROGRESS: + return f'{_("In Progress")} ({self.fraction_completed :.0%})' + elif self.status == FileTranscriptionTask.Status.COMPLETED: + status = _("Completed") + if self.started_at is not None and self.completed_at is not None: + status += ( + f" ({self.format_timedelta(self.completed_at - self.started_at)})" + ) + return status + elif self.status == FileTranscriptionTask.Status.FAILED: + return f'{_("Failed")} ({self.error})' + elif self.status == FileTranscriptionTask.Status.CANCELED: + return _("Canceled") + elif self.status == FileTranscriptionTask.Status.QUEUED: + return _("Queued") + return "" + + @staticmethod + def format_timedelta(delta: datetime.timedelta): + mm, ss = divmod(delta.seconds, 60) + result = f"{ss}s" + if mm == 0: + return result + hh, mm = divmod(mm, 60) + result = f"{mm}m {result}" + if hh == 0: + return result + return f"{hh}h {result}" + class OutputFormat(enum.Enum): TXT = "txt" diff --git a/buzz/widgets/transcription_tasks_table_widget.py b/buzz/widgets/transcription_tasks_table_widget.py index cd79fcd3f..41a221562 100644 --- a/buzz/widgets/transcription_tasks_table_widget.py +++ b/buzz/widgets/transcription_tasks_table_widget.py @@ -1,21 +1,41 @@ -import datetime import enum import os +from dataclasses import dataclass from enum import auto -from typing import Optional +from typing import Optional, Callable from PyQt6 import QtGui from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex -from PyQt6.QtWidgets import QTableWidget, QWidget, QAbstractItemView, QTableWidgetItem +from PyQt6.QtWidgets import ( + QTableWidget, + QWidget, + QAbstractItemView, + QTableWidgetItem, + QMenu, +) from buzz.locale import _ -from buzz.transcriber import FileTranscriptionTask +from buzz.settings.settings import Settings +from buzz.transcriber import FileTranscriptionTask, humanize_language + + +@dataclass +class TableColDef: + id: str + header: str + column_index: int + value_getter: Callable[..., str] + width: Optional[int] = None + hidden: bool = False + hidden_toggleable: bool = True class TranscriptionTasksTableWidget(QTableWidget): class Column(enum.Enum): TASK_ID = 0 FILE_NAME = auto() + MODEL = auto() + TASK = auto() STATUS = auto() return_clicked = pyqtSignal() @@ -25,69 +45,125 @@ def __init__(self, parent: Optional[QWidget] = None): self.setRowCount(0) self.setAlternatingRowColors(True) + self.settings = Settings() + + self.column_definitions = [ + TableColDef( + id="id", + header=_("ID"), + column_index=self.Column.TASK_ID.value, + value_getter=lambda task: str(task.id), + width=0, + hidden=True, + hidden_toggleable=False, + ), + TableColDef( + id="file_name", + header=_("File Name"), + column_index=self.Column.FILE_NAME.value, + value_getter=lambda task: os.path.basename(task.file_path), + width=250, + hidden_toggleable=False, + ), + TableColDef( + id="model", + header=_("Model"), + column_index=self.Column.MODEL.value, + value_getter=lambda task: str(task.transcription_options.model), + width=180, + hidden=True, + ), + TableColDef( + id="task", + header=_("Task"), + column_index=self.Column.TASK.value, + value_getter=lambda task: self.get_task_label(task), + width=180, + hidden=True, + ), + TableColDef( + id="status", + header=_("Status"), + column_index=self.Column.STATUS.value, + value_getter=lambda task: task.status_text(), + width=180, + hidden_toggleable=False, + ), + ] - self.setColumnCount(3) - self.setColumnHidden(0, True) - + self.setColumnCount(len(self.column_definitions)) self.verticalHeader().hide() - self.setHorizontalHeaderLabels([_("ID"), _("File Name"), _("Status")]) - self.setColumnWidth(self.Column.FILE_NAME.value, 250) - self.setColumnWidth(self.Column.STATUS.value, 180) + self.setHorizontalHeaderLabels( + [definition.header for definition in self.column_definitions] + ) + for definition in self.column_definitions: + if definition.width is not None: + self.setColumnWidth(definition.column_index, definition.width) + self.load_column_visibility() + self.horizontalHeader().setMinimumSectionSize(180) self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) + def contextMenuEvent(self, event): + menu = QMenu(self) + for definition in self.column_definitions: + if not definition.hidden_toggleable: + continue + action = menu.addAction(definition.header) + action.setCheckable(True) + action.setChecked(not self.isColumnHidden(definition.column_index)) + action.toggled.connect( + lambda checked, + column_index=definition.column_index: self.on_column_checked( + column_index, checked + ) + ) + menu.exec(event.globalPos()) + + def on_column_checked(self, column_index: int, checked: bool): + self.setColumnHidden(column_index, not checked) + self.save_column_visibility() + + def save_column_visibility(self): + self.settings.begin_group( + Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY + ) + for definition in self.column_definitions: + self.settings.settings.setValue( + definition.id, not self.isColumnHidden(definition.column_index) + ) + self.settings.end_group() + + def load_column_visibility(self): + self.settings.begin_group( + Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY + ) + for definition in self.column_definitions: + visible = self.settings.settings.value(definition.id, not definition.hidden) + self.setColumnHidden(definition.column_index, not visible) + self.settings.end_group() + def upsert_task(self, task: FileTranscriptionTask): task_row_index = self.task_row_index(task.id) if task_row_index is None: self.insertRow(self.rowCount()) row_index = self.rowCount() - 1 - task_id_widget_item = QTableWidgetItem(str(task.id)) - self.setItem(row_index, self.Column.TASK_ID.value, task_id_widget_item) - - file_name_widget_item = QTableWidgetItem(os.path.basename(task.file_path)) - file_name_widget_item.setFlags( - file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable - ) - self.setItem(row_index, self.Column.FILE_NAME.value, file_name_widget_item) - - status_widget_item = QTableWidgetItem(self.get_status_text(task)) - status_widget_item.setFlags( - status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable - ) - self.setItem(row_index, self.Column.STATUS.value, status_widget_item) + for definition in self.column_definitions: + item = QTableWidgetItem(definition.value_getter(task)) + item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.setItem(row_index, definition.column_index, item) else: status_widget = self.item(task_row_index, self.Column.STATUS.value) - status_widget.setText(self.get_status_text(task)) - - @staticmethod - def format_timedelta(delta: datetime.timedelta): - mm, ss = divmod(delta.seconds, 60) - result = f"{ss}s" - if mm == 0: - return result - hh, mm = divmod(mm, 60) - result = f"{mm}m {result}" - if hh == 0: - return result - return f"{hh}h {result}" + status_widget.setText(task.status_text()) @staticmethod - def get_status_text(task: FileTranscriptionTask): - if task.status == FileTranscriptionTask.Status.IN_PROGRESS: - return f'{_("In Progress")} ({task.fraction_completed :.0%})' - elif task.status == FileTranscriptionTask.Status.COMPLETED: - status = _("Completed") - if task.started_at is not None and task.completed_at is not None: - status += f" ({TranscriptionTasksTableWidget.format_timedelta(task.completed_at - task.started_at)})" - return status - elif task.status == FileTranscriptionTask.Status.FAILED: - return f'{_("Failed")} ({task.error})' - elif task.status == FileTranscriptionTask.Status.CANCELED: - return _("Canceled") - elif task.status == FileTranscriptionTask.Status.QUEUED: - return _("Queued") + def get_task_label(task: FileTranscriptionTask) -> str: + value = task.transcription_options.task.value.capitalize() + if task.transcription_options.language is not None: + value += f" ({humanize_language(task.transcription_options.language)})" + return value def clear_task(self, task_id: int): task_row_index = self.task_row_index(task_id)