Skip to content

Commit

Permalink
fix transliteration tests
Browse files Browse the repository at this point in the history
use normalised_lang_in_collection() to normalize translation languages
remove MinT translation bypass inside google translate
  • Loading branch information
devxpy committed Jun 26, 2024
1 parent 35e0385 commit 4330466
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 131 deletions.
113 changes: 52 additions & 61 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
name: Python tests

on: [push, workflow_dispatch]
on: [ push, workflow_dispatch ]

jobs:
test:
runs-on: ubuntu-22.04
strategy:
fail-fast: false
matrix:
python-version: ["3.10.12"]
python-version: [ "3.10.12" ]
poetry-version: [ "1.8.3" ]

# Service containers to run with `test`
services:
Expand All @@ -23,10 +23,6 @@ jobs:
POSTGRES_PASSWORD: password
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
--name postgres
ports:
# Maps tcp port 5432 on service container to the host
Expand All @@ -38,68 +34,63 @@ jobs:
# Set health checks to wait until redis has started
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-interval 5s
--health-timeout 5s
--health-retries 5
ports:
# Maps tcp port 6379 on service container to the host
- 6379:6379
steps:
- name: Increase max_connections
run: |
docker exec -i postgres bash << EOF
sed -i -e 's/max_connections = 100/max_connections = 10000/' /var/lib/postgresql/data/postgresql.conf
EOF
- name: Restart postgres
run: |
docker restart --time 0 postgres && sleep 5
- uses: actions/checkout@v4
# with:
# submodules: recursive
# https://remarkablemark.org/blog/2022/05/12/github-actions-postgresql-increase-max-connections-and-shared-buffers/
- name: Increase max_connections
run: >-
docker exec -i postgres bash << EOF
sed -i -e 's/max_connections = 100/max_connections = 10000/' /var/lib/postgresql/data/postgresql.conf
EOF
- name: Restart postgres
run: >-
docker restart postgres
&& while ! docker exec postgres pg_isready; do sleep 5; done
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'

- name: Install system dependencies
run: |
sudo apt-get update && sudo apt-get install -y --no-install-recommends \
libpoppler-cpp-dev \
python3-opencv \
postgresql-client \
- name: Install system dependencies
run: >-
sudo apt-get update && sudo apt-get install -y --no-install-recommends
libpoppler-cpp-dev
python3-opencv
postgresql-client
libzbar0
- name: Install python dependencies
run: |
pip install -U poetry pip && poetry install --only main --no-interaction
- uses: actions/checkout@v4

# - name: Load secrets into env
# uses: oNaiPs/secrets-to-env-action@v1
# with:
# secrets: ${{ toJSON(secrets) }}
- name: Setup Python, Poetry and Dependencies
uses: packetcoders/action-setup-cache-python-poetry@main
with:
python-version: ${{matrix.python-version}}
poetry-version: ${{matrix.poetry-version}}
install-args: --only main

