Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: openai api transcriber #652

Merged
merged 2 commits into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions buzz/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Key(enum.Enum):
"transcription-tasks-table/column-visibility"
)

MAIN_WINDOW = "main-window"

def set_value(self, key: Key, value: typing.Any) -> None:
self.settings.setValue(key.value, value)

Expand Down
115 changes: 80 additions & 35 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import json
import logging
import math
import multiprocessing
import os
import subprocess
Expand Down Expand Up @@ -286,56 +287,100 @@
self.task,
)

wav_file = tempfile.mktemp() + ".wav"
mp3_file = tempfile.mktemp() + ".mp3"

# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", self.file_path,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(whisper.audio.SAMPLE_RATE),
wav_file,
]
# fmt: on
cmd = ["ffmpeg", "-i", self.file_path, mp3_file]

try:
subprocess.run(cmd, capture_output=True, check=True)
except subprocess.CalledProcessError as exc:
logging.exception("")
raise Exception(exc.stderr.decode("utf-8"))

# TODO: Check if file size is more than 25MB (2.5 minutes), then chunk
audio_file = open(wav_file, "rb")
# fmt: off
cmd = [
"ffprobe",
"-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1",
mp3_file,
]
# fmt: on
duration_secs = float(
subprocess.run(cmd, capture_output=True, check=True).stdout.decode("utf-8")
)

total_size = os.path.getsize(mp3_file)
max_chunk_size = 25 * 1024 * 1024

openai.api_key = (
self.transcription_task.transcription_options.openai_access_token
)
language = self.transcription_task.transcription_options.language
response_format = "verbose_json"
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
transcript = openai.Audio.translate(
"whisper-1",
audio_file,
response_format=response_format,
language=language,
)
else:
transcript = openai.Audio.transcribe(
"whisper-1",
audio_file,
response_format=response_format,
language=language,

self.progress.emit((0, 100))

if total_size < max_chunk_size:
return self.get_segments_for_file(mp3_file)

# If the file is larger than 25MB, split into chunks
# and transcribe each chunk separately
num_chunks = math.ceil(total_size / max_chunk_size)
chunk_duration = duration_secs / num_chunks

Check warning on line 328 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L327-L328

Added lines #L327 - L328 were not covered by tests

segments = []

Check warning on line 330 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L330

Added line #L330 was not covered by tests

for i in range(num_chunks):
chunk_start = i * chunk_duration
chunk_end = min((i + 1) * chunk_duration, duration_secs)

Check warning on line 334 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L332-L334

Added lines #L332 - L334 were not covered by tests

chunk_file = tempfile.mktemp() + ".mp3"

Check warning on line 336 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L336

Added line #L336 was not covered by tests

# fmt: off
cmd = [

Check warning on line 339 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L339

Added line #L339 was not covered by tests
"ffmpeg",
"-i", mp3_file,
"-ss", str(chunk_start),
"-to", str(chunk_end),
"-c", "copy",
chunk_file,
]
# fmt: on
subprocess.run(cmd, capture_output=True, check=True)
logging.debug('Created chunk file "%s"', chunk_file)

Check warning on line 349 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L348-L349

Added lines #L348 - L349 were not covered by tests

segments.extend(

Check warning on line 351 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L351

Added line #L351 was not covered by tests
self.get_segments_for_file(
chunk_file, offset_ms=int(chunk_start * 1000)
)
)
os.remove(chunk_file)
self.progress.emit((i + 1, num_chunks))

Check warning on line 357 in buzz/transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber.py#L356-L357

Added lines #L356 - L357 were not covered by tests

segments = [
Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"])
for segment in transcript["segments"]
]
return segments

def get_segments_for_file(self, file: str, offset_ms: int = 0):
with open(file, "rb") as audio_file:
kwargs = {
"model": "whisper-1",
"file": audio_file,
"response_format": "verbose_json",
"language": self.transcription_task.transcription_options.language,
}
transcript = (
openai.Audio.translate(**kwargs)
if self.transcription_task.transcription_options.task == Task.TRANSLATE
else openai.Audio.transcribe(**kwargs)
)

return [
Segment(
int(segment["start"] * 1000 + offset_ms),
int(segment["end"] * 1000 + offset_ms),
segment["text"],
)
for segment in transcript["segments"]
]

def stop(self):
pass

Expand Down
19 changes: 19 additions & 0 deletions buzz/widgets/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@

self.load_tasks_from_cache()

self.load_geometry()

def dragEnterEvent(self, event):
# Accept file drag events
if event.mimeData().hasUrls():
Expand Down Expand Up @@ -314,10 +316,27 @@
self.toolbar.set_shortcuts(shortcuts=self.shortcuts)
self.shortcut_settings.save(shortcuts=self.shortcuts)

def resizeEvent(self, event):
self.save_geometry()

Check warning on line 320 in buzz/widgets/main_window.py

View check run for this annotation

Codecov / codecov/patch

buzz/widgets/main_window.py#L320

Added line #L320 was not covered by tests

def closeEvent(self, event: QtGui.QCloseEvent) -> None:
self.save_geometry()

self.transcriber_worker.stop()
self.transcriber_thread.quit()
self.transcriber_thread.wait()
self.save_tasks_to_cache()
self.shortcut_settings.save(shortcuts=self.shortcuts)
super().closeEvent(event)

def save_geometry(self):
self.settings.begin_group(Settings.Key.MAIN_WINDOW)
self.settings.settings.setValue("geometry", self.saveGeometry())
self.settings.end_group()

def load_geometry(self):
self.settings.begin_group(Settings.Key.MAIN_WINDOW)
geometry = self.settings.settings.value("geometry")
if geometry is not None:
self.restoreGeometry(geometry)
self.settings.end_group()
12 changes: 7 additions & 5 deletions buzz/widgets/transcriber/file_transcriber_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def __init__(
self.setLayout(layout)
self.setFixedSize(self.sizeHint())

self.reset_transcriber_controls()

def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat):
def on_checkbox_state_changed(state: int):
if state == Qt.CheckState.Checked.value:
Expand All @@ -158,11 +160,6 @@ def on_transcription_options_changed(
self, transcription_options: TranscriptionOptions
):
self.transcription_options = transcription_options
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type == ModelType.HUGGING_FACE
or self.transcription_options.model.model_type
== ModelType.OPEN_AI_WHISPER_API
)
if self.transcription_options.openai_access_token != "":
self.openai_access_token_changed.emit(
self.transcription_options.openai_access_token
Expand Down Expand Up @@ -213,6 +210,11 @@ def on_download_model_error(self, error: str):

def reset_transcriber_controls(self):
self.run_button.setDisabled(False)
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type == ModelType.HUGGING_FACE
or self.transcription_options.model.model_type
== ModelType.OPEN_AI_WHISPER_API
)

