-
Notifications
You must be signed in to change notification settings - Fork 977
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: save transcriptions to sqlite (#682)
- Loading branch information
1 parent
dfac983
commit ae5af30
Showing
36 changed files
with
2,036 additions
and
1,340 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Adapted from https://github.com/zhiyiYo/Groove | ||
from abc import ABC | ||
from typing import TypeVar, Generic, Any, Type | ||
|
||
from PyQt6.QtSql import QSqlDatabase, QSqlQuery, QSqlRecord | ||
|
||
from buzz.db.entity.entity import Entity | ||
|
||
T = TypeVar("T", bound=Entity) | ||
|
||
|
||
class DAO(ABC, Generic[T]): | ||
entity: Type[T] | ||
|
||
def __init__(self, table: str, db: QSqlDatabase): | ||
self.db = db | ||
self.table = table | ||
|
||
def insert(self, record: T): | ||
query = self._create_query() | ||
keys = record.__dict__.keys() | ||
query.prepare( | ||
f""" | ||
INSERT INTO {self.table} ({", ".join(keys)}) | ||
VALUES ({", ".join([f":{key}" for key in keys])}) | ||
""" | ||
) | ||
for key, value in record.__dict__.items(): | ||
query.bindValue(f":{key}", value) | ||
if not query.exec(): | ||
raise Exception(query.lastError().text()) | ||
|
||
def find_by_id(self, id: Any) -> T | None: | ||
query = self._create_query() | ||
query.prepare(f"SELECT * FROM {self.table} WHERE id = :id") | ||
query.bindValue(":id", id) | ||
return self._execute(query) | ||
|
||
def to_entity(self, record: QSqlRecord) -> T: | ||
entity = self.entity() | ||
for i in range(record.count()): | ||
setattr(entity, record.fieldName(i), record.value(i)) | ||
return entity | ||
|
||
def _execute(self, query: QSqlQuery) -> T | None: | ||
if not query.exec(): | ||
raise Exception(query.lastError().text()) | ||
if not query.first(): | ||
return None | ||
return self.to_entity(query.record()) | ||
|
||
def _create_query(self): | ||
return QSqlQuery(self.db) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from datetime import datetime | ||
from uuid import UUID | ||
|
||
from PyQt6.QtSql import QSqlDatabase | ||
|
||
from buzz.db.dao.dao import DAO | ||
from buzz.db.entity.transcription import Transcription | ||
from buzz.transcriber.transcriber import FileTranscriptionTask | ||
|
||
|
||
class TranscriptionDAO(DAO[Transcription]): | ||
entity = Transcription | ||
|
||
def __init__(self, db: QSqlDatabase): | ||
super().__init__("transcription", db) | ||
|
||
def create_transcription(self, task: FileTranscriptionTask): | ||
query = self._create_query() | ||
query.prepare( | ||
""" | ||
INSERT INTO transcription ( | ||
id, | ||
export_formats, | ||
file, | ||
output_folder, | ||
language, | ||
model_type, | ||
source, | ||
status, | ||
task, | ||
time_queued, | ||
url, | ||
whisper_model_size | ||
) VALUES ( | ||
:id, | ||
:export_formats, | ||
:file, | ||
:output_folder, | ||
:language, | ||
:model_type, | ||
:source, | ||
:status, | ||
:task, | ||
:time_queued, | ||
:url, | ||
:whisper_model_size | ||
) | ||
""" | ||
) | ||
query.bindValue(":id", str(task.uid)) | ||
query.bindValue( | ||
":export_formats", | ||
", ".join( | ||
[ | ||
output_format.value | ||
for output_format in task.file_transcription_options.output_formats | ||
] | ||
), | ||
) | ||
query.bindValue(":file", task.file_path) | ||
query.bindValue(":output_folder", task.output_directory) | ||
query.bindValue(":language", task.transcription_options.language) | ||
query.bindValue( | ||
":model_type", task.transcription_options.model.model_type.value | ||
) | ||
query.bindValue(":source", task.source.value) | ||
query.bindValue(":status", FileTranscriptionTask.Status.QUEUED.value) | ||
query.bindValue(":task", task.transcription_options.task.value) | ||
query.bindValue(":time_queued", datetime.now().isoformat()) | ||
query.bindValue(":url", task.url) | ||
query.bindValue( | ||
":whisper_model_size", | ||
task.transcription_options.model.whisper_model_size.value | ||
if task.transcription_options.model.whisper_model_size | ||
else None, | ||
) | ||
if not query.exec(): | ||
raise Exception(query.lastError().text()) | ||
|
||
def update_transcription_as_started(self, id: UUID): | ||
query = self._create_query() | ||
query.prepare( | ||
""" | ||
UPDATE transcription | ||
SET status = :status, time_started = :time_started | ||
WHERE id = :id | ||
""" | ||
) | ||
|
||
query.bindValue(":id", str(id)) | ||
query.bindValue(":status", FileTranscriptionTask.Status.IN_PROGRESS.value) | ||
query.bindValue(":time_started", datetime.now().isoformat()) | ||
if not query.exec(): | ||
raise Exception(query.lastError().text()) | ||
|
||
def update_transcription_as_failed(self, id: UUID, error: str): | ||
query = self._create_query() | ||
query.prepare( | ||
""" | ||
UPDATE transcription | ||
SET status = :status, time_ended = :time_ended, error_message = :error_message | ||
WHERE id = :id | ||
""" | ||
) | ||
|
||
query.bindValue(":id", str(id)) | ||
query.bindValue(":status", FileTranscriptionTask.Status.FAILED.value) | ||
query.bindValue(":time_ended", datetime.now().isoformat()) | ||
query.bindValue(":error_message", error) | ||
if not query.exec(): | ||
raise Exception(query.lastError().text()) | ||
|
||
def update_transcription_as_canceled(self, id: UUID): | ||
query = self._create_query() | ||
query.prepare( | ||
""" | ||
UPDATE transcription | ||
SET status = :status, time_ended = :time_ended | ||
WHERE id = :id | ||
""" | ||
) | ||
|
||
query.bindValue(":id", str(id)) | ||
query.bindValue(":status", FileTranscriptionTask.Status.CANCELED.value) | ||
query.bindValue(":time_ended", datetime.now().isoformat()) | ||
if not query.exec(): | ||
raise Exception(query.lastError().text()) | ||
|
||
def update_transcription_progress(self, id: UUID, progress: float): | ||
query = self._create_query() | ||
query.prepare( | ||
""" | ||
UPDATE transcription | ||
SET status = :status, progress = :progress | ||
WHERE id = :id | ||
""" | ||
) | ||
|
||
query.bindValue(":id", str(id)) | ||
query.bindValue(":status", FileTranscriptionTask.Status.IN_PROGRESS.value) | ||
query.bindValue(":progress", progress) | ||
if not query.exec(): | ||
raise Exception(query.lastError().text()) | ||
|
||
def update_transcription_as_completed(self, id: UUID): | ||
query = self._create_query() | ||
query.prepare( | ||
""" | ||
UPDATE transcription | ||
SET status = :status, time_ended = :time_ended | ||
WHERE id = :id | ||
""" | ||
) | ||
|
||
query.bindValue(":id", str(id)) | ||
query.bindValue(":status", FileTranscriptionTask.Status.COMPLETED.value) | ||
query.bindValue(":time_ended", datetime.now().isoformat()) | ||
if not query.exec(): | ||
raise Exception(query.lastError().text()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from PyQt6.QtSql import QSqlDatabase | ||
|
||
from buzz.db.dao.dao import DAO | ||
from buzz.db.entity.transcription_segment import TranscriptionSegment | ||
|
||
|
||
class TranscriptionSegmentDAO(DAO[TranscriptionSegment]): | ||
entity = TranscriptionSegment | ||
|
||
def __init__(self, db: QSqlDatabase): | ||
super().__init__("transcription_segment", db) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import logging | ||
import os | ||
import sqlite3 | ||
import tempfile | ||
|
||
from PyQt6.QtSql import QSqlDatabase | ||
from platformdirs import user_data_dir | ||
|
||
from buzz.db.helpers import ( | ||
run_sqlite_migrations, | ||
copy_transcriptions_from_json_to_sqlite, | ||
mark_in_progress_and_queued_transcriptions_as_canceled, | ||
) | ||
|
||
APP_DB_PATH = os.path.join(user_data_dir("Buzz"), "Buzz.sqlite") | ||
|
||
|
||
def setup_app_db() -> QSqlDatabase: | ||
return _setup_db(APP_DB_PATH) | ||
|
||
|
||
def setup_test_db() -> QSqlDatabase: | ||
return _setup_db(tempfile.mktemp()) | ||
|
||
|
||
def _setup_db(path: str) -> QSqlDatabase: | ||
# Run migrations | ||
db = sqlite3.connect(path) | ||
run_sqlite_migrations(db) | ||
copy_transcriptions_from_json_to_sqlite(db) | ||
mark_in_progress_and_queued_transcriptions_as_canceled(db) | ||
db.close() | ||
|
||
db = QSqlDatabase.addDatabase("QSQLITE") | ||
db.setDatabaseName(path) | ||
if not db.open(): | ||
raise RuntimeError(f"Failed to open database connection: {db.databaseName()}") | ||
logging.debug("Database connection opened: %s", db.databaseName()) | ||
return db |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from abc import ABC | ||
|
||
from PyQt6.QtSql import QSqlRecord | ||
|
||
|
||
class Entity(ABC): | ||
@classmethod | ||
def from_record(cls, record: QSqlRecord): | ||
entity = cls() | ||
for i in range(record.count()): | ||
setattr(entity, record.fieldName(i), record.value(i)) | ||
return entity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import datetime | ||
import os | ||
import uuid | ||
from dataclasses import dataclass, field | ||
|
||
from buzz.db.entity.entity import Entity | ||
from buzz.model_loader import ModelType | ||
from buzz.settings.settings import Settings | ||
from buzz.transcriber.transcriber import OutputFormat, Task, FileTranscriptionTask | ||
|
||
|
||
@dataclass | ||
class Transcription(Entity): | ||
status: str = FileTranscriptionTask.Status.QUEUED.value | ||
task: str = Task.TRANSCRIBE.value | ||
model_type: str = ModelType.WHISPER.value | ||
whisper_model_size: str | None = None | ||
language: str | None = None | ||
id: str = field(default_factory=lambda: str(uuid.uuid4())) | ||
error_message: str | None = None | ||
file: str | None = None | ||
time_queued: str = datetime.datetime.now().isoformat() | ||
|
||
@property | ||
def id_as_uuid(self): | ||
return uuid.UUID(hex=self.id) | ||
|
||
@property | ||
def status_as_status(self): | ||
return FileTranscriptionTask.Status(self.status) | ||
|
||
def get_output_file_path( | ||
self, | ||
output_format: OutputFormat, | ||
output_directory: str | None = None, | ||
): | ||
input_file_name = os.path.splitext(os.path.basename(self.file))[0] | ||
|
||
date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S") | ||
|
||
export_file_name_template = Settings().get_default_export_file_template() | ||
|
||
output_file_name = ( | ||
export_file_name_template.replace("{{ input_file_name }}", input_file_name) | ||
.replace("{{ task }}", self.task) | ||
.replace("{{ language }}", self.language or "") | ||
.replace("{{ model_type }}", self.model_type) | ||
.replace("{{ model_size }}", self.whisper_model_size or "") | ||
.replace("{{ date_time }}", date_time_now) | ||
+ f".{output_format.value}" | ||
) | ||
|
||
output_directory = output_directory or os.path.dirname(self.file) | ||
return os.path.join(output_directory, output_file_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from dataclasses import dataclass | ||
|
||
from buzz.db.entity.entity import Entity | ||
|
||
|
||
@dataclass | ||
class TranscriptionSegment(Entity): | ||
start_time: int | ||
end_time: int | ||
text: str | ||
transcription_id: str |
Oops, something went wrong.