Skip to content

Commit

Permalink
feat: save transcriptions to sqlite (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Mar 14, 2024
1 parent dfac983 commit ae5af30
Show file tree
Hide file tree
Showing 36 changed files with 2,036 additions and 1,340 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ clean:
rm -f buzz/whisper_cpp.py
rm -rf dist/* || true

COVERAGE_THRESHOLD := 76
COVERAGE_THRESHOLD := 77
ifeq ($(UNAME_S),Linux)
COVERAGE_THRESHOLD := 71
COVERAGE_THRESHOLD := 72
endif

test: buzz/whisper_cpp.py translation_mo
Expand Down
11 changes: 9 additions & 2 deletions buzz/buzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from typing import TextIO

from appdirs import user_log_dir
from platformdirs import user_log_dir

# Check for segfaults if not running in frozen mode
if getattr(sys, "frozen", False) is False:
Expand Down Expand Up @@ -57,7 +57,14 @@ def main():

from buzz.cli import parse_command_line
from buzz.widgets.application import Application
from buzz.db.dao.transcription_dao import TranscriptionDAO
from buzz.db.dao.transcription_segment_dao import TranscriptionSegmentDAO
from buzz.db.service.transcription_service import TranscriptionService
from buzz.db.db import setup_app_db

app = Application()
db = setup_app_db()
app = Application(
TranscriptionService(TranscriptionDAO(db), TranscriptionSegmentDAO(db))
)
parse_command_line(app)
sys.exit(app.exec())
Empty file added buzz/db/__init__.py
Empty file.
Empty file added buzz/db/dao/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions buzz/db/dao/dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Adapted from https://github.com/zhiyiYo/Groove
from abc import ABC
from typing import TypeVar, Generic, Any, Type

from PyQt6.QtSql import QSqlDatabase, QSqlQuery, QSqlRecord

from buzz.db.entity.entity import Entity

T = TypeVar("T", bound=Entity)


class DAO(ABC, Generic[T]):
entity: Type[T]

def __init__(self, table: str, db: QSqlDatabase):
self.db = db
self.table = table

def insert(self, record: T):
query = self._create_query()
keys = record.__dict__.keys()
query.prepare(
f"""
INSERT INTO {self.table} ({", ".join(keys)})
VALUES ({", ".join([f":{key}" for key in keys])})
"""
)
for key, value in record.__dict__.items():
query.bindValue(f":{key}", value)
if not query.exec():
raise Exception(query.lastError().text())

def find_by_id(self, id: Any) -> T | None:
query = self._create_query()
query.prepare(f"SELECT * FROM {self.table} WHERE id = :id")
query.bindValue(":id", id)
return self._execute(query)

def to_entity(self, record: QSqlRecord) -> T:
entity = self.entity()
for i in range(record.count()):
setattr(entity, record.fieldName(i), record.value(i))
return entity

def _execute(self, query: QSqlQuery) -> T | None:
if not query.exec():
raise Exception(query.lastError().text())
if not query.first():
return None
return self.to_entity(query.record())

def _create_query(self):
return QSqlQuery(self.db)
159 changes: 159 additions & 0 deletions buzz/db/dao/transcription_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from datetime import datetime
from uuid import UUID

from PyQt6.QtSql import QSqlDatabase

from buzz.db.dao.dao import DAO
from buzz.db.entity.transcription import Transcription
from buzz.transcriber.transcriber import FileTranscriptionTask


class TranscriptionDAO(DAO[Transcription]):
entity = Transcription

def __init__(self, db: QSqlDatabase):
super().__init__("transcription", db)

def create_transcription(self, task: FileTranscriptionTask):
query = self._create_query()
query.prepare(
"""
INSERT INTO transcription (
id,
export_formats,
file,
output_folder,
language,
model_type,
source,
status,
task,
time_queued,
url,
whisper_model_size
) VALUES (
:id,
:export_formats,
:file,
:output_folder,
:language,
:model_type,
:source,
:status,
:task,
:time_queued,
:url,
:whisper_model_size
)
"""
)
query.bindValue(":id", str(task.uid))
query.bindValue(
":export_formats",
", ".join(
[
output_format.value
for output_format in task.file_transcription_options.output_formats
]
),
)
query.bindValue(":file", task.file_path)
query.bindValue(":output_folder", task.output_directory)
query.bindValue(":language", task.transcription_options.language)
query.bindValue(
":model_type", task.transcription_options.model.model_type.value
)
query.bindValue(":source", task.source.value)
query.bindValue(":status", FileTranscriptionTask.Status.QUEUED.value)
query.bindValue(":task", task.transcription_options.task.value)
query.bindValue(":time_queued", datetime.now().isoformat())
query.bindValue(":url", task.url)
query.bindValue(
":whisper_model_size",
task.transcription_options.model.whisper_model_size.value
if task.transcription_options.model.whisper_model_size
else None,
)
if not query.exec():
raise Exception(query.lastError().text())

def update_transcription_as_started(self, id: UUID):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, time_started = :time_started
WHERE id = :id
"""
)

query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.IN_PROGRESS.value)
query.bindValue(":time_started", datetime.now().isoformat())
if not query.exec():
raise Exception(query.lastError().text())

def update_transcription_as_failed(self, id: UUID, error: str):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, time_ended = :time_ended, error_message = :error_message
WHERE id = :id
"""
)

query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.FAILED.value)
query.bindValue(":time_ended", datetime.now().isoformat())
query.bindValue(":error_message", error)
if not query.exec():
raise Exception(query.lastError().text())

def update_transcription_as_canceled(self, id: UUID):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, time_ended = :time_ended
WHERE id = :id
"""
)

query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.CANCELED.value)
query.bindValue(":time_ended", datetime.now().isoformat())
if not query.exec():
raise Exception(query.lastError().text())

def update_transcription_progress(self, id: UUID, progress: float):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, progress = :progress
WHERE id = :id
"""
)

query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.IN_PROGRESS.value)
query.bindValue(":progress", progress)
if not query.exec():
raise Exception(query.lastError().text())

def update_transcription_as_completed(self, id: UUID):
query = self._create_query()
query.prepare(
"""
UPDATE transcription
SET status = :status, time_ended = :time_ended
WHERE id = :id
"""
)

query.bindValue(":id", str(id))
query.bindValue(":status", FileTranscriptionTask.Status.COMPLETED.value)
query.bindValue(":time_ended", datetime.now().isoformat())
if not query.exec():
raise Exception(query.lastError().text())
11 changes: 11 additions & 0 deletions buzz/db/dao/transcription_segment_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from PyQt6.QtSql import QSqlDatabase

from buzz.db.dao.dao import DAO
from buzz.db.entity.transcription_segment import TranscriptionSegment


class TranscriptionSegmentDAO(DAO[TranscriptionSegment]):
entity = TranscriptionSegment

def __init__(self, db: QSqlDatabase):
super().__init__("transcription_segment", db)
39 changes: 39 additions & 0 deletions buzz/db/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
import os
import sqlite3
import tempfile

from PyQt6.QtSql import QSqlDatabase
from platformdirs import user_data_dir

from buzz.db.helpers import (
run_sqlite_migrations,
copy_transcriptions_from_json_to_sqlite,
mark_in_progress_and_queued_transcriptions_as_canceled,
)

APP_DB_PATH = os.path.join(user_data_dir("Buzz"), "Buzz.sqlite")


def setup_app_db() -> QSqlDatabase:
return _setup_db(APP_DB_PATH)


def setup_test_db() -> QSqlDatabase:
return _setup_db(tempfile.mktemp())


def _setup_db(path: str) -> QSqlDatabase:
# Run migrations
db = sqlite3.connect(path)
run_sqlite_migrations(db)
copy_transcriptions_from_json_to_sqlite(db)
mark_in_progress_and_queued_transcriptions_as_canceled(db)
db.close()

db = QSqlDatabase.addDatabase("QSQLITE")
db.setDatabaseName(path)
if not db.open():
raise RuntimeError(f"Failed to open database connection: {db.databaseName()}")
logging.debug("Database connection opened: %s", db.databaseName())
return db
Empty file added buzz/db/entity/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions buzz/db/entity/entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABC

from PyQt6.QtSql import QSqlRecord


class Entity(ABC):
@classmethod
def from_record(cls, record: QSqlRecord):
entity = cls()
for i in range(record.count()):
setattr(entity, record.fieldName(i), record.value(i))
return entity
54 changes: 54 additions & 0 deletions buzz/db/entity/transcription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import datetime
import os
import uuid
from dataclasses import dataclass, field

from buzz.db.entity.entity import Entity
from buzz.model_loader import ModelType
from buzz.settings.settings import Settings
from buzz.transcriber.transcriber import OutputFormat, Task, FileTranscriptionTask


@dataclass
class Transcription(Entity):
status: str = FileTranscriptionTask.Status.QUEUED.value
task: str = Task.TRANSCRIBE.value
model_type: str = ModelType.WHISPER.value
whisper_model_size: str | None = None
language: str | None = None
id: str = field(default_factory=lambda: str(uuid.uuid4()))
error_message: str | None = None
file: str | None = None
time_queued: str = datetime.datetime.now().isoformat()

@property
def id_as_uuid(self):
return uuid.UUID(hex=self.id)

@property
def status_as_status(self):
return FileTranscriptionTask.Status(self.status)

def get_output_file_path(
self,
output_format: OutputFormat,
output_directory: str | None = None,
):
input_file_name = os.path.splitext(os.path.basename(self.file))[0]

date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S")

export_file_name_template = Settings().get_default_export_file_template()

output_file_name = (
export_file_name_template.replace("{{ input_file_name }}", input_file_name)
.replace("{{ task }}", self.task)
.replace("{{ language }}", self.language or "")
.replace("{{ model_type }}", self.model_type)
.replace("{{ model_size }}", self.whisper_model_size or "")
.replace("{{ date_time }}", date_time_now)
+ f".{output_format.value}"
)

output_directory = output_directory or os.path.dirname(self.file)
return os.path.join(output_directory, output_file_name)
11 changes: 11 additions & 0 deletions buzz/db/entity/transcription_segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass

from buzz.db.entity.entity import Entity


@dataclass
class TranscriptionSegment(Entity):
start_time: int
end_time: int
text: str
transcription_id: str
Loading

0 comments on commit ae5af30

Please sign in to comment.