- name: Test with pytest
env:
PGHOST: localhost
PGPORT: 5432
PGDATABASE: gooey
PGUSER: postgres
PGPASSWORD: password
REDIS_URL: redis://localhost:6379/0
REDIS_CACHE_URL: redis://localhost:6379/1
APP_BASE_URL: http://localhost:3000
API_BASE_URL: http://localhost:8080
ADMIN_BASE_URL: http://localhost:8000
GOOGLE_APPLICATION_CREDENTIALS_JSON: ${{ secrets.GOOGLE_APPLICATION_CREDENTIALS_JSON }}
GS_BUCKET_NAME: ${{ secrets.GS_BUCKET_NAME }}
STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }}
ELEVEN_LABS_API_KEY: ${{ secrets.ELEVEN_LABS_API_KEY }}
AZURE_SPEECH_REGION: ${{ secrets.AZURE_SPEECH_REGION }}
AZURE_SPEECH_KEY: ${{ secrets.AZURE_SPEECH_KEY }}
AZURE_FORM_RECOGNIZER_ENDPOINT: ${{ secrets.AZURE_FORM_RECOGNIZER_ENDPOINT }}
AZURE_FORM_RECOGNIZER_KEY: ${{ secrets.AZURE_FORM_RECOGNIZER_KEY }}
run: |
poetry run ./scripts/run-tests.sh
- name: Run tests
env:
PGHOST: localhost
PGPORT: 5432
PGDATABASE: gooey
PGUSER: postgres
PGPASSWORD: password
REDIS_URL: redis://localhost:6379/0
REDIS_CACHE_URL: redis://localhost:6379/1
APP_BASE_URL: http://localhost:3000
API_BASE_URL: http://localhost:8080
ADMIN_BASE_URL: http://localhost:8000
GOOGLE_APPLICATION_CREDENTIALS_JSON: ${{ secrets.GOOGLE_APPLICATION_CREDENTIALS_JSON }}
GS_BUCKET_NAME: ${{ secrets.GS_BUCKET_NAME }}
STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }}
ELEVEN_LABS_API_KEY: ${{ secrets.ELEVEN_LABS_API_KEY }}
AZURE_SPEECH_REGION: ${{ secrets.AZURE_SPEECH_REGION }}
AZURE_SPEECH_KEY: ${{ secrets.AZURE_SPEECH_KEY }}
AZURE_FORM_RECOGNIZER_ENDPOINT: ${{ secrets.AZURE_FORM_RECOGNIZER_ENDPOINT }}
AZURE_FORM_RECOGNIZER_KEY: ${{ secrets.AZURE_FORM_RECOGNIZER_KEY }}
TEST_SLACK_TEAM_ID: ${{ secrets.TEST_SLACK_TEAM_ID }}
TEST_SLACK_USER_ID: ${{ secrets.TEST_SLACK_USER_ID }}
TEST_SLACK_AUTH_TOKEN: ${{ secrets.TEST_SLACK_AUTH_TOKEN }}
run: |
poetry run ./scripts/run-tests.sh
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
- Cloud Speech Administrator
- Cloud Translation API Admin
- Firebase Authentication Admin
5. Download the `serviceAccountKey.json` and save it to the project root.
- Storage Admin
5. Create and Download a JSON Key for this service account and save it to the project root as `serviceAccountKey.json`.

