Skip to content

Commit

Permalink
Merge pull request #159 from dara-network/glossary_new
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy authored Nov 7, 2023
2 parents 781fa7a + cd5ec90 commit 2df9984
Show file tree
Hide file tree
Showing 36 changed files with 708 additions and 88 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def threadpool_subtest(subtests, max_workers: int = 8):
ts = []

def submit(fn, *args, **kwargs):
msg = "--".join(map(str, args))
msg = "--".join(map(str, [*args, *kwargs.values()]))

@wraps(fn)
def runner(*args, **kwargs):
Expand Down
15 changes: 10 additions & 5 deletions daras_ai/image_input.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import math
import mimetypes
import os
import re
import uuid
from pathlib import Path

import math
import numpy as np
import requests
from PIL import Image, ImageOps
from furl import furl

from daras_ai_v2 import settings

if False:
from firebase_admin import storage


def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes:
img_cv2 = bytes_to_cv2_img(img_bytes)
Expand Down Expand Up @@ -70,7 +69,9 @@ def storage_blob_for(filename: str) -> "storage.storage.Blob":

filename = safe_filename(filename)
bucket = storage.bucket(settings.GS_BUCKET_NAME)
blob = bucket.blob(f"daras_ai/media/{uuid.uuid1()}/{filename}")
blob = bucket.blob(
os.path.join(settings.GS_MEDIA_PATH, str(uuid.uuid1()), filename)
)
return blob


Expand Down Expand Up @@ -143,3 +144,7 @@ def guess_ext_from_response(response: requests.Response) -> str:
def get_mimetype_from_response(response: requests.Response) -> str:
content_type = response.headers.get("Content-Type", "application/octet-stream")
return content_type.split(";")[0]


def gs_url_to_uri(url: str) -> str:
return "gs://" + "/".join(furl(url).path.segments)
74 changes: 48 additions & 26 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import langcodes
import requests
import typing_extensions
from django.db.models import F
from furl import furl

