Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update transcription tasks columns #649

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ else
endif

clean:
rm -f $(LIBWHISPER)
rm -f buzz/$(LIBWHISPER)
rm -f buzz/whisper_cpp.py
rm -rf dist/* || true

Expand Down
22 changes: 20 additions & 2 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
return "large-v2"
return self.value

def __str__(self):
return self.value.capitalize()


class ModelType(enum.Enum):
WHISPER = "Whisper"
Expand All @@ -46,6 +49,21 @@
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY
hugging_face_model_id: Optional[str] = None

def __str__(self):
match self.model_type:
case ModelType.WHISPER:
return f"Whisper ({self.whisper_model_size})"
case ModelType.WHISPER_CPP:
return f"Whisper.cpp ({self.whisper_model_size})"
case ModelType.HUGGING_FACE:
return f"Hugging Face ({self.hugging_face_model_id})"
case ModelType.FASTER_WHISPER:
return f"Faster Whisper ({self.whisper_model_size})"
case ModelType.OPEN_AI_WHISPER_API:
return "OpenAI Whisper API"
case _:
raise Exception("Unknown model type")

Check warning on line 65 in buzz/model_loader.py

View check run for this annotation

Codecov / codecov/patch

buzz/model_loader.py#L56-L65

Added lines #L56 - L65 were not covered by tests

def is_deletable(self):
return (
self.model_type == ModelType.WHISPER
Expand Down Expand Up @@ -110,12 +128,12 @@
"base": "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe",
"small": "1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b",
"medium": "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208",
"large": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487",
"large": "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2",
}


def get_hugging_face_file_url(author: str, repository_name: str, filename: str):
return f"https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}"
return f"https://huggingface.co/{author}/{repository_name}/resolve/bf8b606c2fcd9173605cdf6bd2ac8a75a8141b6c/{filename}"


def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
Expand Down
14 changes: 14 additions & 0 deletions buzz/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class Settings:
def __init__(self):
self.settings = QSettings(APP_NAME)
self.settings.sync()

class Key(enum.Enum):
RECORDING_TRANSCRIBER_TASK = "recording-transcriber/task"
Expand All @@ -29,6 +30,10 @@ class Key(enum.Enum):

SHORTCUTS = "shortcuts"

TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY = (
"transcription-tasks-table/column-visibility"
)

def set_value(self, key: Key, value: typing.Any) -> None:
self.settings.setValue(key.value, value)

Expand All @@ -46,3 +51,12 @@ def value(

def clear(self):
self.settings.clear()

def begin_group(self, group: Key):
self.settings.beginGroup(group.value)

def end_group(self):
self.settings.endGroup()

def sync(self):
self.settings.sync()
37 changes: 37 additions & 0 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from . import transformers_whisper
from .conn import pipe_stderr
from .locale import _
from .model_loader import TranscriptionModel, ModelType

# Catch exception from whisper.dll not getting loaded.
Expand Down Expand Up @@ -71,6 +72,12 @@
)


def humanize_language(language: str) -> str:
if language == "":
return _("Detect Language")
return LANGUAGES[language].title()

Check warning on line 78 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L76-L78

Added lines #L76 - L78 were not covered by tests


@dataclass()
class FileTranscriptionOptions:
file_paths: List[str]
Expand Down Expand Up @@ -101,6 +108,36 @@
started_at: Optional[datetime.datetime] = None
completed_at: Optional[datetime.datetime] = None

def status_text(self) -> str:
if self.status == FileTranscriptionTask.Status.IN_PROGRESS:
return f'{_("In Progress")} ({self.fraction_completed :.0%})'
elif self.status == FileTranscriptionTask.Status.COMPLETED:
status = _("Completed")
if self.started_at is not None and self.completed_at is not None:
status += (
f" ({self.format_timedelta(self.completed_at - self.started_at)})"
)
return status
elif self.status == FileTranscriptionTask.Status.FAILED:
return f'{_("Failed")} ({self.error})'
elif self.status == FileTranscriptionTask.Status.CANCELED:
return _("Canceled")
elif self.status == FileTranscriptionTask.Status.QUEUED:
return _("Queued")
return ""

Check warning on line 127 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L127

Added line #L127 was not covered by tests

@staticmethod
def format_timedelta(delta: datetime.timedelta):
mm, ss = divmod(delta.seconds, 60)
result = f"{ss}s"
if mm == 0:
return result
hh, mm = divmod(mm, 60)
result = f"{mm}m {result}"
if hh == 0:
return result
return f"{hh}h {result}"

Check warning on line 139 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L135-L139

Added lines #L135 - L139 were not covered by tests


class OutputFormat(enum.Enum):
TXT = "txt"
Expand Down
178 changes: 127 additions & 51 deletions buzz/widgets/transcription_tasks_table_widget.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
import datetime
import enum
import os
from dataclasses import dataclass
from enum import auto
from typing import Optional
from typing import Optional, Callable

from PyQt6 import QtGui
from PyQt6.QtCore import pyqtSignal, Qt, QModelIndex
from PyQt6.QtWidgets import QTableWidget, QWidget, QAbstractItemView, QTableWidgetItem
from PyQt6.QtWidgets import (
QTableWidget,
QWidget,
QAbstractItemView,
QTableWidgetItem,
QMenu,
)

from buzz.locale import _
from buzz.transcriber import FileTranscriptionTask
from buzz.settings.settings import Settings
from buzz.transcriber import FileTranscriptionTask, humanize_language


@dataclass
class TableColDef:
id: str
header: str
column_index: int
value_getter: Callable[..., str]
width: Optional[int] = None
hidden: bool = False
hidden_toggleable: bool = True


class TranscriptionTasksTableWidget(QTableWidget):
class Column(enum.Enum):
TASK_ID = 0
FILE_NAME = auto()
MODEL = auto()
TASK = auto()
STATUS = auto()

return_clicked = pyqtSignal()
Expand All @@ -25,69 +45,125 @@

self.setRowCount(0)
self.setAlternatingRowColors(True)
self.settings = Settings()

self.column_definitions = [
TableColDef(
id="id",
header=_("ID"),
column_index=self.Column.TASK_ID.value,
value_getter=lambda task: str(task.id),
width=0,
hidden=True,
hidden_toggleable=False,
),
TableColDef(
id="file_name",
header=_("File Name"),
column_index=self.Column.FILE_NAME.value,
value_getter=lambda task: os.path.basename(task.file_path),
width=250,
hidden_toggleable=False,
),
TableColDef(
id="model",
header=_("Model"),
column_index=self.Column.MODEL.value,
value_getter=lambda task: str(task.transcription_options.model),
width=180,
hidden=True,
),
TableColDef(
id="task",
header=_("Task"),
column_index=self.Column.TASK.value,
value_getter=lambda task: self.get_task_label(task),
width=180,
hidden=True,
),
TableColDef(
id="status",
header=_("Status"),
column_index=self.Column.STATUS.value,
value_getter=lambda task: task.status_text(),
width=180,
hidden_toggleable=False,
),
]

self.setColumnCount(3)
self.setColumnHidden(0, True)

self.setColumnCount(len(self.column_definitions))
self.verticalHeader().hide()
self.setHorizontalHeaderLabels([_("ID"), _("File Name"), _("Status")])
self.setColumnWidth(self.Column.FILE_NAME.value, 250)
self.setColumnWidth(self.Column.STATUS.value, 180)
self.setHorizontalHeaderLabels(
[definition.header for definition in self.column_definitions]
)
for definition in self.column_definitions:
if definition.width is not None:
self.setColumnWidth(definition.column_index, definition.width)
self.load_column_visibility()

self.horizontalHeader().setMinimumSectionSize(180)

self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)

def contextMenuEvent(self, event):
menu = QMenu(self)
for definition in self.column_definitions:
if not definition.hidden_toggleable:
continue
action = menu.addAction(definition.header)
action.setCheckable(True)
action.setChecked(not self.isColumnHidden(definition.column_index))
action.toggled.connect(

Check warning on line 116 in buzz/widgets/transcription_tasks_table_widget.py

View check run for this annotation

Codecov / codecov/patch

buzz/widgets/transcription_tasks_table_widget.py#L109-L116

Added lines #L109 - L116 were not covered by tests
lambda checked,
column_index=definition.column_index: self.on_column_checked(
column_index, checked
)
)
menu.exec(event.globalPos())

Check warning on line 122 in buzz/widgets/transcription_tasks_table_widget.py

View check run for this annotation

Codecov / codecov/patch

buzz/widgets/transcription_tasks_table_widget.py#L122

Added line #L122 was not covered by tests

def on_column_checked(self, column_index: int, checked: bool):
self.setColumnHidden(column_index, not checked)
self.save_column_visibility()

Check warning on line 126 in buzz/widgets/transcription_tasks_table_widget.py

View check run for this annotation

Codecov / codecov/patch

buzz/widgets/transcription_tasks_table_widget.py#L125-L126

Added lines #L125 - L126 were not covered by tests

def save_column_visibility(self):
self.settings.begin_group(

Check warning on line 129 in buzz/widgets/transcription_tasks_table_widget.py

View check run for this annotation

Codecov / codecov/patch

buzz/widgets/transcription_tasks_table_widget.py#L129

Added line #L129 was not covered by tests
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
)
for definition in self.column_definitions:
self.settings.settings.setValue(

Check warning on line 133 in buzz/widgets/transcription_tasks_table_widget.py

View check run for this annotation

Codecov / codecov/patch

buzz/widgets/transcription_tasks_table_widget.py#L132-L133

Added lines #L132 - L133 were not covered by tests
definition.id, not self.isColumnHidden(definition.column_index)
)
self.settings.end_group()

Check warning on line 136 in buzz/widgets/transcription_tasks_table_widget.py

View check run for this annotation

Codecov / codecov/patch

buzz/widgets/transcription_tasks_table_widget.py#L136

Added line #L136 was not covered by tests

def load_column_visibility(self):
self.settings.begin_group(
Settings.Key.TRANSCRIPTION_TASKS_TABLE_COLUMN_VISIBILITY
)
for definition in self.column_definitions:
visible = self.settings.settings.value(definition.id, not definition.hidden)
self.setColumnHidden(definition.column_index, not visible)
self.settings.end_group()

def upsert_task(self, task: FileTranscriptionTask):
task_row_index = self.task_row_index(task.id)
if task_row_index is None:
self.insertRow(self.rowCount())

row_index = self.rowCount() - 1
task_id_widget_item = QTableWidgetItem(str(task.id))
self.setItem(row_index, self.Column.TASK_ID.value, task_id_widget_item)

file_name_widget_item = QTableWidgetItem(os.path.basename(task.file_path))
file_name_widget_item.setFlags(
file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable
)
self.setItem(row_index, self.Column.FILE_NAME.value, file_name_widget_item)

status_widget_item = QTableWidgetItem(self.get_status_text(task))
status_widget_item.setFlags(
status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable
)
self.setItem(row_index, self.Column.STATUS.value, status_widget_item)
for definition in self.column_definitions:
item = QTableWidgetItem(definition.value_getter(task))
item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
self.setItem(row_index, definition.column_index, item)
else:
status_widget = self.item(task_row_index, self.Column.STATUS.value)
status_widget.setText(self.get_status_text(task))

@staticmethod
def format_timedelta(delta: datetime.timedelta):
mm, ss = divmod(delta.seconds, 60)
result = f"{ss}s"
if mm == 0:
return result
hh, mm = divmod(mm, 60)
result = f"{mm}m {result}"
if hh == 0:
return result
return f"{hh}h {result}"
status_widget.setText(task.status_text())

@staticmethod
def get_status_text(task: FileTranscriptionTask):
if task.status == FileTranscriptionTask.Status.IN_PROGRESS:
return f'{_("In Progress")} ({task.fraction_completed :.0%})'
elif task.status == FileTranscriptionTask.Status.COMPLETED:
status = _("Completed")
if task.started_at is not None and task.completed_at is not None:
status += f" ({TranscriptionTasksTableWidget.format_timedelta(task.completed_at - task.started_at)})"
return status
elif task.status == FileTranscriptionTask.Status.FAILED:
return f'{_("Failed")} ({task.error})'
elif task.status == FileTranscriptionTask.Status.CANCELED:
return _("Canceled")
elif task.status == FileTranscriptionTask.Status.QUEUED:
return _("Queued")
def get_task_label(task: FileTranscriptionTask) -> str:
value = task.transcription_options.task.value.capitalize()
if task.transcription_options.language is not None:
value += f" ({humanize_language(task.transcription_options.language)})"

Check warning on line 165 in buzz/widgets/transcription_tasks_table_widget.py

View check run for this annotation

Codecov / codecov/patch

buzz/widgets/transcription_tasks_table_widget.py#L165

Added line #L165 was not covered by tests
return value

def clear_task(self, task_id: int):
task_row_index = self.task_row_index(task_id)
Expand Down
11 changes: 11 additions & 0 deletions tests/widgets/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import logging
import pytest
from buzz.settings.settings import Settings


@pytest.fixture(scope="package")
def reset_settings():
settings = Settings()
settings.clear()
settings.sync()
logging.debug("Reset settings")
Loading
Loading