Skip to content

Commit

Permalink
store file metadata for glossary files
Browse files Browse the repository at this point in the history
add tests for glossary translation
add lock for redis cache
  • Loading branch information
devxpy committed Oct 29, 2023
1 parent be63867 commit 677131f
Show file tree
Hide file tree
Showing 20 changed files with 512 additions and 378 deletions.
30 changes: 30 additions & 0 deletions app_users/migrations/0010_filemetadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Generated by Django 4.2.5 on 2023-10-29 11:10

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("app_users", "0009_alter_appusertransaction_options_and_more"),
]

operations = [
migrations.CreateModel(
name="FileMetadata",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.TextField(blank=True, default="")),
("etag", models.CharField(max_length=255, null=True)),
("mime_type", models.CharField(blank=True, default="", max_length=255)),
("total_bytes", models.PositiveIntegerField(blank=True, default=0)),
],
),
]
10 changes: 10 additions & 0 deletions app_users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,13 @@ class Meta:

def __str__(self):
return f"{self.invoice_id} ({self.amount})"


class FileMetadata(models.Model):
name = models.TextField(default="", blank=True)
etag = models.CharField(max_length=255, null=True)
mime_type = models.CharField(max_length=255, default="", blank=True)
total_bytes = models.PositiveIntegerField(default=0, blank=True)

def __str__(self):
return f"{self.name or self.etag} ({self.mime_type})"
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def threadpool_subtest(subtests, max_workers: int = 8):
ts = []

def submit(fn, *args, **kwargs):
msg = "--".join(map(str, args))
msg = "--".join(map(str, [*args, *kwargs.values()]))

@wraps(fn)
def runner(*args, **kwargs):
Expand Down
15 changes: 10 additions & 5 deletions daras_ai/image_input.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import math
import mimetypes
import os
import re
import uuid
from pathlib import Path

import math
import numpy as np
import requests
from PIL import Image, ImageOps
from furl import furl

from daras_ai_v2 import settings

if False:
from firebase_admin import storage


def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes:
img_cv2 = bytes_to_cv2_img(img_bytes)
Expand Down Expand Up @@ -70,7 +69,9 @@ def storage_blob_for(filename: str) -> "storage.storage.Blob":

filename = safe_filename(filename)
bucket = storage.bucket(settings.GS_BUCKET_NAME)
blob = bucket.blob(f"daras_ai/media/{uuid.uuid1()}/{filename}")
blob = bucket.blob(
os.path.join(settings.GS_MEDIA_PATH, str(uuid.uuid1()), filename)
)
return blob


Expand Down Expand Up @@ -143,3 +144,7 @@ def guess_ext_from_response(response: requests.Response) -> str:
def get_mimetype_from_response(response: requests.Response) -> str:
content_type = response.headers.get("Content-Type", "application/octet-stream")
return content_type.split(";")[0]


def gs_url_to_uri(url: str) -> str:
return "gs://" + "/".join(furl(url).path.segments)
65 changes: 27 additions & 38 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,12 @@
from furl import furl

import gooey_ui as st
from daras_ai.image_input import upload_file_from_bytes
from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri
from daras_ai_v2 import settings
from daras_ai_v2.functional import map_parallel
from daras_ai_v2.gpu_server import (
GpuEndpoints,
call_celery_task,
)
from daras_ai_v2.glossary import (
DEFAULT_GLOSSARY_URL,
glossary_resource,
supports_language_pair,
)
from daras_ai_v2.redis_cache import redis_cache_decorator

