Skip to content

Commit

Permalink
feat: update transcription tasks columns (#649)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Dec 21, 2023
1 parent 6820797 commit f163aab
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ else
endif

clean:
rm -f $(LIBWHISPER)
rm -f buzz/$(LIBWHISPER)
rm -f buzz/whisper_cpp.py
rm -rf dist/* || true

Expand Down
22 changes: 20 additions & 2 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def to_faster_whisper_model_size(self) -> str:
return "large-v2"
return self.value

def __str__(self):
return self.value.capitalize()


class ModelType(enum.Enum):
WHISPER = "Whisper"
Expand All @@ -46,6 +49,21 @@ class TranscriptionModel:
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY
hugging_face_model_id: Optional[str] = None

def __str__(self):
match self.model_type:
case ModelType.WHISPER:
return f"Whisper ({self.whisper_model_size})"
case ModelType.WHISPER_CPP:
return f"Whisper.cpp ({self.whisper_model_size})"
case ModelType.HUGGING_FACE:
return f"Hugging Face ({self.hugging_face_model_id})"
case ModelType.FASTER_WHISPER:
return f"Faster Whisper ({self.whisper_model_size})"
case ModelType.OPEN_AI_WHISPER_API:
return "OpenAI Whisper API"
case _:
raise Exception("Unknown model type")

def is_deletable(self):
return (
self.model_type == ModelType.WHISPER
Expand Down Expand Up @@ -110,12 +128,12 @@ def get_local_model_path(self) -> Optional[str]:
"base": "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe",
"small": "1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b",
"medium": "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208",
"large": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487",
"large": "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2",
}


def get_hugging_face_file_url(author: str, repository_name: str, filename: str):
return f"https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}"
return f"https://huggingface.co/{author}/{repository_name}/resolve/bf8b606c2fcd9173605cdf6bd2ac8a75a8141b6c/{filename}"


def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
Expand Down
14 changes: 14 additions & 0 deletions buzz/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class Settings:
def __init__(self):
self.settings = QSettings(APP_NAME)
self.settings.sync()

class Key(enum.Enum):
RECORDING_TRANSCRIBER_TASK = "recording-transcriber/task"
Expand All @@ -29,6 +30,10 @@ class Key(enum.Enum):

SHORTCUTS = "shortcuts"

TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY = (
"transcription-tasks-table/column-visibility"
)

def set_value(self, key: Key, value: typing.Any) -> None:
self.settings.setValue(key.value, value)

Expand All @@ -46,3 +51,12 @@ def value(

def clear(self):
self.settings.clear()

def begin_group(self, group: Key):
self.settings.beginGroup(group.value)

def end_group(self):
self.settings.endGroup()

def sync(self):
self.settings.sync()
37 changes: 37 additions & 0 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from . import transformers_whisper
from .conn import pipe_stderr
from .locale import _
from .model_loader import TranscriptionModel, ModelType

# Catch exception from whisper.dll not getting loaded.
Expand Down Expand Up @@ -71,6 +72,12 @@ class TranscriptionOptions:
)


def humanize_language(language: str) -> str:
if language == "":
return _("Detect Language")
return LANGUAGES[language].title()


@dataclass()
class FileTranscriptionOptions:
file_paths: List[str]
Expand Down Expand Up @@ -101,6 +108,36 @@ class Status(enum.Enum):
started_at: Optional[datetime.datetime] = None
completed_at: Optional[datetime.datetime] = None

def status_text(self) -> str:
if self.status == FileTranscriptionTask.Status.IN_PROGRESS:
return f'{_("In Progress")} ({self.fraction_completed :.0%})'
elif self.status == FileTranscriptionTask.Status.COMPLETED:
status = _("Completed")
if self.started_at is not None and self.completed_at is not None:
status += (
f" ({self.format_timedelta(self.completed_at - self.started_at)})"
)
return status
elif self.status == FileTranscriptionTask.Status.FAILED:
return f'{_("Failed")} ({self.error})'
elif self.status == FileTranscriptionTask.Status.CANCELED:
return _("Canceled")
elif self.status == FileTranscriptionTask.Status.QUEUED:
return _("Queued")
return ""

@staticmethod
def format_timedelta(delta: datetime.timedelta):
mm, ss = divmod(delta.seconds, 60)
result = f"{ss}s"
if mm == 0:
return result
hh, mm = divmod(mm, 60)
result = f"{mm}m {result}"
if hh == 0:
return result
return f"{hh}h {result}"


class OutputFormat(enum.Enum):
TXT = "txt"
Expand Down
178 changes: 127 additions & 51 deletions buzz/widgets/transcription_tasks_table_widget.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
import datetime
import enum
import os
from dataclasses import dataclass
from enum import auto
from typing import Optional
from typing import Optional, Callable

from PyQt6 import QtGui
from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex
from PyQt6.QtWidgets import QTableWidget, QWidget, QAbstractItemView, QTableWidgetItem
from PyQt6.QtWidgets import (
QTableWidget,
QWidget,
QAbstractItemView,
QTableWidgetItem,
QMenu,
)

from buzz.locale import _
from buzz.transcriber import FileTranscriptionTask
from buzz.settings.settings import Settings
from buzz.transcriber import FileTranscriptionTask, humanize_language


@dataclass
class TableColDef:
id: str
header: str
column_index: int
value_getter: Callable[..., str]
width: Optional[int] = None
hidden: bool = False
hidden_toggleable: bool = True


class TranscriptionTasksTableWidget(QTableWidget):
class Column(enum.Enum):
TASK_ID = 0
FILE_NAME = auto()
MODEL = auto()
TASK = auto()
STATUS = auto()

return_clicked = pyqtSignal()
Expand All @@ -25,69 +45,125 @@ def __init__(self, parent: Optional[QWidget] = None):

self.setRowCount(0)
self.setAlternatingRowColors(True)
self.settings = Settings()

self.column_definitions = [
TableColDef(
id="id",
header=_("ID"),
column_index=self.Column.TASK_ID.value,
value_getter=lambda task: str(task.id),
width=0,
hidden=True,
hidden_toggleable=False,
),
TableColDef(
id="file_name",
header=_("File Name"),
column_index=self.Column.FILE_NAME.value,
value_getter=lambda task: os.path.basename(task.file_path),
width=250,
hidden_toggleable=False,
),
TableColDef(
id="model",
header=_("Model"),
column_index=self.Column.MODEL.value,
value_getter=lambda task: str(task.transcription_options.model),
width=180,
hidden=True,
),
TableColDef(
id="task",
header=_("Task"),
column_index=self.Column.TASK.value,
value_getter=lambda task: self.get_task_label(task),
width=180,
hidden=True,
),
TableColDef(
id="status",
header=_("Status"),
column_index=self.Column.STATUS.value,
value_getter=lambda task: task.status_text(),
width=180,
hidden_toggleable=False,
),
]

self.setColumnCount(3)
self.setColumnHidden(0, True)

self.setColumnCount(len(self.column_definitions))
self.verticalHeader().hide()
self.setHorizontalHeaderLabels([_("ID"), _("File Name"), _("Status")])
self.setColumnWidth(self.Column.FILE_NAME.value, 250)
self.setColumnWidth(self.Column.STATUS.value, 180)
self.setHorizontalHeaderLabels(
[definition.header for definition in self.column_definitions]
)
for definition in self.column_definitions:
if definition.width is not None:
self.setColumnWidth(definition.column_index, definition.width)
self.load_column_visibility()

self.horizontalHeader().setMinimumSectionSize(180)

self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)

def contextMenuEvent(self, event):
menu = QMenu(self)
for definition in self.column_definitions:
if not definition.hidden_toggleable:
continue
action = menu.addAction(definition.header)
action.setCheckable(True)
action.setChecked(not self.isColumnHidden(definition.column_index))
action.toggled.connect(
lambda checked,
column_index=definition.column_index: self.on_column_checked(
column_index, checked
)
)
menu.exec(event.globalPos())

def on_column_checked(self, column_index: int, checked: bool):
self.setColumnHidden(column_index, not checked)
self.save_column_visibility()

def save_column_visibility(self):
self.settings.begin_group(
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
)
for definition in self.column_definitions:
self.settings.settings.setValue(
definition.id, not self.isColumnHidden(definition.column_index)
)
self.settings.end_group()

def load_column_visibility(self):
self.settings.begin_group(
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
)
for definition in self.column_definitions:
visible = self.settings.settings.value(definition.id, not definition.hidden)
self.setColumnHidden(definition.column_index, not visible)
self.settings.end_group()

def upsert_task(self, task: FileTranscriptionTask):
task_row_index = self.task_row_index(task.id)
if task_row_index is None:
self.insertRow(self.rowCount())

row_index = self.rowCount() - 1
task_id_widget_item = QTableWidgetItem(str(task.id))
self.setItem(row_index, self.Column.TASK_ID.value, task_id_widget_item)

file_name_widget_item = QTableWidgetItem(os.path.basename(task.file_path))
file_name_widget_item.setFlags(
file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable
)
self.setItem(row_index, self.Column.FILE_NAME.value, file_name_widget_item)

status_widget_item = QTableWidgetItem(self.get_status_text(task))
status_widget_item.setFlags(
status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable
)
self.setItem(row_index, self.Column.STATUS.value, status_widget_item)
for definition in self.column_definitions:
item = QTableWidgetItem(definition.value_getter(task))
item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
self.setItem(row_index, definition.column_index, item)
else:
status_widget = self.item(task_row_index, self.Column.STATUS.value)
status_widget.setText(self.get_status_text(task))

@staticmethod
def format_timedelta(delta: datetime.timedelta):
mm, ss = divmod(delta.seconds, 60)
result = f"{ss}s"
if mm == 0:
return result
hh, mm = divmod(mm, 60)
result = f"{mm}m {result}"
if hh == 0:
return result
return f"{hh}h {result}"
status_widget.setText(task.status_text())

@staticmethod
def get_status_text(task: FileTranscriptionTask):
if task.status == FileTranscriptionTask.Status.IN_PROGRESS:
return f'{_("In Progress")} ({task.fraction_completed :.0%})'
elif task.status == FileTranscriptionTask.Status.COMPLETED:
status = _("Completed")
if task.started_at is not None and task.completed_at is not None:
status += f" ({TranscriptionTasksTableWidget.format_timedelta(task.completed_at - task.started_at)})"
return status
elif task.status == FileTranscriptionTask.Status.FAILED:
return f'{_("Failed")} ({task.error})'
elif task.status == FileTranscriptionTask.Status.CANCELED:
return _("Canceled")
elif task.status == FileTranscriptionTask.Status.QUEUED:
return _("Queued")
def get_task_label(task: FileTranscriptionTask) -> str:
value = task.transcription_options.task.value.capitalize()
if task.transcription_options.language is not None:
value += f" ({humanize_language(task.transcription_options.language)})"
return value

def clear_task(self, task_id: int):
task_row_index = self.task_row_index(task_id)
Expand Down
11 changes: 11 additions & 0 deletions tests/widgets/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import logging
import pytest
from buzz.settings.settings import Settings


@pytest.fixture(scope="package")
def reset_settings():
settings = Settings()
settings.clear()
settings.sync()
logging.debug("Reset settings")
Loading

0 comments on commit f163aab

Please sign in to comment.