From 677131f87af2a2543043fbd79d871814a60803eb Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sun, 29 Oct 2023 19:10:27 +0530 Subject: [PATCH] store file metadata for glossary files add tests for glossary translation add lock for redis cache --- app_users/migrations/0010_filemetadata.py | 30 ++ app_users/models.py | 10 + conftest.py | 2 +- daras_ai/image_input.py | 15 +- daras_ai_v2/asr.py | 65 ++-- daras_ai_v2/doc_search_settings_widgets.py | 7 +- daras_ai_v2/gdrive_downloader.py | 2 +- daras_ai_v2/glossary.py | 298 +++++------------- daras_ai_v2/redis_cache.py | 23 +- daras_ai_v2/settings.py | 6 +- daras_ai_v2/vector_search.py | 130 ++++---- glossary_resources/admin.py | 16 +- glossary_resources/apps.py | 5 + glossary_resources/migrations/0001_initial.py | 23 +- glossary_resources/models.py | 115 ++++++- glossary_resources/signals.py | 14 + glossary_resources/tests.py | 114 ++++++- recipes/DocExtract.py | 2 +- recipes/asr.py | 4 +- tests/test_translation.py | 9 +- 20 files changed, 512 insertions(+), 378 deletions(-) create mode 100644 app_users/migrations/0010_filemetadata.py create mode 100644 glossary_resources/signals.py diff --git a/app_users/migrations/0010_filemetadata.py b/app_users/migrations/0010_filemetadata.py new file mode 100644 index 000000000..9c1825efc --- /dev/null +++ b/app_users/migrations/0010_filemetadata.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.5 on 2023-10-29 11:10 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("app_users", "0009_alter_appusertransaction_options_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="FileMetadata", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.TextField(blank=True, default="")), + ("etag", models.CharField(max_length=255, null=True)), + ("mime_type", models.CharField(blank=True, default="", max_length=255)), + ("total_bytes", models.PositiveIntegerField(blank=True, default=0)), + ], + ), + ] diff --git a/app_users/models.py b/app_users/models.py index 6a72eecab..bce1a4485 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -229,3 +229,13 @@ class Meta: def __str__(self): return f"{self.invoice_id} ({self.amount})" + + +class FileMetadata(models.Model): + name = models.TextField(default="", blank=True) + etag = models.CharField(max_length=255, null=True) + mime_type = models.CharField(max_length=255, default="", blank=True) + total_bytes = models.PositiveIntegerField(default=0, blank=True) + + def __str__(self): + return f"{self.name or self.etag} ({self.mime_type})" diff --git a/conftest.py b/conftest.py index ff61671aa..e8534d5f2 100644 --- a/conftest.py +++ b/conftest.py @@ -48,7 +48,7 @@ def threadpool_subtest(subtests, max_workers: int = 8): ts = [] def submit(fn, *args, **kwargs): - msg = "--".join(map(str, args)) + msg = "--".join(map(str, [*args, *kwargs.values()])) @wraps(fn) def runner(*args, **kwargs): diff --git a/daras_ai/image_input.py b/daras_ai/image_input.py index 96bd7734c..2a7a70417 100644 --- a/daras_ai/image_input.py +++ b/daras_ai/image_input.py @@ -1,18 +1,17 @@ -import math import mimetypes +import os import re import uuid from pathlib import Path +import math import numpy as np import requests from PIL import Image, ImageOps +from furl import furl from daras_ai_v2 import settings -if False: - from firebase_admin import storage - def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes: img_cv2 = bytes_to_cv2_img(img_bytes) @@ -70,7 +69,9 @@ def storage_blob_for(filename: str) -> "storage.storage.Blob": filename = safe_filename(filename) bucket = storage.bucket(settings.GS_BUCKET_NAME) - blob = bucket.blob(f"daras_ai/media/{uuid.uuid1()}/{filename}") + blob = bucket.blob( + os.path.join(settings.GS_MEDIA_PATH, str(uuid.uuid1()), filename) + ) return blob @@ -143,3 +144,7 @@ def guess_ext_from_response(response: requests.Response) -> str: def get_mimetype_from_response(response: requests.Response) -> str: content_type = response.headers.get("Content-Type", "application/octet-stream") return content_type.split(";")[0] + + +def gs_url_to_uri(url: str) -> str: + return "gs://" + "/".join(furl(url).path.segments) diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index cbfe873be..40f7aa406 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -10,18 +10,12 @@ from furl import furl import gooey_ui as st -from daras_ai.image_input import upload_file_from_bytes +from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings from daras_ai_v2.functional import map_parallel from daras_ai_v2.gpu_server import ( - GpuEndpoints, call_celery_task, ) -from daras_ai_v2.glossary import ( - DEFAULT_GLOSSARY_URL, - glossary_resource, - supports_language_pair, -) from daras_ai_v2.redis_cache import redis_cache_decorator SHORT_FILE_CUTOFF = 5 * 1024 * 1024 # 1 MB @@ -168,7 +162,7 @@ def run_google_translate( texts: list[str], target_language: str, source_language: str = None, - glossary_url: str = DEFAULT_GLOSSARY_URL, + glossary_url: str = None, ) -> list[str]: """ Translate text using the Google Translate API. @@ -176,6 +170,7 @@ def run_google_translate( texts (list[str]): Text to be translated. target_language (str): Language code to translate to. source_language (str): Language code to translate from. + glossary_url (str): URL of glossary file. Returns: list[str]: Translated text. """ @@ -191,7 +186,7 @@ def run_google_translate( return map_parallel( lambda text, source: _translate_text( - text, source, target_language, glossary_url or DEFAULT_GLOSSARY_URL + text, source, target_language, glossary_url ), texts, language_codes, @@ -202,16 +197,13 @@ def _translate_text( text: str, source_language: str, target_language: str, - glossary_url: str = DEFAULT_GLOSSARY_URL, + glossary_url: str | None, ) -> str: is_romanized = source_language.endswith("-Latn") source_language = source_language.replace("-Latn", "") enable_transliteration = ( is_romanized and source_language in TRANSLITERATION_SUPPORTED ) - glossary_url = ( - glossary_url if not enable_transliteration else "" - ) # glossary does not work with transliteration # prevent incorrect API calls if source_language == target_language or not text: @@ -228,30 +220,28 @@ def _translate_text( "transliteration_config": {"enable_transliteration": enable_transliteration}, } - with glossary_resource(glossary_url) as (uri, location, lang_codes): - if supports_language_pair(lang_codes, target_language, source_language): - config.update( - { - "glossaryConfig": { - "glossary": uri, - "ignoreCase": True, - } - } - ) - else: - uri = None - location = "global" + # glossary does not work with transliteration + if glossary_url and not enable_transliteration: + from glossary_resources.models import GlossaryResource - authed_session, project = get_google_auth_session() - res = authed_session.post( - f"https://translation.googleapis.com/v3/projects/{project}/locations/{location}:translateText", - json=config, - ) - res.raise_for_status() - data = res.json() - result = data["glossaryTranslations"][0] if uri else data["translations"][0] + gr = GlossaryResource.objects.get_or_create_from_url(glossary_url)[0] + location = gr.location + config["glossary_config"] = { + "glossary": gr.get_glossary_path(), + "ignoreCase": True, + } + else: + location = "global" - return result["translatedText"].strip() + authed_session, project = get_google_auth_session() + res = authed_session.post( + f"https://translation.googleapis.com/v3/projects/{project}/locations/{location}:translateText", + json=config, + ) + res.raise_for_status() + data = res.json() + translations = data.get("glossaryTranslations", data["translations"]) + return translations[0]["translatedText"].strip() _session = None @@ -355,8 +345,7 @@ def run_asr( ) elif selected_model == AsrModels.usm: - # note: only us-central1 and a few other regions support chirp recognizers (so global can't be used) - location = "us-central1" + location = settings.GCP_REGION # Create a client options = ClientOptions(api_endpoint=f"{location}-speech.googleapis.com") @@ -388,7 +377,7 @@ def run_asr( audio_channel_count=1, ) audio = cloud_speech.BatchRecognizeFileMetadata() - audio.uri = "gs://" + "/".join(furl(audio_url).path.segments) + audio.uri = gs_url_to_uri(audio_url) # Specify that results should be inlined in the response (only possible for 1 audio file) output_config = cloud_speech.RecognitionOutputConfig() output_config.inline_response_config = cloud_speech.InlineOutputConfig() diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index 817fc8728..eb8b43df4 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -1,3 +1,4 @@ +import os import typing import gooey_ui as st @@ -7,9 +8,13 @@ from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.search_ref import CitationStyles +_user_media_url_prefix = os.path.join( + "storage.googleapis.com", settings.GS_BUCKET_NAME, settings.GS_MEDIA_PATH +) + def is_user_uploaded_url(url: str) -> bool: - return f"storage.googleapis.com/{settings.GS_BUCKET_NAME}/daras_ai" in url + return _user_media_url_prefix in url def document_uploader( diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index 5c1d31e3b..d82c4f4a8 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -83,7 +83,7 @@ def gdrive_metadata(file_id: str) -> dict: .get( supportsAllDrives=True, fileId=file_id, - fields="name,md5Checksum,modifiedTime,mimeType", + fields="name,md5Checksum,modifiedTime,mimeType,filesize", ) .execute() ) diff --git a/daras_ai_v2/glossary.py b/daras_ai_v2/glossary.py index 6831bd498..033b2bb57 100644 --- a/daras_ai_v2/glossary.py +++ b/daras_ai_v2/glossary.py @@ -1,261 +1,115 @@ -import gooey_ui as st -from daras_ai_v2.redis_cache import redis_cache_decorator -from contextlib import contextmanager -from glossary_resources.models import GlossaryResource -from django.db.models import F -import requests -from time import sleep +import uuid -DEFAULT_GLOSSARY_URL = "https://docs.google.com/spreadsheets/d/1IRHKcOC86oZXwMB0hR7eej7YVg5kUHpriZymwYQcQX4/edit?usp=sharing" # only viewing access -BUCKET_NAME = "gooey-server-glossary" # name of bucket -MAX_GLOSSARY_RESOURCES = 10_000 # https://cloud.google.com/translate/quotas +import gooey_ui as gui +from daras_ai_v2.asr import google_translate_languages -# ================================ Glossary UI ================================ def glossary_input( label="##### Glossary\nUpload a google sheet, csv, or xlsx file.", key="glossary_document", ): from daras_ai_v2.doc_search_settings_widgets import document_uploader - st.session_state.setdefault(key, DEFAULT_GLOSSARY_URL) glossary_url = document_uploader( label=label, key=key, accept=[".csv", ".xlsx", ".xls", ".gsheet", ".ods", ".tsv"], accept_multiple_files=False, ) - st.caption( + gui.caption( f"If not specified or invalid, no glossary will be used. Read about the expected format [here](https://docs.google.com/document/d/1TwzAvFmFYekloRKql2PXNPIyqCbsHRL8ZtnWkzAYrh8/edit?usp=sharing)." ) return glossary_url -# ================================ Glossary Logic ================================ -def supports_language_pair( - supported_lang_codes, target_language: str, source_language: str | None -): - return any( - [ - target_language.split("-")[0] == supported_lang.split("-")[0] - for supported_lang in supported_lang_codes - ] - ) and ( - not source_language - or any( - [ - source_language.split("-")[0] == supported_lang.split("-")[0] - for supported_lang in supported_lang_codes - ] - ) - ) - - -@contextmanager -def glossary_resource(f_url: str = DEFAULT_GLOSSARY_URL, max_tries=3): +def create_glossary( + *, + language_codes: list[str], + input_uri: str, + project_id: str, + location: str, + glossary_name: str, + timeout: int = 180, +) -> "translate.Glossary": """ - Obtains a glossary resource for use in translation requests. - """ - from daras_ai_v2.vector_search import doc_url_to_metadata - from google.api_core.exceptions import NotFound - - if not f_url: - yield None, "global", {} - return - - resource, created = GlossaryResource.objects.get_or_create(f_url=f_url) - - # make sure we don't exceed the max number of glossary resources allowed by GCP (we add a safety buffer of 100 for local development) - if created and GlossaryResource.objects.count() > MAX_GLOSSARY_RESOURCES - 100: - for gloss in GlossaryResource.objects.order_by("usage_count", "last_updated")[ - :10 - ]: - try: - _delete_glossary( - glossary_name=gloss.get_clean_name(), - project_id=gloss.project_id, - location=gloss.location, - ) - except NotFound: - pass # glossary already deleted, let's delete the model and move on - finally: - gloss.delete() - - doc_meta = doc_url_to_metadata(f_url) - # create glossary if it doesn't exist, update if it has changed - _update_glossary( - f_url, - doc_meta, - glossary_name=resource.get_clean_name(), - project_id=resource.project_id, - location=resource.location, - ) - path, lang_codes = _get_glossary( - glossary_name=resource.get_clean_name(), - project_id=resource.project_id, - location=resource.location, - ) - - try: - yield path, resource.location, lang_codes - except requests.exceptions.HTTPError as e: - if e.response.status_code == 400 and e.response.json().get("error", {}).get( - "message", "" - ).startswith("Invalid resource name"): - sleep(1) - yield glossary_resource(f_url, max_tries - 1) - else: - raise e - finally: - GlossaryResource.objects.filter(pk=resource.pk).update( - usage_count=F("usage_count") + 1 - ) - - -@redis_cache_decorator -def _update_glossary( - f_url: str, - doc_meta, - glossary_name: str = "glossary", - project_id="dara-c1b52", - location="us-central1", -) -> "pd.DataFrame": - """Goes through the full process of uploading the glossary from the url""" - from daras_ai_v2.vector_search import download_table_doc - from google.api_core.exceptions import NotFound - - df = download_table_doc(f_url, doc_meta) - - _upload_glossary_to_bucket(df, glossary_name=glossary_name) - # delete existing glossary - try: - _delete_glossary( - glossary_name=glossary_name, project_id=project_id, location=location - ) - except NotFound: - pass # glossary already deleted, moving on - # create new glossary - languages = [ - lan_code - for lan_code in df.columns.tolist() - if lan_code not in ["pos", "description"] - ] # "pos" and "description" are not languages but still allowed by the google spec in the glossary csv - _create_glossary( - languages, glossary_name=glossary_name, project_id=project_id, location=location - ) - - return df - - -def _get_glossary( - glossary_name: str = "glossary", project_id="dara-c1b52", location="us-central1" -): - """Get information about the glossary.""" - from google.cloud import translate_v3beta1 - - client = translate_v3beta1.TranslationServiceClient() - - path = client.glossary_path(project_id, location, glossary_name) + From https://cloud.google.com/translate/docs/advanced/glossary#equivalent_term_sets_glossary - response = client.get_glossary(name=path) - print("Glossary name: {}".format(response.name)) - print("Entry count: {}".format(response.entry_count)) - print("Input URI: {}".format(response.input_config.gcs_source.input_uri)) - return path, response.language_codes_set.language_codes - - -def _upload_glossary_to_bucket(df, glossary_name: str = "glossary"): - """Uploads a pandas DataFrame to the bucket.""" - # import gcloud storage - from google.cloud import storage - - csv = df.to_csv(index=False) - - # initialize the storage client and give it the bucket and the blob name - BLOB_NAME, _ = _parse_glossary_name(glossary_name) - storage_client = storage.Client() - bucket = storage_client.bucket(BUCKET_NAME) - blob = bucket.blob(BLOB_NAME) - - # upload the file to the bucket - blob.upload_from_string(csv) - - -def _delete_glossary( - timeout=180, - glossary_name: str = "glossary", - project_id="dara-c1b52", - location="us-central1", -): - """Delete the glossary resource so a new one can be created.""" - from google.cloud import translate_v3beta1 - - client = translate_v3beta1.TranslationServiceClient() - - path = client.glossary_path(project_id, location, glossary_name) - - operation = client.delete_glossary(name=path) - result = operation.result(timeout) - print("Deleted: {}".format(result.name)) - - -def _create_glossary( - languages, - glossary_name: str = "glossary", - project_id="dara-c1b52", - location="us-central1", -): - """Creates a GCP glossary resource.""" - from google.cloud import translate_v3beta1 + Create a equivalent term sets glossary. Glossary can be words or + short phrases (usually fewer than five words). + https://cloud.google.com/translate/docs/advanced/glossary#format-glossary + """ + from google.cloud import translate_v3 as translate from google.api_core.exceptions import AlreadyExists - # Instantiates a client - client = translate_v3beta1.TranslationServiceClient() - - # Set glossary resource name - _, GLOSSARY_URI = _parse_glossary_name(glossary_name) - path = client.glossary_path(project_id, location, glossary_name) + client = translate.TranslationServiceClient() - # Set language codes - language_codes_set = translate_v3beta1.Glossary.LanguageCodesSet( - language_codes=languages + name = client.glossary_path(project_id, location, glossary_name) + language_codes_set = translate.types.Glossary.LanguageCodesSet( + language_codes=language_codes ) - gcs_source = translate_v3beta1.GcsSource(input_uri=GLOSSARY_URI) - - input_config = translate_v3beta1.GlossaryInputConfig(gcs_source=gcs_source) - - # Set glossary resource information - glossary = translate_v3beta1.Glossary( - name=path, language_codes_set=language_codes_set, input_config=input_config + gcs_source = translate.types.GcsSource(input_uri=input_uri) + input_config = translate.types.GlossaryInputConfig(gcs_source=gcs_source) + glossary = translate.types.Glossary( + name=name, language_codes_set=language_codes_set, input_config=input_config ) parent = f"projects/{project_id}/locations/{location}" - - # Create glossary resource - # Handle exception for case in which a glossary - # with glossary_name already exists try: operation = client.create_glossary(parent=parent, glossary=glossary) - operation.result(timeout=90) - print("Created glossary " + glossary_name + ".") + operation.result(timeout) + print("Glossary created:", name) except AlreadyExists: - print( - "The glossary " - + glossary_name - + " already exists. No new glossary was created." - ) + pass -def _parse_glossary_name(glossary_name: str = "glossary"): +def delete_glossary( + *, + project_id: str, + glossary_name: str, + location: str, + timeout: int = 180, +) -> "translate.Glossary": """ - Parses the glossary name into the bucket name and blob name. + From https://cloud.google.com/translate/docs/advanced/glossary#delete-glossary + + Delete a specific glossary based on the glossary ID. + Args: - glossary_name: name of the glossary resource + project_id: The ID of the GCP project that owns the glossary. + glossary_name: The name of the glossary to delete. + location: The location of the glossary. + timeout: The timeout for this request. + Returns: - blob_name: name of the blob - glossary_uri: uri of the glossary uploaded to Cloud Storage + The glossary that was deleted. """ - blob_name = glossary_name + ".csv" - glossary_uri = "gs://" + BUCKET_NAME + "/" + blob_name - return blob_name, glossary_uri + + from google.cloud import translate_v3 as translate + from google.api_core.exceptions import NotFound + + client = translate.TranslationServiceClient() + name = client.glossary_path(project_id, location, glossary_name) + try: + operation = client.delete_glossary(name=name) + operation.result(timeout) + print("Glossary deleted:", name) + except NotFound: + pass + + +def get_langcodes_from_df(df: "pd.DataFrame") -> list[str]: + import langcodes + + supported = { + langcodes.Language.get(code).language for code in google_translate_languages() + } + ret = [] + for col in df.columns: + try: + lang = langcodes.Language.get(col).language + if lang in supported: + ret.append(col) + except langcodes.LanguageTagError: + pass + return ret diff --git a/daras_ai_v2/redis_cache.py b/daras_ai_v2/redis_cache.py index a2e42d8cc..2f700936d 100644 --- a/daras_ai_v2/redis_cache.py +++ b/daras_ai_v2/redis_cache.py @@ -1,4 +1,5 @@ import hashlib +import os.path import pickle import typing from functools import wraps, lru_cache @@ -25,15 +26,17 @@ def wrapper(*args, **kwargs): cache_key = f"gooey/redis-cache-decorator/v1/{fn.__name__}/{args_hash}" # get the redis cache redis_cache = get_redis_cache() - cache_val = redis_cache.get(cache_key) - # if the cache exists, return it - if cache_val: - return pickle.loads(cache_val) - # otherwise, run the function and cache the result - else: - result = fn(*args, **kwargs) - cache_val = pickle.dumps(result) - redis_cache.set(cache_key, cache_val) - return result + # lock the cache key so that only one thread can run the function + with redis_cache.lock(os.path.join(cache_key, "lock")): + cache_val = redis_cache.get(cache_key) + # if the cache exists, return it + if cache_val: + return pickle.loads(cache_val) + # otherwise, run the function and cache the result + else: + result = fn(*args, **kwargs) + cache_val = pickle.dumps(result) + redis_cache.set(cache_key, cache_val) + return result return wrapper diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 7285bf1cf..84f2373dd 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -211,8 +211,12 @@ os.environ["REPLICATE_API_TOKEN"] = config("REPLICATE_API_TOKEN", default="") +GCP_PROJECT = config("GCP_PROJECT", default="dara-c1b52") +GCP_REGION = config("GCP_REGION", default="us-central1") + GS_BUCKET_NAME = config("GS_BUCKET_NAME", default="") -# GOOGLE_CLIENT_ID = config("GOOGLE_CLIENT_ID") +GS_MEDIA_PATH = config("GS_MEDIA_PATH", default="daras_ai/media") + UBERDUCK_KEY = config("UBERDUCK_KEY", None) UBERDUCK_SECRET = config("UBERDUCK_SECRET", None) diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 74d2d1891..a7a42af68 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -237,6 +237,7 @@ class DocMetadata(typing.NamedTuple): name: str etag: str | None mime_type: str | None + total_bytes: int = 0 def doc_url_to_metadata(f_url: str) -> DocMetadata: @@ -265,6 +266,7 @@ def doc_url_to_metadata(f_url: str) -> DocMetadata: name = meta["name"] etag = meta.get("md5Checksum") or meta.get("modifiedTime") mime_type = meta["mimeType"] + total_bytes = int(meta.get("filesize") or 0) else: try: r = requests.head( @@ -275,17 +277,19 @@ def doc_url_to_metadata(f_url: str) -> DocMetadata: r.raise_for_status() except requests.RequestException as e: print(f"ignore error while downloading {f_url}: {e}") + name = None mime_type = None etag = None - name = None + total_bytes = 0 else: - mime_type = get_mimetype_from_response(r) - etag = r.headers.get("etag", r.headers.get("last-modified")) name = ( r.headers.get("content-disposition", "") .split("filename=")[-1] .strip('"') ) + etag = r.headers.get("etag", r.headers.get("last-modified")) + mime_type = get_mimetype_from_response(r) + total_bytes = int(r.headers.get("content-length") or 0) # extract filename from url as a fallback if not name: if is_user_uploaded_url(str(f)): @@ -295,7 +299,7 @@ def doc_url_to_metadata(f_url: str) -> DocMetadata: # guess mimetype from name as a fallback if not mime_type: mime_type = mimetypes.guess_type(name)[0] - return DocMetadata(name, etag, mime_type) + return DocMetadata(name, etag, mime_type, total_bytes) @redis_cache_decorator @@ -417,44 +421,6 @@ def split_sections(sections: str, *, chunk_overlap: int, chunk_size: int): header += f"{role}={content}\n" -def _download_doc_content(f_url: str, doc_meta: DocMetadata): - f = furl(f_url) - f_name = doc_meta.name - if is_gdrive_url(f): - # download from google drive - f_bytes, ext = gdrive_download(f, doc_meta.mime_type) - else: - # download from url - try: - r = requests.get( - f_url, - headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, - timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, - ) - r.raise_for_status() - except requests.RequestException as e: - print(f"ignore error while downloading {f_url}: {e}") - return "", "", b"" - f_bytes = r.content - # if it's a known encoding, standardize to utf-8 - if r.encoding: - try: - codec = codecs.lookup(r.encoding) - except LookupError: - pass - else: - f_bytes = codec.decode(f_bytes)[0].encode() - ext = guess_ext_from_response(r) - return ext, f_name, f_bytes - - -def download_content_bytes(f_url: str, mime_type: str): - ext, _, f_bytes = _download_doc_content( - f_url, DocMetadata(name="", etag="", mime_type=mime_type) - ) - return f_bytes, ext - - @redis_cache_decorator def doc_url_to_text_pages( *, @@ -462,18 +428,74 @@ def doc_url_to_text_pages( doc_meta: DocMetadata, google_translate_target: str | None, selected_asr_model: str | None, -) -> list[str]: +) -> typing.Union[list[str], "pd.DataFrame"]: """ Download document from url and convert to text pages. + Args: f_url: url of document doc_meta: document metadata google_translate_target: target language for google translate selected_asr_model: selected ASR model (used for audio files) + Returns: list of text pages """ - ext, f_name, f_bytes = _download_doc_content(f_url, doc_meta) + f_bytes, ext = download_content_bytes(f_url=f_url, mime_type=doc_meta.mime_type) + if not f_bytes: + return [] + pages = bytes_to_text_pages_or_df( + f_url=f_url, + f_name=doc_meta.name, + f_bytes=f_bytes, + ext=ext, + mime_type=doc_meta.mime_type, + selected_asr_model=selected_asr_model, + ) + # optionally, translate text + if google_translate_target and isinstance(pages, list): + pages = run_google_translate(pages, google_translate_target) + return pages + + +def download_content_bytes(*, f_url: str, mime_type: str) -> tuple[bytes, str]: + f = furl(f_url) + if is_gdrive_url(f): + # download from google drive + return gdrive_download(f, mime_type) + try: + # download from url + r = requests.get( + f_url, + headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, + timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, + ) + r.raise_for_status() + except requests.RequestException as e: + print(f"ignore error while downloading {f_url}: {e}") + return b"", "" + f_bytes = r.content + # if it's a known encoding, standardize to utf-8 + if r.encoding: + try: + codec = codecs.lookup(r.encoding) + except LookupError: + pass + else: + f_bytes = codec.decode(f_bytes)[0].encode() + ext = guess_ext_from_response(r) + return f_bytes, ext + + +def bytes_to_text_pages_or_df( + *, + f_url: str, + f_name: str, + f_bytes: bytes, + ext: str, + mime_type: str, + selected_asr_model: str | None, +) -> typing.Union[list[str], "pd.DataFrame"]: # convert document to text pages match ext: case ".pdf": @@ -488,17 +510,16 @@ def doc_url_to_text_pages( "For transcribing audio/video, please choose an ASR model from the settings!" ) if is_gdrive_url(furl(f_url)): - f_url = upload_file_from_bytes( - f_name, f_bytes, content_type=doc_meta.mime_type - ) + f_url = upload_file_from_bytes(f_name, f_bytes, content_type=mime_type) pages = [run_asr(f_url, selected_model=selected_asr_model, language="en")] case _: - df = bytes_to_df(f_name=f_name, f_bytes=f_bytes, ext=ext).fillna("") + df = bytes_to_df(f_name=f_name, f_bytes=f_bytes, ext=ext) assert ( "snippet" in df.columns or "sections" in df.columns ), f'uploaded spreadsheet must contain a "snippet" or "sections" column - {f_name !r}' return df - return run_google_translate(pages, target_language=google_translate_target) + + return pages def bytes_to_df( @@ -521,13 +542,9 @@ def bytes_to_df( df = pd.read_json(f, dtype=str) case ".xml": df = pd.read_xml(f, dtype=str) - case ".ods": - df = pd.read_excel(f, engine="odf", dtype=str) - case ".gsheet": - df = pd.read_csv(f, dtype=str) case _: raise ValueError(f"Unsupported document format {ext!r} ({f_name})") - return df + return df.fillna("") def pdf_to_text_pages(f: typing.BinaryIO) -> list[str]: @@ -567,11 +584,6 @@ def pandoc_to_text(f_name: str, f_bytes: bytes, to="plain") -> str: return outfile.read() -def download_table_doc(f_url: str, doc_meta: DocMetadata) -> "pd.DataFrame": - ext, f_name, f_bytes = _download_doc_content(f_url, doc_meta) - return bytes_to_df(f_name=f_name, f_bytes=f_bytes, ext=ext).dropna() - - def render_sources_widget(refs: list[SearchReference]): if not refs: return diff --git a/glossary_resources/admin.py b/glossary_resources/admin.py index 14200e4fb..27eafff28 100644 --- a/glossary_resources/admin.py +++ b/glossary_resources/admin.py @@ -5,15 +5,7 @@ # Register your models here. @admin.register(GlossaryResource) class GlossaryAdmin(admin.ModelAdmin): - list_display = ("f_url", "usage_count", "last_updated", "glossary_name") - list_filter = ("usage_count", "last_updated") - search_fields = ("f_url", "glossary_name") - ordering = ["usage_count", "last_updated"] - readonly_fields = ( - "f_url", - "usage_count", - "last_updated", - "glossary_name", - "project_id", - "location", - ) + list_display = ["f_url", "usage_count", "last_updated", "glossary_id"] + list_filter = ["usage_count", "last_updated"] + search_fields = ["f_url", "glossary_name"] + ordering = ["-usage_count", "-last_updated"] diff --git a/glossary_resources/apps.py b/glossary_resources/apps.py index 30c62a5ac..d039c9e32 100644 --- a/glossary_resources/apps.py +++ b/glossary_resources/apps.py @@ -5,3 +5,8 @@ class GlossaryResourcesConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" name = "glossary_resources" verbose_name = "Glossary Resources" + + def ready(self): + import glossary_resources.signals + + assert glossary_resources.signals diff --git a/glossary_resources/migrations/0001_initial.py b/glossary_resources/migrations/0001_initial.py index 91cbd4244..41eb153aa 100644 --- a/glossary_resources/migrations/0001_initial.py +++ b/glossary_resources/migrations/0001_initial.py @@ -1,14 +1,17 @@ -# Generated by Django 4.2.6 on 2023-10-26 20:45 +# Generated by Django 4.2.5 on 2023-10-29 11:10 import bots.custom_fields from django.db import migrations, models +import django.db.models.deletion import uuid class Migration(migrations.Migration): initial = True - dependencies = [] + dependencies = [ + ("app_users", "0010_filemetadata"), + ] operations = [ migrations.CreateModel( @@ -27,14 +30,24 @@ class Migration(migrations.Migration): "f_url", bots.custom_fields.CustomURLField(max_length=2048, unique=True), ), - ("usage_count", models.IntegerField(default=0)), - ("last_updated", models.DateTimeField(auto_now=True)), + ("language_codes", models.JSONField()), + ("glossary_uri", models.TextField()), ( - "glossary_name", + "glossary_id", models.UUIDField(default=uuid.uuid4, editable=False, unique=True), ), ("project_id", models.CharField(default="dara-c1b52", max_length=100)), ("location", models.CharField(default="us-central1", max_length=100)), + ("usage_count", models.IntegerField(default=0)), + ("last_updated", models.DateTimeField(auto_now=True)), + ( + "metadata", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="glossary_resources", + to="app_users.filemetadata", + ), + ), ], options={ "indexes": [ diff --git a/glossary_resources/models.py b/glossary_resources/models.py index 1a9fcbe44..7f17a3d5e 100644 --- a/glossary_resources/models.py +++ b/glossary_resources/models.py @@ -1,28 +1,115 @@ -from django.db import models -from bots.custom_fields import CustomURLField import uuid +from django.db import models, IntegrityError, transaction + +from app_users.models import FileMetadata +from bots.custom_fields import CustomURLField +from daras_ai.image_input import gs_url_to_uri, upload_file_from_bytes +from daras_ai_v2 import settings +from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url +from daras_ai_v2.glossary import ( + get_langcodes_from_df, + create_glossary, +) +from daras_ai_v2.redis_cache import redis_cache_decorator +from daras_ai_v2.vector_search import ( + doc_url_to_metadata, +) +from daras_ai_v2.vector_search import ( + download_content_bytes, + bytes_to_df, + DocMetadata, +) + + +class GlossaryResourceQuerySet(models.QuerySet): + def get_or_create_from_url(self, url: str) -> tuple["GlossaryResource", bool]: + doc_meta = doc_url_to_metadata(url) + try: + return GlossaryResource.objects.get(f_url=url), False + except GlossaryResource.DoesNotExist: + try: + gr = create_glossary_cached(url, doc_meta) + with transaction.atomic(): + try: + gr.id = GlossaryResource.objects.get(f_url=url).id + except GlossaryResource.DoesNotExist: + pass + gr.metadata = FileMetadata.objects.create( + name=doc_meta.name, + etag=doc_meta.etag, + mime_type=doc_meta.mime_type, + ) + gr.save() + return gr, True + except IntegrityError: + try: + return GlossaryResource.objects.get(f_url=url), False + except self.model.DoesNotExist: + pass + raise + + +@redis_cache_decorator +def create_glossary_cached(url: str, doc_meta: DocMetadata) -> "GlossaryResource": + f_bytes, ext = download_content_bytes(f_url=url, mime_type=doc_meta.name) + df = bytes_to_df(f_name=doc_meta.name, f_bytes=f_bytes, ext=ext) + if not is_user_uploaded_url(url): + url = upload_file_from_bytes( + doc_meta.name + ".csv", + df.to_csv(index=False).encode(), + content_type="text/csv", + ) + gr = GlossaryResource( + f_url=url, + language_codes=get_langcodes_from_df(df), + project_id=settings.GCP_PROJECT, + location=settings.GCP_REGION, + glossary_uri=gs_url_to_uri(url), + ) + create_glossary( + language_codes=gr.language_codes, + input_uri=gr.glossary_uri, + project_id=gr.project_id, + location=gr.location, + glossary_name=gr.glossary_name, + ) + return gr + class GlossaryResource(models.Model): f_url = CustomURLField(unique=True) + metadata = models.ForeignKey( + "app_users.FileMetadata", + on_delete=models.CASCADE, + related_name="glossary_resources", + ) + + language_codes = models.JSONField() + glossary_uri = models.TextField() + glossary_id = models.UUIDField(unique=True, default=uuid.uuid4, editable=False) + project_id = models.CharField(max_length=100, default=settings.GCP_PROJECT) + location = models.CharField(max_length=100, default=settings.GCP_REGION) + usage_count = models.IntegerField(default=0) last_updated = models.DateTimeField(auto_now=True) - glossary_name = models.UUIDField(unique=True, default=uuid.uuid4, editable=False) - project_id = models.CharField(max_length=100, default="dara-c1b52") - location = models.CharField(max_length=100, default="us-central1") + + objects = GlossaryResourceQuerySet.as_manager() class Meta: indexes = [ - models.Index( - fields=[ - "usage_count", - "last_updated", - ] - ), + models.Index(fields=["usage_count", "last_updated"]), ] def __str__(self): - return f"{self.f_url} ({self.usage_count} uses)" + return f"{self.metadata.name or self.f_url} ({self.glossary_id})" + + def get_glossary_path(self) -> dict: + from google.cloud import translate_v3 as translate + + client = translate.TranslationServiceClient() + return client.glossary_path(self.project_id, self.location, self.glossary_name) - def get_clean_name(self): - return "glossary-" + str(self.glossary_name).lower() + @property + def glossary_name(self) -> str: + return f"gooey-api--{self.glossary_id}" diff --git a/glossary_resources/signals.py b/glossary_resources/signals.py new file mode 100644 index 000000000..32b07122a --- /dev/null +++ b/glossary_resources/signals.py @@ -0,0 +1,14 @@ +from django.db.models.signals import post_delete +from django.dispatch import receiver + +from daras_ai_v2.glossary import delete_glossary +from glossary_resources.models import GlossaryResource + + +@receiver(post_delete, sender=GlossaryResource) +def on_glossary_resource_deleted(instance: GlossaryResource, **kwargs): + delete_glossary( + project_id=instance.project_id, + location=instance.location, + glossary_name=instance.glossary_name, + ) diff --git a/glossary_resources/tests.py b/glossary_resources/tests.py index 7ce503c2d..06ac9ef7c 100644 --- a/glossary_resources/tests.py +++ b/glossary_resources/tests.py @@ -1,3 +1,113 @@ -from django.test import TestCase +import pytest -# Create your tests here. +from daras_ai.image_input import storage_blob_for +from daras_ai_v2.crypto import get_random_doc_id +from glossary_resources.models import GlossaryResource +from tests.test_translation import test_run_google_translate_one + +GLOSSARY = [ + { + "en-US": "Gooey.AI", + "hi-IN": "गूई.एआई", + "pos": "noun", + "description": "name of the Gooey.AI", + "random": "random", + }, + { + "en-US": "Gooey.AI", + "hi-IN": "गुई डॉट ए आई", + "pos": "noun", + "description": "name of the Gooey.AI", + "random": get_random_doc_id(), + }, + { + "en-US": "Gooey.AI", + "hi-IN": "गुई ए आई", + "pos": "noun", + "description": "name of the Gooey.AI", + "random": "random", + }, + { + "en-US": "chilli", + "hi-IN": "मिर्ची", + "pos": "noun", + "description": "the spicy thing", + }, + { + "en-US": "chilli", + "hi-IN": "मिर्च", + "pos": "noun", + "description": "the spicy thing", + }, +] + +TRANSLATION_TESTS_GLOSSARY = [ + ( + "मिर्च में बीज उपचार कैसे करें", # source + "en", + "How to Treat Seeds in Peppers", # no glossary + "How to Treat Seeds in Chilli", # with glosssary + ), + ( + "मिर्ची में बीज उपचार कैसे करें", + "en", + "How to Treat Seeds in Peppers", + "How to Treat Seeds in Chilli", + ), + ( + "गुई डॉट ए आई से हम क्या कर सकते हैं", + "en", + "What can we do with Gui.AI", + "What can we do with Gooey.AI", + ), + ( + "गुई ए आई से हम क्या कर सकते हैं", + "en", + "What can we do with AI", + "What can we do with Gooey.AI", + ), + ( + "Who is the founder of Gooey.AI?", + "hi", + "gooeyai के संस्थापक कौन हैं?", + "गूई.एआई के संस्थापक कौन हैं?", + ), +] + + +@pytest.fixture +def glossary_url(): + import pandas as pd + + df = pd.DataFrame.from_records(GLOSSARY) + blob = storage_blob_for("test glossary.csv") + blob.upload_from_string(df.to_csv(index=False).encode(), content_type="text/csv") + + try: + yield blob.public_url + finally: + blob.delete() + GlossaryResource.objects.all().delete() + + +@pytest.mark.django_db +def test_run_google_translate_glossary(glossary_url, threadpool_subtest): + for ( + text, + target_lang, + expected, + expected_with_glossary, + ) in TRANSLATION_TESTS_GLOSSARY: + threadpool_subtest( + test_run_google_translate_one, + text, + expected, + target_lang=target_lang, + ) + threadpool_subtest( + test_run_google_translate_one, + text, + expected_with_glossary, + target_lang=target_lang, + glossary_url=glossary_url, + ) diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index d3bc05d48..cc94df768 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -403,7 +403,7 @@ def process_source( texts=[transcript], target_language=request.google_translate_target, # source_language=request.language, - glossary_url=request.glossary_document or "", + glossary_url=request.glossary_document, )[0] update_cell(spreadsheet_id, row, Columns.translation.value, translation) else: diff --git a/recipes/asr.py b/recipes/asr.py index da73a271f..11964613b 100644 --- a/recipes/asr.py +++ b/recipes/asr.py @@ -151,9 +151,7 @@ def run(self, state: dict): source_language=forced_asr_languages.get( selected_model, request.language ), - glossary_url=request.glossary_document - if request.glossary_document - else "", + glossary_url=request.glossary_document, ) else: # Save the raw ASR text for details view diff --git a/tests/test_translation.py b/tests/test_translation.py index 95cf2a539..5c94efa99 100644 --- a/tests/test_translation.py +++ b/tests/test_translation.py @@ -1,5 +1,6 @@ from daras_ai_v2.asr import run_google_translate + TRANSLATION_TESTS = [ # hindi romanized ( @@ -44,11 +45,13 @@ def test_run_google_translate(threadpool_subtest): for text, expected in TRANSLATION_TESTS: - threadpool_subtest(_test_run_google_translate, text, expected) + threadpool_subtest(test_run_google_translate_one, text, expected) -def _test_run_google_translate(text: str, expected: str): - actual = run_google_translate([text], "en")[0] +def test_run_google_translate_one( + text: str, expected: str, glossary_url=None, target_lang="en" +): + actual = run_google_translate([text], target_lang, glossary_url=glossary_url)[0] assert ( actual.replace(".", "").replace(",", "").strip().lower() == expected.replace(".", "").replace(",", "").strip().lower()