From 762965936353132b6dd326e296fe451b553b151f Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 5 Oct 2023 19:36:34 +0530 Subject: [PATCH] refactor asr code --- daras_ai_v2/asr.py | 68 ++++++++++++------------------------ daras_ai_v2/vector_search.py | 4 ++- recipes/DocExtract.py | 2 +- recipes/asr.py | 26 ++++++-------- 4 files changed, 36 insertions(+), 64 deletions(-) diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 9870affe1..1d18757a7 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -126,37 +126,12 @@ def google_translate_languages() -> dict[str, str]: } -def _get_asr_languages(selected_model: AsrModels) -> set[str]: - forced_lang = forced_asr_languages.get(selected_model) - if forced_lang: - return {forced_lang} - - return asr_supported_languages.get(selected_model, set()) - - def asr_language_selector( - selected_model: AsrModels | list[AsrModels], + selected_models: list[AsrModels], label="##### Spoken Language", key="language", ): - if not isinstance(selected_model, list): - selected_model = [selected_model] - languages = set.intersection( - *[ - set( - map(lambda l: langcodes.Language.get(l).language, _get_asr_languages(m)) - ) - for m in selected_model - ] - ) - - if len(languages) < 1: - st.session_state[key] = None - return - elif len(languages) == 1: - st.session_state[key] = languages.pop() - return - + languages = set.intersection(*map(_get_asr_languages, selected_models)) options = [None, *languages] # handle non-canonical language codes @@ -178,6 +153,15 @@ def asr_language_selector( ) +def _get_asr_languages(selected_model: AsrModels) -> set[str]: + forced_lang = forced_asr_languages.get(selected_model) + if forced_lang: + languages = {forced_lang} + else: + languages = asr_supported_languages.get(selected_model, set()) + return {langcodes.Language.get(lang).language for lang in languages} + + def run_google_translate( texts: list[str], target_language: str, @@ -284,25 +268,22 @@ def get_google_auth_session(): def run_asr( audio_url: str, - selected_model: str | list[str], + selected_models: list[str], language: str = None, output_format: str = "text", -) -> str | AsrOutputJson | list[str | AsrOutputJson]: +) -> list[str] | list[AsrOutputJson]: """ Run ASR on audio. Args: audio_url (str): url of audio to be transcribed. - selected_model (str): ASR model(s) to use. + selected_models (str): ASR model(s) to use. language: language of the audio output_format: format of the output Returns: str: Transcribed text. """ - - if not isinstance(selected_model, list): - selected_model = [selected_model] - selected_models = [AsrModels[m] for m in selected_model] - output_format: AsrOutputFormat = AsrOutputFormat[output_format] + selected_models = [AsrModels[m] for m in selected_models] + output_format = AsrOutputFormat[output_format] is_youtube_url = "youtube" in audio_url or "youtu.be" in audio_url if is_youtube_url: audio_url, size = download_youtube_to_wav(audio_url) @@ -310,22 +291,17 @@ def run_asr( audio_url, size = audio_url_to_wav(audio_url) is_short = size < SHORT_FILE_CUTOFF - outputs = map_parallel( + return map_parallel( lambda model: _run_asr_one_model( - model, - output_format, - audio_url, - language, - is_short, + selected_model=model, + output_format=output_format, + audio_url=audio_url, + language=language, + is_short=is_short, ), selected_models, ) - if len(outputs) == 1: - return outputs[0] - else: - return outputs - def _run_asr_one_model( selected_model: AsrModels, diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index bf6f6e61c..76668717a 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -482,7 +482,9 @@ def doc_url_to_text_pages( f_url = upload_file_from_bytes( f_name, f_bytes, content_type=doc_meta.mime_type ) - pages = [run_asr(f_url, selected_model=selected_asr_model, language="en")] + pages = [ + run_asr(f_url, [selected_asr_model], language="en")[0], + ] case ".csv" | ".xlsx" | ".tsv" | ".ods": import pandas as pd diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index e8c1c97ce..cf147feba 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -382,7 +382,7 @@ def process_source( or "audio/" in doc_meta.mime_type ): yield "Transcribing" - transcript = run_asr(content_url, request.selected_asr_model) + transcript = run_asr(content_url, [request.selected_asr_model])[0] elif "application/pdf" in doc_meta.mime_type: yield "Extracting PDF" transcript = str(azure_doc_extract_pages(content_url)[0]) diff --git a/recipes/asr.py b/recipes/asr.py index 289ba91e8..5bd7b88f0 100644 --- a/recipes/asr.py +++ b/recipes/asr.py @@ -19,7 +19,7 @@ document_uploader, ) from daras_ai_v2.enum_selector_widget import enum_selector, enum_multiselect -from daras_ai_v2.functional import map_parallel +from daras_ai_v2.functional import flatmap_parallel from daras_ai_v2.text_output_widget import text_outputs from recipes.DocSearch import render_documents @@ -44,7 +44,7 @@ class RequestModel(BaseModel): class ResponseModel(BaseModel): raw_output_text: list[str] | None - output_text: list[str | AsrOutputJson] + output_text: list[str] | list[AsrOutputJson] def preview_image(self, state: dict) -> str | None: return DEFAULT_ASR_META_IMG @@ -85,10 +85,9 @@ def render_form_v2(self): ) col1, col2 = st.columns(2, responsive=False) with col1: - if not isinstance(st.session_state.get("selected_model"), list): - st.session_state["selected_model"] = [ - st.session_state["selected_model"] - ] + selected_model = st.session_state.get("selected_model") + if isinstance(selected_model, str): + st.session_state["selected_model"] = [selected_model] selected_model = enum_multiselect( AsrModels, label="##### ASR Models", @@ -142,26 +141,21 @@ def run(self, state: dict): request: AsrPage.RequestModel = self.RequestModel.parse_obj(state) # Run ASR - selected_models: list[str] = request.selected_model or [ - AsrModels.whisper_large_v2.name - ] - if not isinstance(selected_models, list): + selected_models = request.selected_model + if isinstance(selected_models, str): selected_models = [selected_models] yield f"Running {', '.join([AsrModels[m].value for m in selected_models])}..." - asr_output = map_parallel( + asr_output = flatmap_parallel( lambda audio: run_asr( audio_url=audio, - selected_model=selected_models, + selected_models=selected_models, language=request.language, output_format=request.output_format, ), request.documents, ) - if len(selected_models) != 1: - # flatten - asr_output = [out for model_out in asr_output for out in model_out] str_asr_output: list[str] = [ - out if not isinstance(out, AsrModels) else out.get("text", "").strip() + out.get("text", "").strip() if isinstance(out, dict) else out for out in asr_output ]