def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
Expand Down
33 changes: 29 additions & 4 deletions buzz/widgets/transcription_tasks_table_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TableColDef:
id: str
header: str
column_index: int
value_getter: Callable[..., str]
value_getter: Callable[[FileTranscriptionTask], str]
width: Optional[int] = None
hidden: bool = False
hidden_toggleable: bool = True
Expand All @@ -37,6 +37,8 @@ class Column(enum.Enum):
MODEL = auto()
TASK = auto()
STATUS = auto()
DATE_ADDED = auto()
DATE_COMPLETED = auto()

return_clicked = pyqtSignal()

Expand Down Expand Up @@ -78,7 +80,7 @@ def __init__(self, parent: Optional[QWidget] = None):
header=_("Task"),
column_index=self.Column.TASK.value,
value_getter=lambda task: self.get_task_label(task),
width=180,
width=120,
hidden=True,
),
TableColDef(
Expand All @@ -89,6 +91,28 @@ def __init__(self, parent: Optional[QWidget] = None):
width=180,
hidden_toggleable=False,
),
TableColDef(
id="date_added",
header=_("Date Added"),
column_index=self.Column.DATE_ADDED.value,
value_getter=lambda task: task.queued_at.strftime("%Y-%m-%d %H:%M:%S")
if task.queued_at is not None
else "",
width=180,
hidden=True,
),
TableColDef(
id="date_completed",
header=_("Date Completed"),
column_index=self.Column.DATE_COMPLETED.value,
value_getter=lambda task: task.completed_at.strftime(
"%Y-%m-%d %H:%M:%S"
)
if task.completed_at is not None
else "",
width=180,
hidden=True,
),
]

self.setColumnCount(len(self.column_definitions))
Expand Down Expand Up @@ -155,8 +179,9 @@ def upsert_task(self, task: FileTranscriptionTask):
item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
self.setItem(row_index, definition.column_index, item)
else:
status_widget = self.item(task_row_index, self.Column.STATUS.value)
status_widget.setText(task.status_text())
for definition in self.column_definitions:
item = self.item(task_row_index, definition.column_index)
item.setText(definition.value_getter(task))

@staticmethod
def get_task_label(task: FileTranscriptionTask) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def set_segment_text(self, text: str):
self.task_changed.emit()


# TODO: Fix player duration and add spacer below
class TranscriptionViewerWidget(QWidget):
transcription_task: FileTranscriptionTask
task_changed = pyqtSignal()
Expand Down
32 changes: 32 additions & 0 deletions tests/transcriber_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
whisper_cpp_params,
write_output,
TranscriptionOptions,
OpenAIWhisperAPIFileTranscriber,
)
from buzz.recording_transcriber import RecordingTranscriber
from tests.mock_sounddevice import MockInputStream
Expand Down Expand Up @@ -70,6 +71,37 @@ def test_should_transcribe(self, qtbot):
assert "Bienvenue dans Passe" in text


class TestOpenAIWhisperAPIFileTranscriber:
def test_transcribe(self):
file_path = "testdata/whisper-french.mp3"
transcriber = OpenAIWhisperAPIFileTranscriber(
task=FileTranscriptionTask(
file_path=file_path,
transcription_options=(
TranscriptionOptions(
openai_access_token=os.getenv("OPENAI_ACCESS_TOKEN"),
)
),
file_transcription_options=(
FileTranscriptionOptions(file_paths=[file_path])
),
model_path="",
)
)
mock_completed = Mock()
transcriber.completed.connect(mock_completed)
mock_openai_result = {"segments": [{"start": 0, "end": 6.56, "text": "Hello"}]}
with patch("openai.Audio.transcribe", return_value=mock_openai_result):
transcriber.run()

called_segments = mock_completed.call_args[0][0]

assert len(called_segments) == 1
assert called_segments[0].start == 0
assert called_segments[0].end == 6560
assert called_segments[0].text == "Hello"


class TestWhisperCppFileTranscriber:
@pytest.mark.parametrize(
"word_level_timings,expected_segments",
Expand Down
Loading