import gooey_ui as st
from daras_ai.image_input import upload_file_from_bytes
from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri
from daras_ai_v2.gdrive_downloader import (
is_gdrive_url,
gdrive_download,
Expand All @@ -20,7 +21,6 @@
from daras_ai_v2 import settings
from daras_ai_v2.functional import map_parallel
from daras_ai_v2.gpu_server import (
GpuEndpoints,
call_celery_task,
)
from daras_ai_v2.redis_cache import redis_cache_decorator
Expand Down Expand Up @@ -168,14 +168,16 @@ def asr_language_selector(
def run_google_translate(
texts: list[str],
target_language: str,
source_language: str = None,
source_language: str | None = None,
glossary_url: str | None = None,
) -> list[str]:
"""
Translate text using the Google Translate API.
Args:
texts (list[str]): Text to be translated.
target_language (str): Language code to translate to.
source_language (str): Language code to translate from.
glossary_url (str): URL of glossary file.
Returns:
list[str]: Translated text.
"""
Expand All @@ -190,48 +192,69 @@ def run_google_translate(
language_codes = [detection["language"] for detection in detections]

return map_parallel(
lambda text, source: _translate_text(text, source, target_language),
lambda text, source: _translate_text(
text, source, target_language, glossary_url
),
texts,
language_codes,
)


def _translate_text(text: str, source_language: str, target_language: str):
def _translate_text(
text: str,
source_language: str,
target_language: str,
glossary_url: str | None,
) -> str:
is_romanized = source_language.endswith("-Latn")
source_language = source_language.replace("-Latn", "")
enable_transliteration = (
is_romanized and source_language in TRANSLITERATION_SUPPORTED
)

# prevent incorrect API calls
if source_language == target_language or not text:
return text

if source_language == "wo-SN" or target_language == "wo-SN":
return _MinT_translate_one_text(text, source_language, target_language)

config = {
"source_language_code": source_language,
"target_language_code": target_language,
"contents": text,
"mime_type": "text/plain",
"transliteration_config": {"enable_transliteration": enable_transliteration},
}

# glossary does not work with transliteration
if glossary_url and not enable_transliteration:
from glossary_resources.models import GlossaryResource

gr = GlossaryResource.objects.get_or_create_from_url(glossary_url)[0]
GlossaryResource.objects.filter(pk=gr.pk).update(
usage_count=F("usage_count") + 1
)
location = gr.location
config["glossary_config"] = {
"glossary": gr.get_glossary_path(),
"ignoreCase": True,
}
else:
location = "global"

authed_session, project = get_google_auth_session()
res = authed_session.post(
f"https://translation.googleapis.com/v3/projects/{project}/locations/global:translateText",
json.dumps(
{
"source_language_code": source_language,
"target_language_code": target_language,
"contents": text,
"mime_type": "text/plain",
"transliteration_config": {
"enable_transliteration": enable_transliteration
},
}
),
headers={
"Content-Type": "application/json",
},
f"https://translation.googleapis.com/v3/projects/{project}/locations/{location}:translateText",
json=config,
)
res.raise_for_status()
data = res.json()
result = data["translations"][0]

return result["translatedText"].strip()
try:
result = data["glossaryTranslations"][0]["translatedText"]
except (KeyError, IndexError):
result = data["translations"][0]["translatedText"]
return result.strip()


_session = None
Expand Down Expand Up @@ -346,8 +369,7 @@ def run_asr(
)

elif selected_model == AsrModels.usm:
# note: only us-central1 and a few other regions support chirp recognizers (so global can't be used)
location = "us-central1"
location = settings.GCP_REGION

# Create a client
options = ClientOptions(api_endpoint=f"{location}-speech.googleapis.com")
Expand Down Expand Up @@ -379,7 +401,7 @@ def run_asr(
audio_channel_count=1,
)
audio = cloud_speech.BatchRecognizeFileMetadata()
audio.uri = "gs://" + "/".join(furl(audio_url).path.segments)
audio.uri = gs_url_to_uri(audio_url)
# Specify that results should be inlined in the response (only possible for 1 audio file)
output_config = cloud_speech.RecognitionOutputConfig()
output_config.inline_response_config = cloud_speech.InlineOutputConfig()
Expand Down
8 changes: 7 additions & 1 deletion daras_ai_v2/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class BotInterface:
show_feedback_buttons: bool = False
convo: Conversation
recieved_msg_id: str = None
input_glossary: str | None = None
output_glossary: str | None = None

def send_msg(
self,
Expand Down Expand Up @@ -79,6 +81,8 @@ def _unpack_bot_integration(self):
run_id=bi.saved_run.run_id,
uid=bi.saved_run.uid,
)
self.input_glossary = bi.saved_run.state.get("input_glossary_document")
self.output_glossary = bi.saved_run.state.get("output_glossary_document")
else:
self.page_cls = None
self.query_params = {}
Expand Down Expand Up @@ -233,7 +237,9 @@ def _handle_feedback_msg(bot: BotInterface, input_text):
# save the feedback
last_feedback.text = input_text
# translate feedback to english
last_feedback.text_english = " ".join(run_google_translate([input_text], "en"))
last_feedback.text_english = " ".join(
run_google_translate([input_text], "en", glossary_url=bot.input_glossary)
)
last_feedback.save()
# send back a confimation msg
bot.show_feedback_buttons = False # don't show feedback for this confirmation
Expand Down
33 changes: 26 additions & 7 deletions daras_ai_v2/doc_search_settings_widgets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import typing

import gooey_ui as st
Expand All @@ -7,9 +8,13 @@
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.search_ref import CitationStyles

_user_media_url_prefix = os.path.join(
"storage.googleapis.com", settings.GS_BUCKET_NAME, settings.GS_MEDIA_PATH
)


def is_user_uploaded_url(url: str) -> bool:
return f"storage.googleapis.com/{settings.GS_BUCKET_NAME}/daras_ai" in url
return _user_media_url_prefix in url


def document_uploader(
Expand All @@ -26,36 +31,50 @@ def document_uploader(
".mp3",
".aac",
),
):
accept_multiple_files=True,
) -> list[str] | str:
st.write(label, className="gui-input")
documents = st.session_state.get(key) or []
if isinstance(documents, str):
documents = [documents]
has_custom_urls = not all(map(is_user_uploaded_url, documents))
custom_key = "__custom_" + key
if st.checkbox("Enter Custom URLs", value=has_custom_urls):
if st.checkbox(
"Enter Custom URLs", key=f"__custom_checkbox_{key}", value=has_custom_urls
):
if not custom_key in st.session_state:
st.session_state[custom_key] = "\n".join(documents)
text_value = st.text_area(
if accept_multiple_files:
widget = st.text_area
kwargs = dict(height=150)
else:
widget = st.text_input
kwargs = {}
text_value = widget(
label,
key=custom_key,
label_visibility="collapsed",
height=150,
style={
"whiteSpace": "pre",
"overflowWrap": "normal",
"overflowX": "scroll",
"fontFamily": "monospace",
"fontSize": "0.9rem",
},
**kwargs,
)
st.session_state[key] = text_value.strip().splitlines()
if accept_multiple_files:
st.session_state[key] = text_value.strip().splitlines()
else:
st.session_state[key] = text_value
else:
st.session_state.pop(custom_key, None)
st.file_uploader(
label,
label_visibility="collapsed",
key=key,
accept=accept,
accept_multiple_files=True,
accept_multiple_files=accept_multiple_files,
)
return st.session_state.get(key, [])

Expand Down
8 changes: 6 additions & 2 deletions daras_ai_v2/facebook_bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def send_msg(
should_translate: bool = False,
) -> str | None:
if should_translate and self.language and self.language != "en":
text = run_google_translate([text], self.language)[0]
text = run_google_translate(
[text], self.language, glossary_url=self.output_glossary
)[0]
return send_wa_msg(
bot_number=self.bot_id,
user_number=self.user_id,
Expand Down Expand Up @@ -340,7 +342,9 @@ def send_msg(
should_translate: bool = False,
) -> str | None:
if should_translate and self.language and self.language != "en":
text = run_google_translate([text], self.language)[0]
text = run_google_translate(
[text], self.language, glossary_url=self.output_glossary
)[0]
return send_fb_msg(
access_token=self._access_token,
bot_id=self.bot_id,
Expand Down
2 changes: 1 addition & 1 deletion daras_ai_v2/gdrive_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def gdrive_metadata(file_id: str) -> dict:
.get(
supportsAllDrives=True,
fileId=file_id,
fields="name,md5Checksum,modifiedTime,mimeType",
fields="name,md5Checksum,modifiedTime,mimeType,size",
)
.execute()
)
Expand Down
Loading

0 comments on commit 2df9984

Please sign in to comment.