diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 18690384f..78df47297 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -419,7 +419,12 @@ def asr_language_selector( options.insert(0, None) # handle non-canonical language codes - st.session_state[key] = get_language_match(st.session_state.get(key), options) + old_lang = st.session_state.get(key) + if old_lang: + try: + st.session_state[key] = normalised_lang_in_collection(old_lang, options) + except UserError: + st.session_state[key] = None return st.selectbox( label=label, @@ -429,30 +434,6 @@ def asr_language_selector( ) -def get_language_match(lang: str | None, languages: list[str]) -> str | None: - import langcodes - - if not lang: - return None - - if lang in languages: - return lang - - try: - lan = langcodes.Language.get(lang).language - except langcodes.LanguageTagError: - return None - - for language in languages: - try: - if language and langcodes.Language.get(language).language == lan: - return language - except langcodes.LanguageTagError: - pass - - return None - - def lang_format_func(l): import langcodes @@ -597,14 +578,27 @@ def run_google_translate( def normalised_lang_in_collection(target: str, collection: typing.Iterable[str]) -> str: import langcodes - for candidate in collection: - if langcodes.get(candidate).language == langcodes.get(target).language: - return candidate - - raise UserError( + ERROR = UserError( f"Unsupported language: {target!r} | must be one of {set(collection)}" ) + if target in collection: + return target + + try: + target_lan = langcodes.Language.get(target).language + except langcodes.LanguageTagError: + raise ERROR + + for candidate in collection: + try: + if candidate and langcodes.Language.get(candidate).language == target_lan: + return candidate + except langcodes.LanguageTagError: + pass + + raise ERROR + def _translate_text( text: str, diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 712eaf3f6..ac1826494 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -5,7 +5,6 @@ from django.db import transaction from django.utils.text import slugify from furl import furl -import langcodes import gooey_ui as st from app_users.models import AppUser diff --git a/routers/twilio_api.py b/routers/twilio_api.py index d80df7a75..d8ce8d2a4 100644 --- a/routers/twilio_api.py +++ b/routers/twilio_api.py @@ -59,22 +59,30 @@ def get_twilio_tts_voice(bi: BotIntegration) -> str: def get_twilio_asr_language(bi: BotIntegration) -> str: - from daras_ai_v2.asr import get_language_match + from daras_ai_v2.asr import normalised_lang_in_collection run = bi.get_active_saved_run() state: dict = run.state - asr_language = get_language_match( - state.get("asr_language"), TWILIO_ASR_SUPPORTED_LANGUAGES - ) + asr_language = state.get("asr_language") if asr_language: - return asr_language + try: + asr_language = normalised_lang_in_collection( + asr_language, TWILIO_ASR_SUPPORTED_LANGUAGES + ) + return asr_language + except: + pass - user_language = get_language_match( - state.get("user_language"), TWILIO_ASR_SUPPORTED_LANGUAGES - ) + user_language = state.get("user_language") if user_language: - return user_language + try: + user_language = normalised_lang_in_collection( + user_language, TWILIO_ASR_SUPPORTED_LANGUAGES + ) + return user_language + except: + pass return DEFAULT_ASR_LANGUAGE