Skip to content

Commit

Permalink
migrate ghana nlp translation to normalised_lang_in_collection()
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Jun 26, 2024
1 parent 7c0f7c3 commit c605143
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,26 +477,17 @@ def run_ghana_nlp_translate(
target_language: str,
source_language: str,
) -> list[str]:
import langcodes

assert (
target_language in GHANA_NLP_SUPPORTED
), "Ghana NLP does not support this target language"
assert source_language, "Source language is required for Ghana NLP"

if source_language not in GHANA_NLP_SUPPORTED:
src = langcodes.Language.get(source_language).language
for lang in GHANA_NLP_SUPPORTED:
if src == langcodes.Language.get(lang).language:
source_language = lang
break
assert (
source_language in GHANA_NLP_SUPPORTED
), "Ghana NLP does not support this source language"

source_language and target_language
), "Both Source & Target language is required for Ghana NLP"
source_language = normalised_lang_in_collection(
source_language, GHANA_NLP_SUPPORTED
)
target_language = normalised_lang_in_collection(
target_language, GHANA_NLP_SUPPORTED
)
if source_language == target_language:
return texts

return map_parallel(
lambda doc: _call_ghana_nlp_chunked(doc, source_language, target_language),
texts,
Expand Down Expand Up @@ -544,7 +535,7 @@ def run_google_translate(
"""
from google.cloud import translate_v2 as translate

supported_languages = google_translate_target_languages().keys()
supported_languages = google_translate_target_languages()
if source_language:
try:
source_language = normalised_lang_in_collection(
Expand Down Expand Up @@ -591,7 +582,9 @@ def normalised_lang_in_collection(target: str, collection: typing.Iterable[str])
if langcodes.get(candidate).language == langcodes.get(target).language:
return candidate

raise UserError(f"Unsupported language: {target!r} | must be one of {collection}")
raise UserError(
f"Unsupported language: {target!r} | must be one of {set(collection)}"
)


def _translate_text(
Expand Down

0 comments on commit c605143

Please sign in to comment.