Skip to content

Commit

Permalink
refactor asr code
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Oct 5, 2023
1 parent 319327f commit 7629659
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 64 deletions.
68 changes: 22 additions & 46 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,37 +126,12 @@ def google_translate_languages() -> dict[str, str]:
}


def _get_asr_languages(selected_model: AsrModels) -> set[str]:
forced_lang = forced_asr_languages.get(selected_model)
if forced_lang:
return {forced_lang}

return asr_supported_languages.get(selected_model, set())


def asr_language_selector(
selected_model: AsrModels | list[AsrModels],
selected_models: list[AsrModels],
label="##### Spoken Language",
key="language",
):
if not isinstance(selected_model, list):
selected_model = [selected_model]
languages = set.intersection(
*[
set(
map(lambda l: langcodes.Language.get(l).language, _get_asr_languages(m))
)
for m in selected_model
]
)

if len(languages) < 1:
st.session_state[key] = None
return
elif len(languages) == 1:
st.session_state[key] = languages.pop()
return

languages = set.intersection(*map(_get_asr_languages, selected_models))
options = [None, *languages]

# handle non-canonical language codes
Expand All @@ -178,6 +153,15 @@ def asr_language_selector(
)


def _get_asr_languages(selected_model: AsrModels) -> set[str]:
forced_lang = forced_asr_languages.get(selected_model)
if forced_lang:
languages = {forced_lang}
else:
languages = asr_supported_languages.get(selected_model, set())
return {langcodes.Language.get(lang).language for lang in languages}


def run_google_translate(
texts: list[str],
target_language: str,
Expand Down Expand Up @@ -284,48 +268,40 @@ def get_google_auth_session():

def run_asr(
audio_url: str,
selected_model: str | list[str],
selected_models: list[str],
language: str = None,
output_format: str = "text",
) -> str | AsrOutputJson | list[str | AsrOutputJson]:
) -> list[str] | list[AsrOutputJson]:
"""
Run ASR on audio.
Args:
audio_url (str): url of audio to be transcribed.
selected_model (str): ASR model(s) to use.
selected_models (str): ASR model(s) to use.
language: language of the audio
output_format: format of the output
Returns:
str: Transcribed text.
"""

if not isinstance(selected_model, list):
selected_model = [selected_model]
selected_models = [AsrModels[m] for m in selected_model]
output_format: AsrOutputFormat = AsrOutputFormat[output_format]
selected_models = [AsrModels[m] for m in selected_models]
output_format = AsrOutputFormat[output_format]
is_youtube_url = "youtube" in audio_url or "youtu.be" in audio_url
if is_youtube_url:
audio_url, size = download_youtube_to_wav(audio_url)
else:
audio_url, size = audio_url_to_wav(audio_url)
is_short = size < SHORT_FILE_CUTOFF

outputs = map_parallel(
return map_parallel(
lambda model: _run_asr_one_model(
model,
output_format,
audio_url,
language,
is_short,
selected_model=model,
output_format=output_format,
audio_url=audio_url,
language=language,
is_short=is_short,
),
selected_models,
)

if len(outputs) == 1:
return outputs[0]
else:
return outputs


def _run_asr_one_model(
selected_model: AsrModels,
Expand Down
4 changes: 3 additions & 1 deletion daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ def doc_url_to_text_pages(
f_url = upload_file_from_bytes(
f_name, f_bytes, content_type=doc_meta.mime_type
)
pages = [run_asr(f_url, selected_model=selected_asr_model, language="en")]
pages = [
run_asr(f_url, [selected_asr_model], language="en")[0],
]
case ".csv" | ".xlsx" | ".tsv" | ".ods":
import pandas as pd

Expand Down
2 changes: 1 addition & 1 deletion recipes/DocExtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def process_source(
or "audio/" in doc_meta.mime_type
):
yield "Transcribing"
transcript = run_asr(content_url, request.selected_asr_model)
transcript = run_asr(content_url, [request.selected_asr_model])[0]
elif "application/pdf" in doc_meta.mime_type:
yield "Extracting PDF"
transcript = str(azure_doc_extract_pages(content_url)[0])
Expand Down
26 changes: 10 additions & 16 deletions recipes/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
document_uploader,
)
from daras_ai_v2.enum_selector_widget import enum_selector, enum_multiselect
from daras_ai_v2.functional import map_parallel
from daras_ai_v2.functional import flatmap_parallel
from daras_ai_v2.text_output_widget import text_outputs
from recipes.DocSearch import render_documents

Expand All @@ -44,7 +44,7 @@ class RequestModel(BaseModel):

class ResponseModel(BaseModel):
raw_output_text: list[str] | None
output_text: list[str | AsrOutputJson]
output_text: list[str] | list[AsrOutputJson]

def preview_image(self, state: dict) -> str | None:
return DEFAULT_ASR_META_IMG
Expand Down Expand Up @@ -85,10 +85,9 @@ def render_form_v2(self):
)
col1, col2 = st.columns(2, responsive=False)
with col1:
if not isinstance(st.session_state.get("selected_model"), list):
st.session_state["selected_model"] = [
st.session_state["selected_model"]
]
selected_model = st.session_state.get("selected_model")
if isinstance(selected_model, str):
st.session_state["selected_model"] = [selected_model]
selected_model = enum_multiselect(
AsrModels,
label="##### ASR Models",
Expand Down Expand Up @@ -142,26 +141,21 @@ def run(self, state: dict):
request: AsrPage.RequestModel = self.RequestModel.parse_obj(state)

# Run ASR
selected_models: list[str] = request.selected_model or [
AsrModels.whisper_large_v2.name
]
if not isinstance(selected_models, list):
selected_models = request.selected_model
if isinstance(selected_models, str):
selected_models = [selected_models]
yield f"Running {', '.join([AsrModels[m].value for m in selected_models])}..."
asr_output = map_parallel(
asr_output = flatmap_parallel(
lambda audio: run_asr(
audio_url=audio,
selected_model=selected_models,
selected_models=selected_models,
language=request.language,
output_format=request.output_format,
),
request.documents,
)
if len(selected_models) != 1:
# flatten
asr_output = [out for model_out in asr_output for out in model_out]
str_asr_output: list[str] = [
out if not isinstance(out, AsrModels) else out.get("text", "").strip()
out.get("text", "").strip() if isinstance(out, dict) else out
for out in asr_output
]

Expand Down

0 comments on commit 7629659

Please sign in to comment.