SHORT_FILE_CUTOFF = 5 * 1024 * 1024 # 1 MB
Expand Down Expand Up @@ -168,14 +162,15 @@ def run_google_translate(
texts: list[str],
target_language: str,
source_language: str = None,
glossary_url: str = DEFAULT_GLOSSARY_URL,
glossary_url: str = None,
) -> list[str]:
"""
Translate text using the Google Translate API.
Args:
texts (list[str]): Text to be translated.
target_language (str): Language code to translate to.
source_language (str): Language code to translate from.
glossary_url (str): URL of glossary file.
Returns:
list[str]: Translated text.
"""
Expand All @@ -191,7 +186,7 @@ def run_google_translate(

return map_parallel(
lambda text, source: _translate_text(
text, source, target_language, glossary_url or DEFAULT_GLOSSARY_URL
text, source, target_language, glossary_url
),
texts,
language_codes,
Expand All @@ -202,16 +197,13 @@ def _translate_text(
text: str,
source_language: str,
target_language: str,
glossary_url: str = DEFAULT_GLOSSARY_URL,
glossary_url: str | None,
) -> str:
is_romanized = source_language.endswith("-Latn")
source_language = source_language.replace("-Latn", "")
enable_transliteration = (
is_romanized and source_language in TRANSLITERATION_SUPPORTED
)
glossary_url = (
glossary_url if not enable_transliteration else ""
) # glossary does not work with transliteration

# prevent incorrect API calls
if source_language == target_language or not text:
Expand All @@ -228,30 +220,28 @@ def _translate_text(
"transliteration_config": {"enable_transliteration": enable_transliteration},
}

with glossary_resource(glossary_url) as (uri, location, lang_codes):
if supports_language_pair(lang_codes, target_language, source_language):
config.update(
{
"glossaryConfig": {
"glossary": uri,
"ignoreCase": True,
}
}
)
else:
uri = None
location = "global"
# glossary does not work with transliteration
if glossary_url and not enable_transliteration:
from glossary_resources.models import GlossaryResource

authed_session, project = get_google_auth_session()
res = authed_session.post(
f"https://translation.googleapis.com/v3/projects/{project}/locations/{location}:translateText",
json=config,
)
res.raise_for_status()
data = res.json()
result = data["glossaryTranslations"][0] if uri else data["translations"][0]
gr = GlossaryResource.objects.get_or_create_from_url(glossary_url)[0]
location = gr.location
config["glossary_config"] = {
"glossary": gr.get_glossary_path(),
"ignoreCase": True,
}
else:
location = "global"

return result["translatedText"].strip()
authed_session, project = get_google_auth_session()
res = authed_session.post(
f"https://translation.googleapis.com/v3/projects/{project}/locations/{location}:translateText",
json=config,
)
res.raise_for_status()
data = res.json()
translations = data.get("glossaryTranslations", data["translations"])
return translations[0]["translatedText"].strip()


_session = None
Expand Down Expand Up @@ -355,8 +345,7 @@ def run_asr(
)

elif selected_model == AsrModels.usm:
# note: only us-central1 and a few other regions support chirp recognizers (so global can't be used)
location = "us-central1"
location = settings.GCP_REGION

# Create a client
options = ClientOptions(api_endpoint=f"{location}-speech.googleapis.com")
Expand Down Expand Up @@ -388,7 +377,7 @@ def run_asr(
audio_channel_count=1,
)
audio = cloud_speech.BatchRecognizeFileMetadata()
audio.uri = "gs://" + "/".join(furl(audio_url).path.segments)
audio.uri = gs_url_to_uri(audio_url)
# Specify that results should be inlined in the response (only possible for 1 audio file)
output_config = cloud_speech.RecognitionOutputConfig()
output_config.inline_response_config = cloud_speech.InlineOutputConfig()
Expand Down
7 changes: 6 additions & 1 deletion daras_ai_v2/doc_search_settings_widgets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import typing

import gooey_ui as st
Expand All @@ -7,9 +8,13 @@
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.search_ref import CitationStyles

_user_media_url_prefix = os.path.join(
"storage.googleapis.com", settings.GS_BUCKET_NAME, settings.GS_MEDIA_PATH
)


def is_user_uploaded_url(url: str) -> bool:
return f"storage.googleapis.com/{settings.GS_BUCKET_NAME}/daras_ai" in url
return _user_media_url_prefix in url


def document_uploader(
Expand Down
2 changes: 1 addition & 1 deletion daras_ai_v2/gdrive_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def gdrive_metadata(file_id: str) -> dict:
.get(
supportsAllDrives=True,
fileId=file_id,
fields="name,md5Checksum,modifiedTime,mimeType",
fields="name,md5Checksum,modifiedTime,mimeType,filesize",
)
.execute()
)
Expand Down
Loading

0 comments on commit 677131f

Please sign in to comment.