* Run tests to see if everything is working fine:
```
Expand Down
93 changes: 45 additions & 48 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path
import os.path
import tempfile
import typing
from enum import Enum

import requests
Expand Down Expand Up @@ -36,6 +37,7 @@

SHORT_FILE_CUTOFF = 5 * 1024 * 1024 # 1 MB

# https://cloud.google.com/translate/docs/languages#roman
TRANSLITERATION_SUPPORTED = {"ar", "bn", " gu", "hi", "ja", "kn", "ru", "ta", "te"}

# https://cloud.google.com/speech-to-text/docs/speech-to-text-supported-languages
Expand Down Expand Up @@ -395,15 +397,6 @@ def google_translate_source_languages() -> dict[str, str]:
}


def get_language_in_collection(langcode: str, languages):
import langcodes

for lang in languages:
if langcodes.get(lang).language == langcodes.get(langcode).language:
return langcode
return None


def asr_language_selector(
selected_model: AsrModels,
label="##### Spoken Language",
Expand Down Expand Up @@ -484,26 +477,17 @@ def run_ghana_nlp_translate(
target_language: str,
source_language: str,
) -> list[str]:
import langcodes

assert (
target_language in GHANA_NLP_SUPPORTED
), "Ghana NLP does not support this target language"
assert source_language, "Source language is required for Ghana NLP"

if source_language not in GHANA_NLP_SUPPORTED:
src = langcodes.Language.get(source_language).language
for lang in GHANA_NLP_SUPPORTED:
if src == langcodes.Language.get(lang).language:
source_language = lang
break
assert (
source_language in GHANA_NLP_SUPPORTED
), "Ghana NLP does not support this source language"

source_language and target_language
), "Both Source & Target language is required for Ghana NLP"
source_language = normalised_lang_in_collection(
source_language, GHANA_NLP_SUPPORTED
)
target_language = normalised_lang_in_collection(
target_language, GHANA_NLP_SUPPORTED
)
if source_language == target_language:
return texts

return map_parallel(
lambda doc: _call_ghana_nlp_chunked(doc, source_language, target_language),
texts,
Expand Down Expand Up @@ -550,50 +534,67 @@ def run_google_translate(
list[str]: Translated text.
"""
from google.cloud import translate_v2 as translate
import langcodes

# convert to BCP-47 format (google handles consistent language codes but sometimes gets confused by a mix of iso2 and iso3 which we have)
supported_languages = google_translate_target_languages()
if source_language:
source_language = langcodes.Language.get(source_language).to_tag()
source_language = get_language_in_collection(
source_language, google_translate_source_languages().keys()
) # this will default to autodetect if language is not found as supported
target_language = langcodes.Language.get(target_language).to_tag()
target_language: str | None = get_language_in_collection(
target_language, google_translate_target_languages().keys()
try:
source_language = normalised_lang_in_collection(
source_language, supported_languages
)
except UserError:
source_language = None # autodetect
target_language = normalised_lang_in_collection(
target_language, supported_languages
)
if not target_language:
raise UserError(f"Unsupported target language: {target_language!r}")

# if the language supports transliteration, we should check if the script is Latin
if source_language and source_language not in TRANSLITERATION_SUPPORTED:
language_codes = [source_language] * len(texts)
detected_source_languges = [source_language] * len(texts)
else:
translate_client = translate.Client()
detections = flatten(
translate_client.detect_language(texts[i : i + TRANSLATE_BATCH_SIZE])
for i in range(0, len(texts), TRANSLATE_BATCH_SIZE)
)
language_codes = [detection["language"] for detection in detections]
detected_source_languges = [detection["language"] for detection in detections]

# fix for when sometimes google might detect a different language than the user provided one
if source_language:
detected_source_languges = [
code if source_language in code.split("-")[0] else source_language
for code in detected_source_languges
]

return map_parallel(
lambda text, source: _translate_text(
text, source, target_language, glossary_url
lambda text, src_lang: _translate_text(
text, target_language, src_lang, glossary_url
),
texts,
language_codes,
detected_source_languges,
max_workers=TRANSLATE_BATCH_SIZE,
)


def normalised_lang_in_collection(target: str, collection: typing.Iterable[str]) -> str:
import langcodes

for candidate in collection:
if langcodes.get(candidate).language == langcodes.get(target).language:
return candidate

raise UserError(
f"Unsupported language: {target!r} | must be one of {set(collection)}"
)


def _translate_text(
text: str,
source_language: str,
target_language: str,
source_language: str,
glossary_url: str | None,
) -> str:
is_romanized = source_language.endswith("-Latn")
source_language = source_language.replace("-Latn", "")
source_language = source_language.split("-")[0]
enable_transliteration = (
is_romanized and source_language in TRANSLITERATION_SUPPORTED
)
Expand All @@ -602,9 +603,6 @@ def _translate_text(
if not text or source_language == target_language or source_language == "und":
return text

if source_language == "wo-SN" or target_language == "wo-SN":
return _MinT_translate_one_text(text, source_language, target_language)

config = {
"target_language_code": target_language,
"contents": text,
Expand All @@ -614,7 +612,6 @@ def _translate_text(
if source_language != "auto":
config["source_language_code"] = source_language

# glossary does not work with transliteration
if glossary_url and not enable_transliteration:
from glossary_resources.models import GlossaryResource

Expand Down
23 changes: 17 additions & 6 deletions glossary_resources/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from daras_ai_v2 import settings
from daras_ai_v2.crypto import get_random_doc_id
from glossary_resources.models import GlossaryResource
from tests.test_translation import _test_run_google_translate_one
from tests.test_translation import google_translate_check

GLOSSARY = [
{
Expand All @@ -27,12 +27,18 @@
"pos": "noun",
"description": "well labs agniastra",
},
{
"en-US": "Jalapeño",
"hi-IN": "मिर्ची",
"pos": "noun",
"description": "Jalapeño",
},
]

TRANSLATION_TESTS_GLOSSARY = [
(
"एक एकड़ भूमि के लिए कितनी अग्निअस्त्र की आवश्यकता होती है",
"how many fire extinguishers are required for one acre of land", # default
"एक एकड़ भूमि के लिए कितनी अग्निअस्त्र की आवश्यकता होती है", # source
"how many fire extinguishers are required for one acre of land", # default translation
"how many agniastra are required for one acre of land", # using glossary
),
(
Expand All @@ -45,6 +51,11 @@
"What can we do with AI",
"What can we do with Gooey.AI",
),
(
"मेरे मिर्ची पर लाल धब्बे आ गये हैं",
"My chillies have got red spots",
"My Jalapeño have got red spots",
),
]


Expand All @@ -65,15 +76,15 @@ def glossary_url():

@pytest.mark.skipif(not settings.GS_BUCKET_NAME, reason="No GCS bucket")
@pytest.mark.django_db
def test_run_google_translate_glossary(glossary_url, threadpool_subtest):
def test_google_translate_glossary(glossary_url, threadpool_subtest):
for text, expected, expected_with_glossary in TRANSLATION_TESTS_GLOSSARY:
threadpool_subtest(
_test_run_google_translate_one,
google_translate_check,
text,
expected,
)
threadpool_subtest(
_test_run_google_translate_one,
google_translate_check,
text,
expected_with_glossary,
glossary_url=glossary_url,
Expand Down
Loading

0 comments on commit 4330466

Please sign in to comment.