From b756f27bdbadea43f5eba16bf12da863959ff0c4 Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Sat, 23 Dec 2023 09:36:49 +0000 Subject: [PATCH 1/2] fix: openai api transcriber --- buzz/settings/settings.py | 2 + buzz/transcriber.py | 115 ++++++++++++------ buzz/widgets/main_window.py | 19 +++ .../transcriber/file_transcriber_widget.py | 12 +- .../transcription_tasks_table_widget.py | 33 ++++- .../transcription_viewer_widget.py | 1 + tests/transcriber_test.py | 24 ++++ 7 files changed, 162 insertions(+), 44 deletions(-) diff --git a/buzz/settings/settings.py b/buzz/settings/settings.py index f5e74a620..c450c7438 100644 --- a/buzz/settings/settings.py +++ b/buzz/settings/settings.py @@ -34,6 +34,8 @@ class Key(enum.Enum): "transcription-tasks-table/column-visibility" ) + MAIN_WINDOW = "main-window" + def set_value(self, key: Key, value: typing.Any) -> None: self.settings.setValue(key.value, value) diff --git a/buzz/transcriber.py b/buzz/transcriber.py index 2902e00ae..7a7716f27 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -3,6 +3,7 @@ import enum import json import logging +import math import multiprocessing import os import subprocess @@ -286,21 +287,9 @@ def transcribe(self) -> List[Segment]: self.task, ) - wav_file = tempfile.mktemp() + ".wav" + mp3_file = tempfile.mktemp() + ".mp3" - # fmt: off - cmd = [ - "ffmpeg", - "-nostdin", - "-threads", "0", - "-i", self.file_path, - "-f", "s16le", - "-ac", "1", - "-acodec", "pcm_s16le", - "-ar", str(whisper.audio.SAMPLE_RATE), - wav_file, - ] - # fmt: on + cmd = ["ffmpeg", "-i", self.file_path, mp3_file] try: subprocess.run(cmd, capture_output=True, check=True) @@ -308,34 +297,90 @@ def transcribe(self) -> List[Segment]: logging.exception("") raise Exception(exc.stderr.decode("utf-8")) - # TODO: Check if file size is more than 25MB (2.5 minutes), then chunk - audio_file = open(wav_file, "rb") + # fmt: off + cmd = [ + "ffprobe", + "-v", "error", + "-show_entries", "format=duration", + "-of", "default=noprint_wrappers=1:nokey=1", + mp3_file, + ] + # fmt: on + duration_secs = float( + subprocess.run(cmd, capture_output=True, check=True).stdout.decode("utf-8") + ) + + total_size = os.path.getsize(mp3_file) + max_chunk_size = 25 * 1024 * 1024 + openai.api_key = ( self.transcription_task.transcription_options.openai_access_token ) - language = self.transcription_task.transcription_options.language - response_format = "verbose_json" - if self.transcription_task.transcription_options.task == Task.TRANSLATE: - transcript = openai.Audio.translate( - "whisper-1", - audio_file, - response_format=response_format, - language=language, - ) - else: - transcript = openai.Audio.transcribe( - "whisper-1", - audio_file, - response_format=response_format, - language=language, + + logging.debug("File size is %s", total_size) + + self.progress.emit((0, 100)) + + if total_size < max_chunk_size: + return self.get_segments_for_file(mp3_file) + + num_chunks = math.ceil(total_size / max_chunk_size) + chunk_duration = duration_secs / num_chunks + + segments = [] + + for i in range(num_chunks): + chunk_start = i * chunk_duration + chunk_end = min((i + 1) * chunk_duration, duration_secs) + + chunk_file = tempfile.mktemp() + ".mp3" + + # fmt: off + cmd = [ + "ffmpeg", + "-i", mp3_file, + "-ss", str(chunk_start), + "-to", str(chunk_end), + "-c", "copy", + chunk_file, + ] + # fmt: on + subprocess.run(cmd, capture_output=True, check=True) + logging.debug('Created chunk file "%s"', chunk_file) + + segments.extend( + self.get_segments_for_file( + chunk_file, offset_ms=int(chunk_start * 1000) + ) ) + os.remove(chunk_file) + self.progress.emit((i + 1, num_chunks)) - segments = [ - Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) - for segment in transcript["segments"] - ] return segments + def get_segments_for_file(self, file: str, offset_ms: int = 0): + with open(file, "rb") as audio_file: + kwargs = { + "model": "whisper-1", + "file": audio_file, + "response_format": "verbose_json", + "language": self.transcription_task.transcription_options.language, + } + transcript = ( + openai.Audio.translate(**kwargs) + if self.transcription_task.transcription_options.task == Task.TRANSLATE + else openai.Audio.transcribe(**kwargs) + ) + + return [ + Segment( + int(segment["start"] * 1000 + offset_ms), + int(segment["end"] * 1000 + offset_ms), + segment["text"], + ) + for segment in transcript["segments"] + ] + def stop(self): pass diff --git a/buzz/widgets/main_window.py b/buzz/widgets/main_window.py index cdc2c3205..022df9bea 100644 --- a/buzz/widgets/main_window.py +++ b/buzz/widgets/main_window.py @@ -111,6 +111,8 @@ def __init__(self, tasks_cache=TasksCache()): self.load_tasks_from_cache() + self.load_geometry() + def dragEnterEvent(self, event): # Accept file drag events if event.mimeData().hasUrls(): @@ -314,10 +316,27 @@ def on_shortcuts_changed(self, shortcuts: dict): self.toolbar.set_shortcuts(shortcuts=self.shortcuts) self.shortcut_settings.save(shortcuts=self.shortcuts) + def resizeEvent(self, event): + self.save_geometry() + def closeEvent(self, event: QtGui.QCloseEvent) -> None: + self.save_geometry() + self.transcriber_worker.stop() self.transcriber_thread.quit() self.transcriber_thread.wait() self.save_tasks_to_cache() self.shortcut_settings.save(shortcuts=self.shortcuts) super().closeEvent(event) + + def save_geometry(self): + self.settings.begin_group(Settings.Key.MAIN_WINDOW) + self.settings.settings.setValue("geometry", self.saveGeometry()) + self.settings.end_group() + + def load_geometry(self): + self.settings.begin_group(Settings.Key.MAIN_WINDOW) + geometry = self.settings.settings.value("geometry") + if geometry is not None: + self.restoreGeometry(geometry) + self.settings.end_group() diff --git a/buzz/widgets/transcriber/file_transcriber_widget.py b/buzz/widgets/transcriber/file_transcriber_widget.py index 2627530d4..09ff5d7a3 100644 --- a/buzz/widgets/transcriber/file_transcriber_widget.py +++ b/buzz/widgets/transcriber/file_transcriber_widget.py @@ -145,6 +145,8 @@ def __init__( self.setLayout(layout) self.setFixedSize(self.sizeHint()) + self.reset_transcriber_controls() + def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat): def on_checkbox_state_changed(state: int): if state == Qt.CheckState.Checked.value: @@ -158,11 +160,6 @@ def on_transcription_options_changed( self, transcription_options: TranscriptionOptions ): self.transcription_options = transcription_options - self.word_level_timings_checkbox.setDisabled( - self.transcription_options.model.model_type == ModelType.HUGGING_FACE - or self.transcription_options.model.model_type - == ModelType.OPEN_AI_WHISPER_API - ) if self.transcription_options.openai_access_token != "": self.openai_access_token_changed.emit( self.transcription_options.openai_access_token @@ -213,6 +210,11 @@ def on_download_model_error(self, error: str): def reset_transcriber_controls(self): self.run_button.setDisabled(False) + self.word_level_timings_checkbox.setDisabled( + self.transcription_options.model.model_type == ModelType.HUGGING_FACE + or self.transcription_options.model.model_type + == ModelType.OPEN_AI_WHISPER_API + ) def on_cancel_model_progress_dialog(self): if self.model_loader is not None: diff --git a/buzz/widgets/transcription_tasks_table_widget.py b/buzz/widgets/transcription_tasks_table_widget.py index 41a221562..1b112ea99 100644 --- a/buzz/widgets/transcription_tasks_table_widget.py +++ b/buzz/widgets/transcription_tasks_table_widget.py @@ -24,7 +24,7 @@ class TableColDef: id: str header: str column_index: int - value_getter: Callable[..., str] + value_getter: Callable[[FileTranscriptionTask], str] width: Optional[int] = None hidden: bool = False hidden_toggleable: bool = True @@ -37,6 +37,8 @@ class Column(enum.Enum): MODEL = auto() TASK = auto() STATUS = auto() + DATE_ADDED = auto() + DATE_COMPLETED = auto() return_clicked = pyqtSignal() @@ -78,7 +80,7 @@ def __init__(self, parent: Optional[QWidget] = None): header=_("Task"), column_index=self.Column.TASK.value, value_getter=lambda task: self.get_task_label(task), - width=180, + width=120, hidden=True, ), TableColDef( @@ -89,6 +91,28 @@ def __init__(self, parent: Optional[QWidget] = None): width=180, hidden_toggleable=False, ), + TableColDef( + id="date_added", + header=_("Date Added"), + column_index=self.Column.DATE_ADDED.value, + value_getter=lambda task: task.queued_at.strftime("%Y-%m-%d %H:%M:%S") + if task.queued_at is not None + else "", + width=180, + hidden=True, + ), + TableColDef( + id="date_completed", + header=_("Date Completed"), + column_index=self.Column.DATE_COMPLETED.value, + value_getter=lambda task: task.completed_at.strftime( + "%Y-%m-%d %H:%M:%S" + ) + if task.completed_at is not None + else "", + width=180, + hidden=True, + ), ] self.setColumnCount(len(self.column_definitions)) @@ -155,8 +179,9 @@ def upsert_task(self, task: FileTranscriptionTask): item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable) self.setItem(row_index, definition.column_index, item) else: - status_widget = self.item(task_row_index, self.Column.STATUS.value) - status_widget.setText(task.status_text()) + for definition in self.column_definitions: + item = self.item(task_row_index, definition.column_index) + item.setText(definition.value_getter(task)) @staticmethod def get_task_label(task: FileTranscriptionTask) -> str: diff --git a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py index a04fcc4e1..210f667c9 100644 --- a/buzz/widgets/transcription_viewer/transcription_viewer_widget.py +++ b/buzz/widgets/transcription_viewer/transcription_viewer_widget.py @@ -62,6 +62,7 @@ def set_segment_text(self, text: str): self.task_changed.emit() +# TODO: Fix player duration and add spacer below class TranscriptionViewerWidget(QWidget): transcription_task: FileTranscriptionTask task_changed = pyqtSignal() diff --git a/tests/transcriber_test.py b/tests/transcriber_test.py index 7ac85445c..6904de55e 100644 --- a/tests/transcriber_test.py +++ b/tests/transcriber_test.py @@ -26,6 +26,7 @@ whisper_cpp_params, write_output, TranscriptionOptions, + OpenAIWhisperAPIFileTranscriber, ) from buzz.recording_transcriber import RecordingTranscriber from tests.mock_sounddevice import MockInputStream @@ -70,6 +71,29 @@ def test_should_transcribe(self, qtbot): assert "Bienvenue dans Passe" in text +class TestOpenAIWhisperAPIFileTranscriber: + @pytest.mark.skip() + def test_transcribe(self): + file_path = "testdata/whisper-french.mp3" + transcriber = OpenAIWhisperAPIFileTranscriber( + task=FileTranscriptionTask( + file_path=file_path, + transcription_options=( + TranscriptionOptions( + openai_access_token=os.getenv("OPENAI_ACCESS_TOKEN"), + ) + ), + file_transcription_options=( + FileTranscriptionOptions(file_paths=[file_path]) + ), + model_path="", + ) + ) + transcriber.completed.connect(lambda segments: print(segments)) + transcriber.error.connect(lambda error: print(error)) + transcriber.run() + + class TestWhisperCppFileTranscriber: @pytest.mark.parametrize( "word_level_timings,expected_segments", From 594b9e0ae3d99b08a53908f575b08c73d25a3166 Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Sat, 23 Dec 2023 09:57:07 +0000 Subject: [PATCH 2/2] fix: openai api transcriber --- buzz/transcriber.py | 4 ++-- tests/transcriber_test.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/buzz/transcriber.py b/buzz/transcriber.py index 7a7716f27..c8e3f835a 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -317,13 +317,13 @@ def transcribe(self) -> List[Segment]: self.transcription_task.transcription_options.openai_access_token ) - logging.debug("File size is %s", total_size) - self.progress.emit((0, 100)) if total_size < max_chunk_size: return self.get_segments_for_file(mp3_file) + # If the file is larger than 25MB, split into chunks + # and transcribe each chunk separately num_chunks = math.ceil(total_size / max_chunk_size) chunk_duration = duration_secs / num_chunks diff --git a/tests/transcriber_test.py b/tests/transcriber_test.py index 6904de55e..24e4bc48e 100644 --- a/tests/transcriber_test.py +++ b/tests/transcriber_test.py @@ -72,7 +72,6 @@ def test_should_transcribe(self, qtbot): class TestOpenAIWhisperAPIFileTranscriber: - @pytest.mark.skip() def test_transcribe(self): file_path = "testdata/whisper-french.mp3" transcriber = OpenAIWhisperAPIFileTranscriber( @@ -89,9 +88,18 @@ def test_transcribe(self): model_path="", ) ) - transcriber.completed.connect(lambda segments: print(segments)) - transcriber.error.connect(lambda error: print(error)) - transcriber.run() + mock_completed = Mock() + transcriber.completed.connect(mock_completed) + mock_openai_result = {"segments": [{"start": 0, "end": 6.56, "text": "Hello"}]} + with patch("openai.Audio.transcribe", return_value=mock_openai_result): + transcriber.run() + + called_segments = mock_completed.call_args[0][0] + + assert len(called_segments) == 1 + assert called_segments[0].start == 0 + assert called_segments[0].end == 6560 + assert called_segments[0].text == "Hello" class TestWhisperCppFileTranscriber: