diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 7cc5337b7..057092d75 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -477,26 +477,17 @@ def run_ghana_nlp_translate( target_language: str, source_language: str, ) -> list[str]: - import langcodes - - assert ( - target_language in GHANA_NLP_SUPPORTED - ), "Ghana NLP does not support this target language" - assert source_language, "Source language is required for Ghana NLP" - - if source_language not in GHANA_NLP_SUPPORTED: - src = langcodes.Language.get(source_language).language - for lang in GHANA_NLP_SUPPORTED: - if src == langcodes.Language.get(lang).language: - source_language = lang - break assert ( - source_language in GHANA_NLP_SUPPORTED - ), "Ghana NLP does not support this source language" - + source_language and target_language + ), "Both Source & Target language is required for Ghana NLP" + source_language = normalised_lang_in_collection( + source_language, GHANA_NLP_SUPPORTED + ) + target_language = normalised_lang_in_collection( + target_language, GHANA_NLP_SUPPORTED + ) if source_language == target_language: return texts - return map_parallel( lambda doc: _call_ghana_nlp_chunked(doc, source_language, target_language), texts, @@ -544,7 +535,7 @@ def run_google_translate( """ from google.cloud import translate_v2 as translate - supported_languages = google_translate_target_languages().keys() + supported_languages = google_translate_target_languages() if source_language: try: source_language = normalised_lang_in_collection( @@ -591,7 +582,9 @@ def normalised_lang_in_collection(target: str, collection: typing.Iterable[str]) if langcodes.get(candidate).language == langcodes.get(target).language: return candidate - raise UserError(f"Unsupported language: {target!r} | must be one of {collection}") + raise UserError( + f"Unsupported language: {target!r} | must be one of {set(collection)}" + ) def _translate_text(