diff --git a/README.md b/README.md index a2fbe6a7a..322e83718 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,9 @@ pg_restore --no-privileges --no-owner -d $PGDATABASE $fname cid=$(docker ps | grep gooey-api-prod | cut -d " " -f 1 | head -1) # exec the script to create the fixture docker exec -it $cid poetry run ./manage.py runscript create_fixture +``` + +```bash # copy the fixture outside container docker cp $cid:/app/fixture.json . # print the absolute path diff --git a/app_users/admin.py b/app_users/admin.py index 191197eb6..ffe8ccdc6 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -41,6 +41,7 @@ class AppUserAdmin(admin.ModelAdmin): "view_transactions", "open_in_firebase", "open_in_stripe", + "low_balance_email_sent_at", ] @admin.display(description="User Runs") diff --git a/app_users/migrations/0012_appuser_low_balance_email_sent_at.py b/app_users/migrations/0012_appuser_low_balance_email_sent_at.py new file mode 100644 index 000000000..efc2beaf0 --- /dev/null +++ b/app_users/migrations/0012_appuser_low_balance_email_sent_at.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-14 07:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0011_appusertransaction_charged_amount_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='appuser', + name='low_balance_email_sent_at', + field=models.DateTimeField(blank=True, null=True), + ), + ] diff --git a/app_users/migrations/0013_appusertransaction_app_users_a_user_id_9b2e8d_idx_and_more.py b/app_users/migrations/0013_appusertransaction_app_users_a_user_id_9b2e8d_idx_and_more.py new file mode 100644 index 000000000..b992887bb --- /dev/null +++ b/app_users/migrations/0013_appusertransaction_app_users_a_user_id_9b2e8d_idx_and_more.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.7 on 2024-02-28 14:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0012_appuser_low_balance_email_sent_at'), + ] + + operations = [ + migrations.AddIndex( + model_name='appusertransaction', + index=models.Index(fields=['user', 'amount', '-created_at'], name='app_users_a_user_id_9b2e8d_idx'), + ), + migrations.AddIndex( + model_name='appusertransaction', + index=models.Index(fields=['-created_at'], name='app_users_a_created_3c27fe_idx'), + ), + ] diff --git a/app_users/models.py b/app_users/models.py index 576ea0390..9299fba47 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -89,6 +89,8 @@ class AppUser(models.Model): stripe_customer_id = models.CharField(max_length=255, default="", blank=True) is_paying = models.BooleanField("paid", default=False) + low_balance_email_sent_at = models.DateTimeField(null=True, blank=True) + created_at = models.DateTimeField( "created", editable=False, blank=True, default=timezone.now ) @@ -207,7 +209,11 @@ def search_stripe_customer(self) -> stripe.Customer | None: if not self.uid: return None if self.stripe_customer_id: - return stripe.Customer.retrieve(self.stripe_customer_id) + try: + return stripe.Customer.retrieve(self.stripe_customer_id) + except stripe.error.InvalidRequestError as e: + if e.http_status != 404: + raise try: customer = stripe.Customer.search( query=f'metadata["uid"]:"{self.uid}"' @@ -263,6 +269,10 @@ class AppUserTransaction(models.Model): class Meta: verbose_name = "Transaction" + indexes = [ + models.Index(fields=["user", "amount", "-created_at"]), + models.Index(fields=["-created_at"]), + ] def __str__(self): return f"{self.invoice_id} ({self.amount})" diff --git a/bots/admin.py b/bots/admin.py index 524a1c708..23d0a9dd5 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -294,16 +294,25 @@ class SavedRunAdmin(admin.ModelAdmin): django.db.models.JSONField: {"widget": JSONEditorWidget}, } + def get_queryset(self, request): + return ( + super() + .get_queryset(request) + .prefetch_related( + "parent_version", + "parent_version__published_run", + "parent_version__published_run__saved_run", + ) + ) + def lookup_allowed(self, key, value): if key in ["parent_version__published_run__id__exact"]: return True return super().lookup_allowed(key, value) def view_user(self, saved_run: SavedRun): - return change_obj_url( - AppUser.objects.get(uid=saved_run.uid), - label=f"{saved_run.uid}", - ) + user = AppUser.objects.get(uid=saved_run.uid) + return change_obj_url(user) view_user.short_description = "View User" diff --git a/bots/migrations/0060_conversation_reset_at.py b/bots/migrations/0060_conversation_reset_at.py new file mode 100644 index 000000000..10cd847b6 --- /dev/null +++ b/bots/migrations/0060_conversation_reset_at.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-20 16:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0059_savedrun_is_api_call'), + ] + + operations = [ + migrations.AddField( + model_name='conversation', + name='reset_at', + field=models.DateTimeField(blank=True, default=None, null=True), + ), + ] diff --git a/bots/models.py b/bots/models.py index 51a969a68..6354556b3 100644 --- a/bots/models.py +++ b/bots/models.py @@ -860,6 +860,7 @@ class Conversation(models.Model): ) created_at = models.DateTimeField(auto_now_add=True) + reset_at = models.DateTimeField(null=True, blank=True, default=None) objects = ConversationQuerySet.as_manager() @@ -1013,7 +1014,11 @@ def to_df_analysis_format( ) return df - def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]: + def as_llm_context( + self, limit: int = 50, reset_at: datetime.datetime = None + ) -> list["ConversationEntry"]: + if reset_at: + self = self.filter(created_at__gt=reset_at) msgs = self.order_by("-created_at").prefetch_related("attachments")[:limit] entries = [None] * len(msgs) for i, msg in enumerate(reversed(msgs)): diff --git a/bots/tasks.py b/bots/tasks.py index 3978b429d..ff0d9d2fb 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -1,4 +1,5 @@ import json +from json import JSONDecodeError from celery import shared_task from django.db.models import QuerySet @@ -65,8 +66,15 @@ def msg_analysis(msg_id: int): raise RuntimeError(sr.error_msg) # save the result as json + output_text = flatten(sr.state["output_text"].values())[0] + try: + analysis_result = json.loads(output_text) + except JSONDecodeError: + analysis_result = { + "error": "Failed to parse the analysis result. Please check your script.", + } Message.objects.filter(id=msg_id).update( - analysis_result=json.loads(flatten(sr.state["output_text"].values())[0]), + analysis_result=analysis_result, ) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 8d13a6f7d..a4a6ffe07 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -1,18 +1,26 @@ +import datetime +import html import traceback import typing from time import time from types import SimpleNamespace +import requests import sentry_sdk +from django.db.models import Sum +from django.utils import timezone +from fastapi import HTTPException import gooey_ui as st -from app_users.models import AppUser +from app_users.models import AppUser, AppUserTransaction from bots.models import SavedRun from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings -from daras_ai_v2.base import StateKeys, err_msg_for_exc, BasePage +from daras_ai_v2.base import StateKeys, BasePage +from daras_ai_v2.exceptions import UserError from daras_ai_v2.send_email import send_email_via_postmark +from daras_ai_v2.send_email import send_low_balance_email from daras_ai_v2.settings import templates from gooey_ui.pubsub import realtime_push from gooey_ui.state import set_query_params @@ -32,6 +40,17 @@ def gui_runner( is_api_call: bool = False, ): page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id))) + + def event_processor(event, hint): + event["request"] = { + "method": "POST", + "url": page.app_url(query_params=query_params), + "data": state, + } + return event + + page.setup_sentry(event_processor=event_processor) + sr = page.run_doc_sr(run_id, uid) sr.is_api_call = is_api_call @@ -95,8 +114,25 @@ def save(done=False): # render errors nicely except Exception as e: run_time += time() - start_time - traceback.print_exc() - sentry_sdk.capture_exception(e) + + if isinstance(e, HTTPException) and e.status_code == 402: + error_msg = page.generate_credit_error_message( + example_id=query_params.get("example_id"), + run_id=run_id, + uid=uid, + ) + try: + raise UserError(error_msg) from e + except UserError as e: + sentry_sdk.capture_exception(e, level=e.sentry_level) + break + + if isinstance(e, UserError): + sentry_level = e.sentry_level + else: + sentry_level = "error" + traceback.print_exc() + sentry_sdk.capture_exception(e, level=sentry_level) error_msg = err_msg_for_exc(e) break finally: @@ -105,6 +141,69 @@ def save(done=False): save(done=True) if not is_api_call: send_email_on_completion(page, sr) + run_low_balance_email_check(uid) + + +def err_msg_for_exc(e: Exception): + if isinstance(e, requests.HTTPError): + response: requests.Response = e.response + try: + err_body = response.json() + except requests.JSONDecodeError: + err_str = response.text + else: + format_exc = err_body.get("format_exc") + if format_exc: + print("⚡️ " + format_exc) + err_type = err_body.get("type") + err_str = err_body.get("str") + if err_type and err_str: + return f"(GPU) {err_type}: {err_str}" + err_str = str(err_body) + return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}" + elif isinstance(e, HTTPException): + return f"(HTTP {e.status_code}) {e.detail})" + elif isinstance(e, UserError): + return e.message + else: + return f"{type(e).__name__}: {e}" + + +def run_low_balance_email_check(uid: str): + # don't send email if feature is disabled + if not settings.LOW_BALANCE_EMAIL_ENABLED: + return + user = AppUser.objects.get(uid=uid) + # don't send email if user is not paying or has enough balance + if not user.is_paying or user.balance > settings.LOW_BALANCE_EMAIL_CREDITS: + return + last_purchase = ( + AppUserTransaction.objects.filter(user=user, amount__gt=0) + .order_by("-created_at") + .first() + ) + email_date_cutoff = timezone.now() - datetime.timedelta( + days=settings.LOW_BALANCE_EMAIL_DAYS + ) + # send email if user has not been sent email in last X days or last purchase was after last email sent + if ( + # user has not been sent any email + not user.low_balance_email_sent_at + # user was sent email before X days + or (user.low_balance_email_sent_at < email_date_cutoff) + # user has made a purchase after last email sent + or (last_purchase and last_purchase.created_at > user.low_balance_email_sent_at) + ): + # calculate total credits consumed in last X days + total_credits_consumed = abs( + AppUserTransaction.objects.filter( + user=user, amount__lt=0, created_at__gte=email_date_cutoff + ).aggregate(Sum("amount"))["amount__sum"] + or 0 + ) + send_low_balance_email(user=user, total_credits_consumed=total_credits_consumed) + user.low_balance_email_sent_at = timezone.now() + user.save(update_fields=["low_balance_email_sent_at"]) def send_email_on_completion(page: BasePage, sr: SavedRun): diff --git a/conftest.py b/conftest.py index dba333885..f2003dc0c 100644 --- a/conftest.py +++ b/conftest.py @@ -9,6 +9,22 @@ from auth import auth_backend from celeryapp import app from daras_ai_v2.base import BasePage +from daras_ai_v2.send_email import pytest_outbox + + +def flaky(fn): + max_tries = 5 + + @wraps(fn) + def wrapper(*args, **kwargs): + for i in range(max_tries): + try: + return fn(*args, **kwargs) + except Exception: + if i == max_tries - 1: + raise + + return wrapper @pytest.fixture(scope="session") @@ -44,7 +60,7 @@ def _mock_gui_runner( @pytest.fixture -def threadpool_subtest(subtests, max_workers: int = 8): +def threadpool_subtest(subtests, max_workers: int = 128): ts = [] def submit(fn, *args, msg=None, **kwargs): @@ -68,6 +84,11 @@ def runner(*args, **kwargs): t.join() +@pytest.fixture(autouse=True) +def clear_pytest_outbox(): + pytest_outbox.clear() + + # class DummyDatabaseBlocker(pytest_django.plugin._DatabaseBlocker): # class _dj_db_wrapper: # def ensure_connection(self): diff --git a/daras_ai/extract_face.py b/daras_ai/extract_face.py index e81b35d6e..aeb107a45 100644 --- a/daras_ai/extract_face.py +++ b/daras_ai/extract_face.py @@ -1,5 +1,7 @@ import numpy as np +from daras_ai_v2.exceptions import UserError + def extract_and_reposition_face_cv2( orig_img, @@ -118,7 +120,7 @@ def face_oval_hull_generator(image_cv2): results = face_mesh.process(cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)) if not results.multi_face_landmarks: - raise ValueError("Face not found") + raise UserError("Face not found") for landmark_list in results.multi_face_landmarks: idx_to_coordinates = build_idx_to_coordinates_dict( diff --git a/daras_ai/image_input.py b/daras_ai/image_input.py index 0d9530f22..5e61fcba9 100644 --- a/daras_ai/image_input.py +++ b/daras_ai/image_input.py @@ -11,6 +11,7 @@ from furl import furl from daras_ai_v2 import settings +from daras_ai_v2.exceptions import UserError def resize_img_pad(img_bytes: bytes, size: tuple[int, int]) -> bytes: @@ -90,7 +91,7 @@ def bytes_to_cv2_img(img_bytes: bytes, greyscale=False) -> np.ndarray: flags = cv2.IMREAD_COLOR img_cv2 = cv2.imdecode(np.frombuffer(img_bytes, dtype=np.uint8), flags=flags) if not img_exists(img_cv2): - raise ValueError("Bad Image") + raise UserError("Bad Image") return img_cv2 diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 81562a49f..4403dd32b 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -1,9 +1,7 @@ -import json import os.path -import subprocess +import os.path import tempfile from enum import Enum -from time import sleep import langcodes import requests @@ -14,7 +12,14 @@ import gooey_ui as st from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.azure_asr import azure_asr +from daras_ai_v2.exceptions import ( + raise_for_status, + UserError, + ffmpeg, + call_cmd, + ffprobe, +) from daras_ai_v2.functional import map_parallel from daras_ai_v2.gdrive_downloader import ( is_gdrive_url, @@ -22,25 +27,75 @@ gdrive_metadata, url_to_gdrive_file_id, ) +from daras_ai_v2.google_asr import gcp_asr_v1 from daras_ai_v2.gpu_server import call_celery_task from daras_ai_v2.redis_cache import redis_cache_decorator SHORT_FILE_CUTOFF = 5 * 1024 * 1024 # 1 MB - TRANSLITERATION_SUPPORTED = {"ar", "bn", " gu", "hi", "ja", "kn", "ru", "ta", "te"} -# below CHIRP list was found experimentally since the supported languages list by google is actually wrong: -CHIRP_SUPPORTED = {"af-ZA", "sq-AL", "am-ET", "ar-EG", "hy-AM", "as-IN", "ast-ES", "az-AZ", "eu-ES", "be-BY", "bs-BA", "bg-BG", "my-MM", "ca-ES", "ceb-PH", "ckb-IQ", "zh-Hans-CN", "yue-Hant-HK", "hr-HR", "cs-CZ", "da-DK", "nl-NL", "en-AU", "en-IN", "en-GB", "en-US", "et-EE", "fil-PH", "fi-FI", "fr-CA", "fr-FR", "gl-ES", "ka-GE", "de-DE", "el-GR", "gu-IN", "ha-NG", "iw-IL", "hi-IN", "hu-HU", "is-IS", "id-ID", "it-IT", "ja-JP", "jv-ID", "kea-CV", "kam-KE", "kn-IN", "kk-KZ", "km-KH", "ko-KR", "ky-KG", "lo-LA", "lv-LV", "ln-CD", "lt-LT", "luo-KE", "lb-LU", "mk-MK", "ms-MY", "ml-IN", "mt-MT", "mi-NZ", "mr-IN", "mn-MN", "ne-NP", "ny-MW", "oc-FR", "ps-AF", "fa-IR", "pl-PL", "pt-BR", "pa-Guru-IN", "ro-RO", "ru-RU", "nso-ZA", "sr-RS", "sn-ZW", "sd-IN", "si-LK", "sk-SK", "sl-SI", "so-SO", "es-ES", "es-US", "su-ID", "sw", "sv-SE", "tg-TJ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-PK", "uz-UZ", "vi-VN", "cy-GB", "wo-SN", "yo-NG", "zu-ZA"} # fmt: skip - -WHISPER_SUPPORTED = {"af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da", "nl", "en", "et", "fi", "fr", "gl", "de", "el", "he", "hi", "hu", "is", "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi", "ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy"} # fmt: skip +# https://cloud.google.com/speech-to-text/docs/speech-to-text-supported-languages +GCP_V1_SUPPORTED = { + "af-ZA", "sq-AL", "am-ET", "ar-DZ", "ar-BH", "ar-EG", "ar-IQ", "ar-IL", "ar-JO", "ar-KW", "ar-LB", "ar-MR", "ar-MA", + "ar-OM", "ar-QA", "ar-SA", "ar-PS", "ar-SY", "ar-TN", "ar-AE", "ar-YE", "hy-AM", "az-AZ", "eu-ES", "bn-BD", "bn-IN", + "bs-BA", "bg-BG", "my-MM", "ca-ES", "yue-Hant-HK", "zh", "zh-TW", "hr-HR", "cs-CZ", + "da-DK", "nl-BE", "nl-NL", "en-AU", "en-CA", "en-GH", "en-HK", "en-IN", "en-IE", "en-KE", "en-NZ", "en-NG", "en-PK", + "en-PH", "en-SG", "en-ZA", "en-TZ", "en-GB", "en-US", "et-EE", "fil-PH", "fi-FI", "fr-BE", "fr-CA", "fr-FR", + "fr-CH", "gl-ES", "ka-GE", "de-AT", "de-DE", "de-CH", "el-GR", "gu-IN", "iw-IL", "hi-IN", "hu-HU", "is-IS", "id-ID", + "it-IT", "it-CH", "ja-JP", "jv-ID", "kn-IN", "kk-KZ", "km-KH", "ko-KR", "lo-LA", "lv-LV", "lt-LT", "mk-MK", "ms-MY", + "ml-IN", "mr-IN", "mn-MN", "ne-NP", "no-NO", "fa-IR", "pl-PL", "pt-BR", "pt-PT", "pa-Guru-IN", "ro-RO", "ru-RU", + "sr-RS", "si-LK", "sk-SK", "sl-SI", "es-AR", "es-BO", "es-CL", "es-CO", "es-CR", "es-DO", "es-EC", "es-SV", "es-GT", + "es-HN", "es-MX", "es-NI", "es-PA", "es-PY", "es-PE", "es-PR", "es-ES", "es-US", "es-UY", "es-VE", "su-ID", "sw-KE", + "sw-TZ", "sv-SE", "ta-IN", "ta-MY", "ta-SG", "ta-LK", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-IN", "ur-PK", "uz-UZ", + "vi-VN", "zu-ZA", +} # fmt: skip + +# https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages +CHIRP_SUPPORTED = { + "af-ZA", "sq-AL", "am-ET", "ar-EG", "hy-AM", "as-IN", "ast-ES", "az-AZ", "eu-ES", "be-BY", "bs-BA", "bg-BG", + "my-MM", "ca-ES", "ceb-PH", "ckb-IQ", "yue-Hant-HK", "zh-TW", "hr-HR", "cs-CZ", "da-DK", "nl-NL", + "en-AU", "en-IN", "en-GB", "en-US", "et-EE", "fil-PH", "fi-FI", "fr-CA", "fr-FR", "gl-ES", "ka-GE", "de-DE", + "el-GR", "gu-IN", "ha-NG", "iw-IL", "hi-IN", "hu-HU", "is-IS", "id-ID", "it-IT", "ja-JP", "jv-ID", "kea-CV", + "kam-KE", "kn-IN", "kk-KZ", "km-KH", "ko-KR", "ky-KG", "lo-LA", "lv-LV", "ln-CD", "lt-LT", "luo-KE", "lb-LU", + "mk-MK", "ms-MY", "ml-IN", "mt-MT", "mi-NZ", "mr-IN", "mn-MN", "ne-NP", "no-NO", "ny-MW", "oc-FR", "ps-AF", "fa-IR", + "pl-PL", "pt-BR", "pa-Guru-IN", "ro-RO", "ru-RU", "nso-ZA", "sr-RS", "sn-ZW", "sd-IN", "si-LK", "sk-SK", "sl-SI", + "so-SO", "es-ES", "es-US", "su-ID", "sw", "sv-SE", "tg-TJ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-PK", + "uz-UZ", "vi-VN", "cy-GB", "wo-SN", "yo-NG", "zu-ZA" +} # fmt: skip + +WHISPER_SUPPORTED = { + "af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da", "nl", "en", "et", "fi", "fr", "gl", "de", + "el", "he", "hi", "hu", "is", "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi", "ne", "no", + "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy" +} # fmt: skip # See page 14 of https://scontent-sea1-1.xx.fbcdn.net/v/t39.2365-6/369747868_602316515432698_2401716319310287708_n.pdf?_nc_cat=106&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=_5cpNOcftdYAX8rCrVo&_nc_ht=scontent-sea1-1.xx&oh=00_AfDVkx7XubifELxmB_Un-yEYMJavBHFzPnvTbTlalbd_1Q&oe=65141B39 # For now, below are listed the languages that support ASR. Note that Seamless only accepts ISO 639-3 codes. -SEAMLESS_SUPPORTED = {"afr", "amh", "arb", "ary", "arz", "asm", "ast", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", "cmn", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "gaz", "gle", "glg", "guj", "heb", "hin", "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", "kan", "kat", "kaz", "kea", "khk", "khm", "kir", "kor", "lao", "lit", "ltz", "lug", "luo", "lvs", "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld", "nno", "nob", "npi", "nya", "oci", "ory", "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna", "snd", "som", "spa", "srp", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", "xho", "yor", "yue", "zlm", "zul"} # fmt: skip - -AZURE_SUPPORTED = {"af-ZA", "am-ET", "ar-AE", "ar-BH", "ar-DZ", "ar-EG", "ar-IL", "ar-IQ", "ar-JO", "ar-KW", "ar-LB", "ar-LY", "ar-MA", "ar-OM", "ar-PS", "ar-QA", "ar-SA", "ar-SY", "ar-TN", "ar-YE", "az-AZ", "bg-BG", "bn-IN", "bs-BA", "ca-ES", "cs-CZ", "cy-GB", "da-DK", "de-AT", "de-CH", "de-DE", "el-GR", "en-AU", "en-CA", "en-GB", "en-GH", "en-HK", "en-IE", "en-IN", "en-KE", "en-NG", "en-NZ", "en-PH", "en-SG", "en-TZ", "en-US", "en-ZA", "es-AR", "es-BO", "es-CL", "es-CO", "es-CR", "es-CU", "es-DO", "es-EC", "es-ES", "es-GQ", "es-GT", "es-HN", "es-MX", "es-NI", "es-PA", "es-PE", "es-PR", "es-PY", "es-SV", "es-US", "es-UY", "es-VE", "et-EE", "eu-ES", "fa-IR", "fi-FI", "fil-PH", "fr-BE", "fr-CA", "fr-CH", "fr-FR", "ga-IE", "gl-ES", "gu-IN", "he-IL", "hi-IN", "hr-HR", "hu-HU", "hy-AM", "id-ID", "is-IS", "it-CH", "it-IT", "ja-JP", "jv-ID", "ka-GE", "kk-KZ", "km-KH", "kn-IN", "ko-KR", "lo-LA", "lt-LT", "lv-LV", "mk-MK", "ml-IN", "mn-MN", "mr-IN", "ms-MY", "mt-MT", "my-MM", "nb-NO", "ne-NP", "nl-BE", "nl-NL", "pa-IN", "pl-PL", "ps-AF", "pt-BR", "pt-PT", "ro-RO", "ru-RU", "si-LK", "sk-SK", "sl-SI", "so-SO", "sq-AL", "sr-RS", "sv-SE", "sw-KE", "sw-TZ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-IN", "uz-UZ", "vi-VN", "wuu-CN", "yue-CN", "zh-CN", "zh-CN-shandong", "zh-CN-sichuan", "zh-HK", "zh-TW", "zu-ZA"} # fmt: skip -MAX_POLLS = 100 +SEAMLESS_SUPPORTED = { + "afr", "amh", "arb", "ary", "arz", "asm", "ast", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", + "cmn", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "gaz", "gle", "glg", "guj", "heb", "hin", + "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", "kan", "kat", "kaz", "kea", "khk", "khm", + "kir", "kor", "lao", "lit", "ltz", "lug", "luo", "lvs", "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld", + "nno", "nob", "npi", "nya", "oci", "ory", "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna", + "snd", "som", "spa", "srp", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", + "xho", "yor", "yue", "zlm", "zul" +} # fmt: skip + +AZURE_SUPPORTED = { + "af-ZA", "am-ET", "ar-AE", "ar-BH", "ar-DZ", "ar-EG", "ar-IL", "ar-IQ", "ar-JO", "ar-KW", "ar-LB", "ar-LY", "ar-MA", + "ar-OM", "ar-PS", "ar-QA", "ar-SA", "ar-SY", "ar-TN", "ar-YE", "az-AZ", "bg-BG", "bn-IN", "bs-BA", "ca-ES", "cs-CZ", + "cy-GB", "da-DK", "de-AT", "de-CH", "de-DE", "el-GR", "en-AU", "en-CA", "en-GB", "en-GH", "en-HK", "en-IE", "en-IN", + "en-KE", "en-NG", "en-NZ", "en-PH", "en-SG", "en-TZ", "en-US", "en-ZA", "es-AR", "es-BO", "es-CL", "es-CO", "es-CR", + "es-CU", "es-DO", "es-EC", "es-ES", "es-GQ", "es-GT", "es-HN", "es-MX", "es-NI", "es-PA", "es-PE", "es-PR", "es-PY", + "es-SV", "es-US", "es-UY", "es-VE", "et-EE", "eu-ES", "fa-IR", "fi-FI", "fil-PH", "fr-BE", "fr-CA", "fr-CH", + "fr-FR", "ga-IE", "gl-ES", "gu-IN", "he-IL", "hi-IN", "hr-HR", "hu-HU", "hy-AM", "id-ID", "is-IS", "it-CH", "it-IT", + "ja-JP", "jv-ID", "ka-GE", "kk-KZ", "km-KH", "kn-IN", "ko-KR", "lo-LA", "lt-LT", "lv-LV", "mk-MK", "ml-IN", "mn-MN", + "mr-IN", "ms-MY", "mt-MT", "my-MM", "nb-NO", "ne-NP", "nl-BE", "nl-NL", "pa-IN", "pl-PL", "ps-AF", "pt-BR", "pt-PT", + "ro-RO", "ru-RU", "si-LK", "sk-SK", "sl-SI", "so-SO", "sq-AL", "sr-RS", "sv-SE", "sw-KE", "sw-TZ", "ta-IN", "te-IN", + "th-TH", "tr-TR", "uk-UA", "ur-IN", "uz-UZ", "vi-VN", "wuu-CN", "yue-CN", "zh-CN", "zh-CN-shandong", + "zh-CN-sichuan", "zh-HK", "zh-TW", "zu-ZA" +} # fmt: skip # https://deepgram.com/product/languages for the "general" model: # DEEPGRAM_SUPPORTED = {"nl","en","en-AU","en-US","en-GB","en-NZ","en-IN","fr","fr-CA","de","hi","hi-Latn","id","it","ja","ko","cmn-Hans-CN","cmn-Hant-TW","no","pl","pt","pt-PT","pt-BR","ru","es","es-419","sv","tr","uk"} # fmt: skip @@ -56,15 +111,14 @@ class AsrModels(Enum): nemo_english = "Conformer English (ai4bharat.org)" nemo_hindi = "Conformer Hindi (ai4bharat.org)" vakyansh_bhojpuri = "Vakyansh Bhojpuri (Open-Speech-EkStep)" - usm = "Chirp / USM (Google)" + gcp_v1 = "Google Cloud V1" + usm = "Chirp / USM (Google V2)" deepgram = "Deepgram" azure = "Azure Speech" seamless_m4t = "Seamless M4T (Facebook Research)" def supports_auto_detect(self) -> bool: - return self not in { - self.azure, - } + return self not in {self.azure, self.gcp_v1} asr_model_ids = { @@ -89,6 +143,7 @@ def supports_auto_detect(self) -> bool: asr_supported_languages = { AsrModels.whisper_large_v3: WHISPER_SUPPORTED, AsrModels.whisper_large_v2: WHISPER_SUPPORTED, + AsrModels.gcp_v1: GCP_V1_SUPPORTED, AsrModels.usm: CHIRP_SUPPORTED, AsrModels.deepgram: DEEPGRAM_SUPPORTED, AsrModels.seamless_m4t: SEAMLESS_SUPPORTED, @@ -128,7 +183,7 @@ def google_translate_language_selector( label: the label to display key: the key to save the selected language to in the session state """ - languages = google_translate_languages() + languages = google_translate_target_languages() options = list(languages.keys()) if allow_none: options.insert(0, None) @@ -141,8 +196,8 @@ def google_translate_language_selector( ) -@redis_cache_decorator -def google_translate_languages() -> dict[str, str]: +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def google_translate_target_languages() -> dict[str, str]: """ Get list of supported languages for Google Translate. :return: Dictionary of language codes and display names. @@ -162,8 +217,8 @@ def google_translate_languages() -> dict[str, str]: } -@redis_cache_decorator -def google_translate_input_languages() -> dict[str, str]: +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def google_translate_source_languages() -> dict[str, str]: """ Get list of supported languages for Google Translate. :return: Dictionary of language codes and display names. @@ -246,11 +301,11 @@ def run_google_translate( if source_language: source_language = langcodes.Language.get(source_language).to_tag() source_language = get_language_in_collection( - source_language, google_translate_input_languages().keys() + source_language, google_translate_source_languages().keys() ) # this will default to autodetect if language is not found as supported target_language = langcodes.Language.get(target_language).to_tag() target_language: str | None = get_language_in_collection( - target_language, google_translate_languages().keys() + target_language, google_translate_target_languages().keys() ) if not target_language: raise ValueError(f"Unsupported target language: {target_language!r}") @@ -467,6 +522,8 @@ def run_asr( src_lang=language, ), ) + elif selected_model == AsrModels.gcp_v1: + return gcp_asr_v1(audio_url, language) elif selected_model == AsrModels.usm: location = settings.GCP_REGION @@ -575,7 +632,7 @@ def run_asr( assert data.get("chunks"), f"{selected_model.value} can't generate VTT" return generate_vtt(data["chunks"]) case _: - raise ValueError(f"Invalid output format: {output_format}") + raise UserError(f"Invalid output format: {output_format}") def _get_or_create_recognizer( @@ -611,64 +668,6 @@ def _get_or_create_recognizer( return recognizer -def azure_asr(audio_url: str, language: str): - # transcription from audio url only supported via rest api or cli - # Start by initializing a request - payload = { - "contentUrls": [ - audio_url, - ], - "displayName": "Gooey Transcription", - "model": None, - "properties": { - "wordLevelTimestampsEnabled": False, - }, - "locale": language or "en-US", - } - r = requests.post( - str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/transcriptions"), - headers={ - "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, - "Content-Type": "application/json", - }, - json=payload, - ) - raise_for_status(r) - uri = r.json()["self"] - - # poll for results - for _ in range(MAX_POLLS): - r = requests.get( - uri, - headers={ - "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, - }, - ) - if not r.ok or not r.json()["status"] == "Succeeded": - sleep(5) - continue - r = requests.get( - r.json()["links"]["files"], - headers={ - "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, - }, - ) - raise_for_status(r) - transcriptions = [] - for value in r.json()["values"]: - if value["kind"] != "Transcription": - continue - r = requests.get( - value["links"]["contentUrl"], - headers={"Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY}, - ) - raise_for_status(r) - combined_phrases = r.json().get("combinedRecognizedPhrases") or [{}] - transcriptions += [combined_phrases[0].get("display", "")] - return "\n".join(transcriptions) - assert False, "Max polls exceeded, Azure speech did not yield a response" - - # 16kHz, 16-bit, mono FFMPEG_WAV_ARGS = ["-vn", "-acodec", "pcm_s16le", "-ac", "1", "-ar", "16000"] @@ -683,7 +682,7 @@ def download_youtube_to_wav(youtube_url: str) -> tuple[str, int]: infile = os.path.join(tmpdir, "infile") outfile = os.path.join(tmpdir, "outfile.wav") # run yt-dlp to download audio - args = [ + call_cmd( "yt-dlp", "--no-playlist", "--format", @@ -691,13 +690,9 @@ def download_youtube_to_wav(youtube_url: str) -> tuple[str, int]: "--output", infile, youtube_url, - ] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) + ) # convert audio to single channel wav - args = ["ffmpeg", "-y", "-i", infile, *FFMPEG_WAV_ARGS, outfile] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) + ffmpeg("-i", infile, *FFMPEG_WAV_ARGS, outfile) # read wav file into memory with open(outfile, "rb") as f: wavdata = f.read() @@ -728,43 +723,12 @@ def audio_bytes_to_wav(audio_bytes: bytes) -> tuple[bytes | None, int]: with tempfile.NamedTemporaryFile(suffix=".wav") as outfile: # convert audio to single channel wav - args = [ - "ffmpeg", - "-y", - "-i", - infile.name, - *FFMPEG_WAV_ARGS, - outfile.name, - ] - print("\t$ " + " ".join(args)) - try: - subprocess.check_output(args, stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - ffmpeg_output_error = ValueError(e.output, e) - raise ValueError( - "Invalid audio file. Could not convert audio to wav format. Please confirm the file is not corrupted and has a supported format (google 'ffmpeg supported audio file types')" - ) from ffmpeg_output_error + ffmpeg("-i", infile.name, *FFMPEG_WAV_ARGS, outfile.name) return outfile.read(), os.path.getsize(outfile.name) def check_wav_audio_format(filename: str) -> bool: - args = [ - "ffprobe", - "-v", - "quiet", - "-print_format", - "json", - "-show_streams", - filename, - ] - print("\t$ " + " ".join(args)) - try: - data = json.loads(subprocess.check_output(args, stderr=subprocess.STDOUT)) - except subprocess.CalledProcessError as e: - ffmpeg_output_error = ValueError(e.output, e) - raise ValueError( - "Invalid audio file. Please confirm the file is not corrupted and has a supported format (google 'ffmpeg supported audio file types')" - ) from ffmpeg_output_error + data = ffprobe(filename) return ( len(data["streams"]) == 1 and data["streams"][0]["codec_name"] == "pcm_s16le" diff --git a/daras_ai_v2/azure_asr.py b/daras_ai_v2/azure_asr.py new file mode 100644 index 000000000..aed873b03 --- /dev/null +++ b/daras_ai_v2/azure_asr.py @@ -0,0 +1,96 @@ +import datetime +from time import sleep + +import requests +from furl import furl + +from daras_ai_v2 import settings +from daras_ai_v2.exceptions import ( + raise_for_status, +) +from daras_ai_v2.redis_cache import redis_cache_decorator + +# 20 mins timeout +MAX_POLLS = 200 +POLL_INTERVAL = 6 + + +def azure_asr(audio_url: str, language: str): + # Start by initializing a request + # https://eastus.dev.cognitive.microsoft.com/docs/services/speech-to-text-api-v3-1/operations/Transcriptions_Create + language = language or "en-US" + r = requests.post( + str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/transcriptions"), + headers=azure_auth_header(), + json={ + "contentUrls": [audio_url], + "displayName": f"Gooey Transcription {datetime.datetime.now().isoformat()} {language=} {audio_url=}", + "model": azure_get_latest_model(language), + "properties": { + "wordLevelTimestampsEnabled": False, + # "displayFormWordLevelTimestampsEnabled": True, + # "diarizationEnabled": False, + # "punctuationMode": "DictatedAndAutomatic", + # "profanityFilterMode": "Masked", + }, + "locale": language, + }, + ) + raise_for_status(r) + uri = r.json()["self"] + + # poll for results + for _ in range(MAX_POLLS): + r = requests.get(uri, headers=azure_auth_header()) + if not r.ok or not r.json()["status"] == "Succeeded": + sleep(POLL_INTERVAL) + continue + r = requests.get(r.json()["links"]["files"], headers=azure_auth_header()) + raise_for_status(r) + transcriptions = [] + for value in r.json()["values"]: + if value["kind"] != "Transcription": + continue + r = requests.get(value["links"]["contentUrl"], headers=azure_auth_header()) + raise_for_status(r) + combined_phrases = r.json().get("combinedRecognizedPhrases") or [{}] + transcriptions += [combined_phrases[0].get("display", "")] + return "\n".join(transcriptions) + + raise RuntimeError("Max polls exceeded, Azure speech did not yield a response") + + +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def azure_get_latest_model(language: str) -> dict | None: + # https://eastus.dev.cognitive.microsoft.com/docs/services/speech-to-text-api-v3-1/operations/Models_ListBaseModels + r = requests.get( + str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/models/base"), + headers=azure_auth_header(), + params={"filter": f"locale eq '{language}'"}, + ) + raise_for_status(r) + data = r.json()["values"] + try: + models = sorted( + data, + key=lambda m: datetime.datetime.strptime( + m["createdDateTime"], "%Y-%m-%dT%H:%M:%SZ" + ), + reverse=True, + ) + # ignore date parsing errors + except ValueError: + models = data + models.reverse() + for model in models: + if "whisper" in model["displayName"].lower(): + # whisper is pretty slow on azure, so we ignore it + continue + # return the latest model + return {"self": model["self"]} + + +def azure_auth_header(): + return { + "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, + } diff --git a/daras_ai_v2/azure_doc_extract.py b/daras_ai_v2/azure_doc_extract.py index 173bb080d..878dc5733 100644 --- a/daras_ai_v2/azure_doc_extract.py +++ b/daras_ai_v2/azure_doc_extract.py @@ -26,7 +26,7 @@ def azure_doc_extract_pages( ] -@redis_cache_decorator +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) def azure_form_recognizer_models() -> dict[str, str]: r = requests.get( str( @@ -40,7 +40,7 @@ def azure_form_recognizer_models() -> dict[str, str]: return {value["modelId"]: value["description"] for value in r.json()["value"]} -@redis_cache_decorator +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) def azure_form_recognizer(url: str, model_id: str, params: dict = None): r = requests.post( str( diff --git a/daras_ai_v2/azure_image_moderation.py b/daras_ai_v2/azure_image_moderation.py index 30da7a561..c9afc871a 100644 --- a/daras_ai_v2/azure_image_moderation.py +++ b/daras_ai_v2/azure_image_moderation.py @@ -1,7 +1,5 @@ -from typing import Any - -from furl import furl import requests +from furl import furl from daras_ai_v2 import settings from daras_ai_v2.exceptions import raise_for_status @@ -11,7 +9,7 @@ def get_auth_headers(): return {"Ocp-Apim-Subscription-Key": settings.AZURE_IMAGE_MODERATION_KEY} -def run_moderator(image_url: str, cache: bool) -> dict[str, Any]: +def is_image_nsfw(image_url: str, cache: bool = False) -> bool: url = str( furl(settings.AZURE_IMAGE_MODERATION_ENDPOINT) / "contentmoderator/moderate/v1.0/ProcessImage/Evaluate" @@ -22,10 +20,9 @@ def run_moderator(image_url: str, cache: bool) -> dict[str, Any]: headers=get_auth_headers(), json={"DataRepresentation": "URL", "Value": image_url}, ) + if r.status_code == 400 and ( + b"Image Size Error" in r.content or b"Image Error" in r.content + ): + return False raise_for_status(r) - return r.json() - - -def is_image_nsfw(image_url: str, cache: bool = False) -> bool: - response = run_moderator(image_url=image_url, cache=cache) - return response["IsImageAdultClassified"] + return r.json().get("IsImageAdultClassified", False) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index c013cc134..c52d6b5a6 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -13,7 +13,6 @@ from time import sleep from types import SimpleNamespace -import requests import sentry_sdk from django.utils import timezone from django.utils.text import slugify @@ -183,12 +182,40 @@ def get_tab_url(self, tab: str, query_params: dict = {}) -> str: query_params=query_params, ) - def setup_render(self): + def setup_sentry(self, event_processor: typing.Callable = None): + def add_user_to_event(event, hint): + user = self.request and self.request.user + if not user: + return event + event["user"] = { + "id": user.id, + "name": user.display_name, + "email": user.email, + "data": { + field: getattr(user, field) + for field in [ + "uid", + "phone_number", + "photo_url", + "balance", + "is_paying", + "is_anonymous", + "is_disabled", + "disable_safety_checker", + "created_at", + ] + }, + } + return event + with sentry_sdk.configure_scope() as scope: scope.set_extra("base_url", self.app_url()) scope.set_transaction_name( "/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE ) + scope.add_event_processor(add_user_to_event) + if event_processor: + scope.add_event_processor(event_processor) def refresh_state(self): _, run_id, uid = extract_query_params(gooey_get_query_params()) @@ -198,7 +225,7 @@ def refresh_state(self): st.session_state.update(output) def render(self): - self.setup_render() + self.setup_sentry() if self.get_run_state(st.session_state) == RecipeRunState.running: self.refresh_state() @@ -343,7 +370,7 @@ def _render_social_buttons(self, show_button_text: bool = False): copy_to_clipboard_button( f'{button_text}', - value=self._get_current_app_url(), + value=self.get_tab_url(self.tab), type="secondary", className="mb-0 ms-lg-2", ) @@ -1324,7 +1351,6 @@ def _render_input_col(self): self.render_form_v2() with st.expander("⚙️ Settings"): self.render_settings() - st.write("---") submitted = self.render_submit_button() with st.div(style={"textAlign": "right"}): st.caption( @@ -1820,8 +1846,8 @@ def run_as_api_tab(self): as_async = st.checkbox("##### Run Async") as_form_data = st.checkbox("##### Upload Files via Form Data") - request_body = get_example_request_body( - self.RequestModel, st.session_state, include_all=include_all + request_body = self.get_example_request_body( + st.session_state, include_all=include_all ) response_body = self.get_example_response_body( st.session_state, as_async=as_async, include_all=include_all @@ -1867,7 +1893,27 @@ def get_price_roundoff(self, state: dict) -> int: return max(1, math.ceil(self.get_raw_price(state))) def get_raw_price(self, state: dict) -> float: - return self.price + return self.price * state.get("num_outputs", 1) + + @classmethod + def get_example_preferred_fields(cls, state: dict) -> list[str]: + """ + Fields that are not required, but are preferred to be shown in the example. + """ + return [] + + @classmethod + def get_example_request_body( + cls, + state: dict, + include_all: bool = False, + ) -> dict: + return extract_model_fields( + cls.RequestModel, + state, + include_all=include_all, + preferred_fields=cls.get_example_preferred_fields(state), + ) def get_example_response_body( self, @@ -1883,6 +1929,7 @@ def get_example_response_body( run_id=run_id, uid=self.request.user and self.request.user.uid, ) + output = extract_model_fields(self.ResponseModel, state, include_all=True) if as_async: return dict( run_id=run_id, @@ -1890,18 +1937,14 @@ def get_example_response_body( created_at=created_at, run_time_sec=st.session_state.get(StateKeys.run_time, 0), status="completed", - output=get_example_request_body( - self.ResponseModel, state, include_all=include_all - ), + output=output, ) else: return dict( id=run_id, url=web_url, created_at=created_at, - output=get_example_request_body( - self.ResponseModel, state, include_all=include_all - ), + output=output, ) def additional_notes(self) -> str | None: @@ -1949,15 +1992,21 @@ def render_output_caption(): st.caption(caption, unsafe_allow_html=True) -def get_example_request_body( - request_model: typing.Type[BaseModel], +def extract_model_fields( + model: typing.Type[BaseModel], state: dict, include_all: bool = False, + preferred_fields: list[str] = None, ) -> dict: + """Only returns required fields unless include_all is set to True.""" return { field_name: state.get(field_name) - for field_name, field in request_model.__fields__.items() - if include_all or field.required + for field_name, field in model.__fields__.items() + if ( + include_all + or field.required + or (preferred_fields and field_name in preferred_fields) + ) } @@ -1977,27 +2026,6 @@ def extract_nested_str(obj) -> str: return "" -def err_msg_for_exc(e): - if isinstance(e, requests.HTTPError): - response: requests.Response = e.response - try: - err_body = response.json() - except requests.JSONDecodeError: - err_str = response.text - else: - format_exc = err_body.get("format_exc") - if format_exc: - print("⚡️ " + format_exc) - err_type = err_body.get("type") - err_str = err_body.get("str") - if err_type and err_str: - return f"(GPU) {err_type}: {err_str}" - err_str = str(err_body) - return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}" - else: - return f"{type(e).__name__}: {e}" - - def force_redirect(url: str): # note: assumes sanitized URLs st.html( diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index ee977b294..7075e3b2d 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -254,8 +254,8 @@ def _on_msg(bot: BotInterface): return # handle reset keyword if input_text.lower() == RESET_KEYWORD: - # clear saved messages - bot.convo.messages.all().delete() + # record the reset time so we don't send context + bot.convo.reset_at = timezone.now() # reset convo state bot.convo.state = ConvoState.INITIAL bot.convo.save() @@ -317,8 +317,8 @@ def _process_and_send_msg( recieved_time: datetime, speech_run: str | None, ): - # get latest messages for context (upto 100) - saved_msgs = bot.convo.messages.all().as_llm_context() + # get latest messages for context + saved_msgs = bot.convo.messages.all().as_llm_context(reset_at=bot.convo.reset_at) # # mock testing # result = _mock_api_output(input_text) @@ -569,8 +569,12 @@ def _handle_audio_msg(billing_account_user, bot: BotInterface): selected_model = AsrModels.whisper_telugu_large_v2.name case "bho": selected_model = AsrModels.vakyansh_bhojpuri.name - case "en": - selected_model = AsrModels.usm.name + case "sw": + selected_model = AsrModels.seamless_m4t.name + language = "swh" + # case "en": + # selected_model = AsrModels.usm.name + # language = "am-et" case _: selected_model = AsrModels.whisper_large_v2.name diff --git a/daras_ai_v2/db.py b/daras_ai_v2/db.py index a4d4c287f..ffb10bd18 100644 --- a/daras_ai_v2/db.py +++ b/daras_ai_v2/db.py @@ -1,3 +1,9 @@ +import typing + +if typing.TYPE_CHECKING: + from google.cloud import firestore + + FIREBASE_SESSION_COOKIE = "firebase_session" ANONYMOUS_USER_COOKIE = "anonymous_user" diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index fb69f22a9..4c0464475 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -7,6 +7,7 @@ import gooey_ui as st from daras_ai_v2 import settings from daras_ai_v2.asr import AsrModels, google_translate_language_selector +from daras_ai_v2.prompt_vars import prompt_vars_widget from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder from daras_ai_v2.search_ref import CitationStyles @@ -123,6 +124,9 @@ def doc_search_settings( key="query_instructions", height=300, ) + prompt_vars_widget( + "query_instructions", + ) if keyword_instructions_allowed: st.text_area( """ @@ -133,6 +137,9 @@ def doc_search_settings( key="keyword_instructions", height=300, ) + prompt_vars_widget( + "keyword_instructions", + ) dense_weight_ = DocSearchRequest.__fields__["dense_weight"] st.slider( diff --git a/daras_ai_v2/exceptions.py b/daras_ai_v2/exceptions.py index c2cef5109..e97ac7190 100644 --- a/daras_ai_v2/exceptions.py +++ b/daras_ai_v2/exceptions.py @@ -1,12 +1,10 @@ -from logging import getLogger +import json +import subprocess import requests +from loguru import logger from requests import HTTPError -from daras_ai.image_input import truncate_filename - -logger = getLogger(__name__) - def raise_for_status(resp: requests.Response): """Raises :class:`HTTPError`, if one occurred.""" @@ -35,4 +33,54 @@ def raise_for_status(resp: requests.Response): def _response_preview(resp: requests.Response) -> bytes: + from daras_ai.image_input import truncate_filename + return truncate_filename(resp.content, 500, sep=b"...") + + +class UserError(Exception): + def __init__(self, message: str, sentry_level: str = "info"): + self.message = message + self.sentry_level = sentry_level + super().__init__(message) + + +class GPUError(UserError): + pass + + +FFMPEG_ERR_MSG = ( + "Unsupported File Format\n\n" + "We encountered an issue processing your file as it appears to be in a format not supported by our system or may be corrupted. " + "You can find a list of supported formats at [FFmpeg Formats](https://ffmpeg.org/general.html#File-Formats)." +) + + +def ffmpeg(*args) -> str: + return call_cmd("ffmpeg", "-hide_banner", "-y", *args, err_msg=FFMPEG_ERR_MSG) + + +def ffprobe(filename: str) -> dict: + text = call_cmd( + "ffprobe", + "-v", + "quiet", + "-print_format", + "json", + "-show_streams", + filename, + err_msg=FFMPEG_ERR_MSG, + ) + return json.loads(text) + + +def call_cmd(*args, err_msg: str = "") -> str: + logger.info("$ " + " ".join(map(str, args))) + try: + return subprocess.check_output(args, stderr=subprocess.STDOUT, text=True) + except subprocess.CalledProcessError as e: + err_msg = err_msg or f"{str(args[0]).capitalize()} Error" + try: + raise subprocess.SubprocessError(e.output) from e + except subprocess.SubprocessError as e: + raise UserError(err_msg) from e diff --git a/daras_ai_v2/field_render.py b/daras_ai_v2/field_render.py index 1b79c1673..728e90b42 100644 --- a/daras_ai_v2/field_render.py +++ b/daras_ai_v2/field_render.py @@ -4,10 +4,14 @@ def field_title_desc(model: typing.Type[BaseModel], name: str) -> str: + return "\n".join(filter(None, [field_title(model, name), field_desc(model, name)])) + + +def field_title(model: typing.Type[BaseModel], name: str) -> str: + field = model.__fields__[name] + return field.field_info.title + + +def field_desc(model: typing.Type[BaseModel], name: str) -> str: field = model.__fields__[name] - return "\n".join( - filter( - None, - [field.field_info.title, field.field_info.description or ""], - ) - ) + return field.field_info.description or "" diff --git a/daras_ai_v2/glossary.py b/daras_ai_v2/glossary.py index 2e56352da..5444135c1 100644 --- a/daras_ai_v2/glossary.py +++ b/daras_ai_v2/glossary.py @@ -1,4 +1,4 @@ -from daras_ai_v2.asr import google_translate_languages +from daras_ai_v2.asr import google_translate_target_languages from daras_ai_v2.doc_search_settings_widgets import document_uploader @@ -125,7 +125,8 @@ def get_langcodes_from_df(df: "pd.DataFrame") -> list[str]: import langcodes supported = { - langcodes.Language.get(code).language for code in google_translate_languages() + langcodes.Language.get(code).language + for code in google_translate_target_languages() } ret = [] for col in df.columns: diff --git a/daras_ai_v2/google_asr.py b/daras_ai_v2/google_asr.py new file mode 100644 index 000000000..222205f6e --- /dev/null +++ b/daras_ai_v2/google_asr.py @@ -0,0 +1,25 @@ +from daras_ai.image_input import gs_url_to_uri + + +def gcp_asr_v1(audio_url: str, language: str) -> str: + from google.cloud import speech + + client = speech.SpeechClient() + audio = speech.RecognitionAudio(uri=gs_url_to_uri(audio_url)) + config = speech.RecognitionConfig( + encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, + sample_rate_hertz=16000, + audio_channel_count=1, + language_code=language, + model="default", + ) + # Make the request + operation = client.long_running_recognize(config=config, audio=audio) + # Wait for operation to complete + response = operation.result(timeout=600) + print(response) + return "\n\n".join( + result.alternatives[0].transcript + for result in response.results + if result.alternatives + ) diff --git a/daras_ai_v2/gpu_server.py b/daras_ai_v2/gpu_server.py index 329a33fbf..33b6ad935 100644 --- a/daras_ai_v2/gpu_server.py +++ b/daras_ai_v2/gpu_server.py @@ -9,7 +9,7 @@ from daras_ai.image_input import storage_blob_for from daras_ai_v2 import settings -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import raise_for_status, GPUError from gooeysite.bg_db_conn import get_celery_result_db_safe @@ -160,7 +160,11 @@ def call_celery_task( task_name, kwargs=dict(pipeline=pipeline, inputs=inputs), queue=queue ) s = time() - ret = get_celery_result_db_safe(result) + ret = get_celery_result_db_safe(result, propagate=False) + try: + result.maybe_throw() + except Exception as e: + raise GPUError(f"Error in GPU Task {queue}:{task_name} - {e}") from e record_cost_auto( model=queue, sku=ModelSku.gpu_ms, quantity=int((time() - s) * 1000) ) diff --git a/daras_ai_v2/language_model_settings_widgets.py b/daras_ai_v2/language_model_settings_widgets.py index 80f785bd4..4083ece31 100644 --- a/daras_ai_v2/language_model_settings_widgets.py +++ b/daras_ai_v2/language_model_settings_widgets.py @@ -1,14 +1,10 @@ import gooey_ui as st -from daras_ai_v2.azure_doc_extract import azure_form_recognizer_models from daras_ai_v2.enum_selector_widget import enum_selector -from daras_ai_v2.field_render import field_title_desc from daras_ai_v2.language_model import LargeLanguageModels -def language_model_settings(show_selector=True, show_document_model=False): - from recipes.VideoBots import VideoBotsPage - +def language_model_settings(show_selector=True): st.write("##### 🔠 Language Model Settings") if show_selector: @@ -18,16 +14,6 @@ def language_model_settings(show_selector=True, show_document_model=False): key="selected_model", use_selectbox=True, ) - if show_document_model: - doc_model_descriptions = azure_form_recognizer_models() - st.selectbox( - f"###### {field_title_desc(VideoBotsPage.RequestModel, 'document_model')}", - key="document_model", - options=[None, *doc_model_descriptions], - format_func=lambda x: ( - f"{doc_model_descriptions[x]} ({x})" if x else "———" - ), - ) st.checkbox("Avoid Repetition", key="avoid_repetition") diff --git a/daras_ai_v2/lipsync_api.py b/daras_ai_v2/lipsync_api.py index 2f4e6e7cc..f8d7a879b 100644 --- a/daras_ai_v2/lipsync_api.py +++ b/daras_ai_v2/lipsync_api.py @@ -1,21 +1,27 @@ +from daras_ai_v2.exceptions import UserError from daras_ai_v2.gpu_server import call_celery_task_outfile def wav2lip(*, face: str, audio: str, pads: (int, int, int, int)) -> bytes: - return call_celery_task_outfile( - "wav2lip", - pipeline=dict( - model_id="wav2lip_gan.pth", - ), - inputs=dict( - face=face, - audio=audio, - pads=pads, - batch_size=256, - # "out_height": 480, - # "smooth": True, - # "fps": 25, - ), - content_type="video/mp4", - filename=f"gooey.ai lipsync.mp4", - )[0] + try: + return call_celery_task_outfile( + "wav2lip", + pipeline=dict( + model_id="wav2lip_gan.pth", + ), + inputs=dict( + face=face, + audio=audio, + pads=pads, + batch_size=256, + # "out_height": 480, + # "smooth": True, + # "fps": 25, + ), + content_type="video/mp4", + filename=f"gooey.ai lipsync.mp4", + )[0] + except ValueError as e: + msg = "\n\n".join(e.args).lower() + if "unsupported" in msg: + raise UserError(msg) from e diff --git a/daras_ai_v2/lipsync_settings_widgets.py b/daras_ai_v2/lipsync_settings_widgets.py index 8112c4e56..365058443 100644 --- a/daras_ai_v2/lipsync_settings_widgets.py +++ b/daras_ai_v2/lipsync_settings_widgets.py @@ -4,7 +4,7 @@ def lipsync_settings(): st.write( """ - ##### ⌖ Face Padding + ##### ⌖ Lipsync Face Padding Adjust the detected face bounding box. Often leads to improved results. Recommended to give at least 10 padding for the chin region. """ diff --git a/daras_ai_v2/prompt_vars.py b/daras_ai_v2/prompt_vars.py index fabb3928c..b6c2b4795 100644 --- a/daras_ai_v2/prompt_vars.py +++ b/daras_ai_v2/prompt_vars.py @@ -28,7 +28,7 @@ def prompt_vars_widget(*keys: str, variables_key: str = "variables"): if not (template_vars or err): return - st.write("#### ⌥ Variables") + st.write("###### ⌥ Variables") old_state = st.session_state.get(variables_key, {}) new_state = {} for name in sorted(template_vars): diff --git a/daras_ai_v2/redis_cache.py b/daras_ai_v2/redis_cache.py index 4930e5427..f339bec80 100644 --- a/daras_ai_v2/redis_cache.py +++ b/daras_ai_v2/redis_cache.py @@ -8,7 +8,6 @@ from daras_ai_v2 import settings - LOCK_TIMEOUT_SEC = 10 * 60 @@ -20,38 +19,44 @@ def get_redis_cache(): F = typing.TypeVar("F", bound=typing.Callable[..., typing.Any]) -def redis_cache_decorator(fn: F) -> F: - @wraps(fn) - def wrapper(*args, **kwargs): - # hash the args and kwargs so they are not too long - args_hash = hashlib.sha256(f"{args}{kwargs}".encode()).hexdigest() - # create a readable cache key - cache_key = f"gooey/redis-cache-decorator/v1/{fn.__name__}/{args_hash}" - # get the redis cache - redis_cache = get_redis_cache() - # lock the cache key so that only one thread can run the function - lock = redis_cache.lock( - name=os.path.join(cache_key, "lock"), timeout=LOCK_TIMEOUT_SEC - ) - try: - lock.acquire() - except redis.exceptions.LockError: - pass - try: - cache_val = redis_cache.get(cache_key) - # if the cache exists, return it - if cache_val: - return pickle.loads(cache_val) - # otherwise, run the function and cache the result - else: - result = fn(*args, **kwargs) - cache_val = pickle.dumps(result) - redis_cache.set(cache_key, cache_val) - return result - finally: +def redis_cache_decorator(fn: F = None, ex=None) -> F: + def decorator(fn: F) -> F: + @wraps(fn) + def wrapper(*args, **kwargs): + # hash the args and kwargs so they are not too long + args_hash = hashlib.sha256(f"{args}{kwargs}".encode()).hexdigest() + # create a readable cache key + cache_key = f"gooey/redis-cache-decorator/v1/{fn.__name__}/{args_hash}" + # get the redis cache + redis_cache = get_redis_cache() + # lock the cache key so that only one thread can run the function + lock = redis_cache.lock( + name=os.path.join(cache_key, "lock"), timeout=LOCK_TIMEOUT_SEC + ) try: - lock.release() + lock.acquire() except redis.exceptions.LockError: pass - - return wrapper + try: + cache_val = redis_cache.get(cache_key) + # if the cache exists, return it + if cache_val: + return pickle.loads(cache_val) + # otherwise, run the function and cache the result + else: + result = fn(*args, **kwargs) + cache_val = pickle.dumps(result) + redis_cache.set(cache_key, cache_val, ex=ex) + return result + finally: + try: + lock.release() + except redis.exceptions.LockError: + pass + + return wrapper + + if fn is None: + return decorator + else: + return decorator(fn) diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 3f6859164..8ae962cc4 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -1,10 +1,15 @@ +from contextlib import contextmanager + from app_users.models import AppUser +from daras_ai_v2 import settings from daras_ai_v2.azure_image_moderation import is_image_nsfw +from daras_ai_v2.exceptions import UserError from daras_ai_v2.functional import flatten -from daras_ai_v2 import settings from gooeysite.bg_db_conn import get_celery_result_db_safe from recipes.CompareLLM import CompareLLMPage +SAFETY_CHECKER_MSG = "Your request was rejected as a result of our safety system. Your input image may contain contents that are not allowed by our safety system." + def safety_checker(*, text: str | None = None, image: str | None = None): assert text or image, "safety_checker: at least one of text, image is required" @@ -44,15 +49,21 @@ def safety_checker_text(text_input: str): if not lines: continue if lines[-1].upper().endswith("FLAGGED"): - raise ValueError( - "Your request was rejected as a result of our safety system. Your prompt may contain text that is not allowed by our safety system." - ) + raise UserError(SAFETY_CHECKER_MSG) def safety_checker_image(image_url: str, cache: bool = False) -> None: if is_image_nsfw(image_url=image_url, cache=cache): - raise ValueError( - "Your request was rejected as a result of our safety system. " - "Your input image may contain contents that are not allowed " - "by our safety system." - ) + raise UserError(SAFETY_CHECKER_MSG) + + +@contextmanager +def capture_openai_content_policy_violation(): + import openai + + try: + yield + except openai.BadRequestError as e: + if e.response.status_code == 400 and "content_policy_violation" in e.message: + raise UserError(SAFETY_CHECKER_MSG) from e + raise diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py index 6c3f0bccb..aaf1eb59b 100644 --- a/daras_ai_v2/send_email.py +++ b/daras_ai_v2/send_email.py @@ -1,4 +1,5 @@ import smtplib +import sys import typing from email.mime.application import MIMEApplication from email.mime.multipart import MIMEMultipart @@ -12,6 +13,7 @@ from daras_ai_v2 import settings from daras_ai_v2.settings import templates from gooey_ui import UploadedFile +from routers.billing import account_url def send_reported_run_email( @@ -43,6 +45,31 @@ def send_reported_run_email( ) +def send_low_balance_email( + *, + user: AppUser, + total_credits_consumed: int, +): + recipeints = "support@gooey.ai, devs@gooey.ai" + html_body = templates.get_template("low_balance_email.html").render( + user=user, + url=account_url, + total_credits_consumed=total_credits_consumed, + settings=settings, + ) + send_email_via_postmark( + from_address=settings.SUPPORT_EMAIL, + to_address=user.email or recipeints, + bcc=recipeints, + subject="Your Gooey.AI credit balance is low", + html_body=html_body, + ) + + +is_running_pytest = "pytest" in sys.modules +pytest_outbox = [] + + def send_email_via_postmark( *, from_address: str, @@ -56,6 +83,21 @@ def send_email_via_postmark( "outbound", "gooey-ai-workflows", "announcements" ] = "outbound", ): + if is_running_pytest: + pytest_outbox.append( + dict( + from_address=from_address, + to_address=to_address, + cc=cc, + bcc=bcc, + subject=subject, + html_body=html_body, + text_body=text_body, + message_stream=message_stream, + ), + ) + return + r = requests.post( "https://api.postmarkapp.com/email", headers={ diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 545bce2aa..cfa4d818c 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -35,7 +35,7 @@ HASHIDS_SALT = config("HASHIDS_SALT", default="") ALLOWED_HOSTS = ["*"] -INTERNAL_IPS = ["127.0.0.1"] +INTERNAL_IPS = ["127.0.0.1", "localhost"] SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https") # Application definition @@ -48,6 +48,7 @@ "django.contrib.staticfiles", "bots", "django_extensions", + # "debug_toolbar", # the order matters, since we want to override the admin templates "django.forms", # needed to override admin forms "django.contrib.admin", @@ -67,6 +68,7 @@ "django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware", + # "debug_toolbar.middleware.DebugToolbarMiddleware", ] ROOT_URLCONF = "gooeysite.urls" @@ -266,6 +268,10 @@ ANON_USER_FREE_CREDITS = config("ANON_USER_FREE_CREDITS", 25, cast=int) LOGIN_USER_FREE_CREDITS = config("LOGIN_USER_FREE_CREDITS", 1000, cast=int) +LOW_BALANCE_EMAIL_CREDITS = config("LOW_BALANCE_EMAIL_CREDITS", 200, cast=int) +LOW_BALANCE_EMAIL_DAYS = config("LOW_BALANCE_EMAIL_DAYS", 7, cast=int) +LOW_BALANCE_EMAIL_ENABLED = config("LOW_BALANCE_EMAIL_ENABLED", True, cast=bool) + stripe.api_key = config("STRIPE_SECRET_KEY", None) STRIPE_ENDPOINT_SECRET = config("STRIPE_ENDPOINT_SECRET", None) @@ -295,6 +301,8 @@ REDIS_CACHE_URL = config("REDIS_CACHE_URL", "redis://localhost:6379") TWITTER_BEARER_TOKEN = config("TWITTER_BEARER_TOKEN", None) +REDIS_MODELS_CACHE_EXPIRY = 60 * 60 * 24 * 7 + GPU_CELERY_BROKER_URL = config("GPU_CELERY_BROKER_URL", "amqp://localhost:5674") GPU_CELERY_RESULT_BACKEND = config( "GPU_CELERY_RESULT_BACKEND", "redis://localhost:6374" diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index ce6333068..f68fed2ec 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -13,12 +13,15 @@ resize_img_fit, get_downscale_factor, ) -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import ( + raise_for_status, +) from daras_ai_v2.extract_face import rgb_img_to_rgba from daras_ai_v2.gpu_server import ( b64_img_decode, call_sd_multi, ) +from daras_ai_v2.safety_checker import capture_openai_content_policy_violation SD_IMG_MAX_SIZE = (768, 768) @@ -283,27 +286,29 @@ def text2img( client = OpenAI() width, height = _get_dall_e_3_img_size(width, height) - response = client.images.generate( - model=text2img_model_ids[Text2ImgModels[selected_model]], - n=1, # num_outputs, not supported yet - prompt=prompt, - response_format="b64_json", - quality=dall_e_3_quality, - style=dall_e_3_style, - size=f"{width}x{height}", - ) + with capture_openai_content_policy_violation(): + response = client.images.generate( + model=text2img_model_ids[Text2ImgModels[selected_model]], + n=1, # num_outputs, not supported yet + prompt=prompt, + response_format="b64_json", + quality=dall_e_3_quality, + style=dall_e_3_style, + size=f"{width}x{height}", + ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] case Text2ImgModels.dall_e.name: from openai import OpenAI edge = _get_dall_e_img_size(width, height) client = OpenAI() - response = client.images.generate( - n=num_outputs, - prompt=prompt, - size=f"{edge}x{edge}", - response_format="b64_json", - ) + with capture_openai_content_policy_violation(): + response = client.images.generate( + n=num_outputs, + prompt=prompt, + size=f"{edge}x{edge}", + response_format="b64_json", + ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] case _: prompt = add_prompt_prefix(prompt, selected_model) @@ -379,12 +384,13 @@ def img2img( image = resize_img_pad(init_image_bytes, (edge, edge)) client = OpenAI() - response = client.images.create_variation( - image=image, - n=num_outputs, - size=f"{edge}x{edge}", - response_format="b64_json", - ) + with capture_openai_content_policy_violation(): + response = client.images.create_variation( + image=image, + n=num_outputs, + size=f"{edge}x{edge}", + response_format="b64_json", + ) out_imgs = [ resize_img_fit(b64_img_decode(part.b64_json), (width, height)) @@ -503,13 +509,14 @@ def inpainting( image = rgb_img_to_rgba(edit_image_bytes, mask_bytes) client = OpenAI() - response = client.images.edit( - prompt=prompt, - image=image, - n=num_outputs, - size=f"{edge}x{edge}", - response_format="b64_json", - ) + with capture_openai_content_policy_violation(): + response = client.images.edit( + prompt=prompt, + image=image, + n=num_outputs, + size=f"{edge}x{edge}", + response_format="b64_json", + ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] case InpaintingModels.sd_2.name | InpaintingModels.runway_ml.name: diff --git a/daras_ai_v2/text_to_speech_settings_widgets.py b/daras_ai_v2/text_to_speech_settings_widgets.py index 499b502d2..dad825955 100644 --- a/daras_ai_v2/text_to_speech_settings_widgets.py +++ b/daras_ai_v2/text_to_speech_settings_widgets.py @@ -4,6 +4,7 @@ from google.cloud import texttospeech import gooey_ui as st +from daras_ai_v2 import settings from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.redis_cache import redis_cache_decorator @@ -26,9 +27,9 @@ class TextToSpeechProviders(Enum): - GOOGLE_TTS = "Google Cloud Text-to-Speech" + GOOGLE_TTS = "Google Text-to-Speech" ELEVEN_LABS = "Eleven Labs" - UBERDUCK = "uberduck.ai" + UBERDUCK = "Uberduck.ai" BARK = "Bark (suno-ai)" @@ -141,246 +142,253 @@ class TextToSpeechProviders(Enum): } -def text_to_speech_settings(page): - st.write( - """ - ##### 🗣️ Voice Settings - """ - ) - +def text_to_speech_provider_selector(page): col1, col2 = st.columns(2) with col1: tts_provider = enum_selector( TextToSpeechProviders, "###### Speech Provider", key="tts_provider", + use_selectbox=True, ) - + with col2: + match tts_provider: + case TextToSpeechProviders.BARK.name: + bark_selector() + case TextToSpeechProviders.GOOGLE_TTS.name: + google_tts_selector() + case TextToSpeechProviders.UBERDUCK.name: + uberduck_selector() + case TextToSpeechProviders.ELEVEN_LABS.name: + elevenlabs_selector(page) + return tts_provider + + +def text_to_speech_settings(page, tts_provider): match tts_provider: case TextToSpeechProviders.BARK.name: - with col2: - st.selectbox( - label=""" - ###### Bark History Prompt - """, - key="bark_history_prompt", - format_func=BARK_ALLOWED_PROMPTS.__getitem__, - options=BARK_ALLOWED_PROMPTS.keys(), - ) - + pass case TextToSpeechProviders.GOOGLE_TTS.name: - with col2: - voices = google_tts_voices() - st.selectbox( - label=""" - ###### Voice name (Google TTS) - """, - key="google_voice_name", - format_func=voices.__getitem__, - options=voices.keys(), - ) - st.caption( - "*Please refer to the list of voice names [here](https://cloud.google.com/text-to-speech/docs/voices)*" - ) - - col1, col2 = st.columns(2) - with col1: - st.slider( - """ - ###### Speaking rate - *`1.0` is the normal native speed of the speaker* - """, - min_value=0.3, - max_value=4.0, - step=0.1, - key="google_speaking_rate", - ) - with col2: - st.slider( - """ - ###### Pitch - *Increase/Decrease semitones from the original pitch* - """, - min_value=-20.0, - max_value=20.0, - step=0.25, - key="google_pitch", - ) - + google_tts_settings() case TextToSpeechProviders.UBERDUCK.name: - with col2: - st.selectbox( - label=""" - ###### Voice name (Uberduck) - """, - key="uberduck_voice_name", - format_func=lambda option: f"{option}", - options=UBERDUCK_VOICES.keys(), - ) - - col1, col2 = st.columns(2) - with col1: - st.slider( - """ - ###### Speaking rate - *`1.0` is the normal native speed of the speaker* - """, - min_value=0.5, - max_value=3.0, - step=0.25, - key="uberduck_speaking_rate", - ) - + uberduck_settings() case TextToSpeechProviders.ELEVEN_LABS.name: - with col2: - if not st.session_state.get("elevenlabs_api_key"): - st.session_state["elevenlabs_api_key"] = page.request.session.get( - SESSION_ELEVENLABS_API_KEY - ) - - elevenlabs_use_custom_key = st.checkbox( - "Use custom API key + Voice ID", - value=bool(st.session_state.get("elevenlabs_api_key")), - ) - if elevenlabs_use_custom_key: - st.session_state["elevenlabs_voice_name"] = None - elevenlabs_api_key = st.text_input( - """ - ###### Your ElevenLabs API key - *Read this - to know how to obtain an API key from - ElevenLabs.* - """, - key="elevenlabs_api_key", - ) - - selected_voice_id = st.session_state.get("elevenlabs_voice_id") - elevenlabs_voices = ( - {selected_voice_id: selected_voice_id} - if selected_voice_id - else {} - ) - - if elevenlabs_api_key: - try: - elevenlabs_voices = fetch_elevenlabs_voices( - elevenlabs_api_key - ) - except requests.exceptions.HTTPError as e: - st.error( - f"Invalid ElevenLabs API key. Failed to fetch voices: {e}" - ) - - st.selectbox( - """ - ###### Voice ID (ElevenLabs) - """, - key="elevenlabs_voice_id", - options=elevenlabs_voices.keys(), - format_func=elevenlabs_voices.__getitem__, - ) - else: - st.session_state["elevenlabs_api_key"] = None - st.session_state["elevenlabs_voice_id"] = None - if not ( - page - and ( - page.is_current_user_paying() - or page.is_current_user_admin() - ) - ): - st.caption( - """ - Note: Please purchase Gooey.AI credits to use ElevenLabs voices [here](/account). - Alternatively, you can use your own ElevenLabs API key by selecting the checkbox above. - """ - ) - - st.session_state.update( - elevenlabs_api_key=None, elevenlabs_voice_id=None - ) - st.selectbox( - """ - ###### Voice Name (ElevenLabs) - """, - key="elevenlabs_voice_name", - format_func=str, - options=ELEVEN_LABS_VOICES.keys(), - ) - - page.request.session[SESSION_ELEVENLABS_API_KEY] = st.session_state.get( - "elevenlabs_api_key" - ) - - st.selectbox( - """ - ###### Voice Model - """, - key="elevenlabs_model", - format_func=ELEVEN_LABS_MODELS.__getitem__, - options=ELEVEN_LABS_MODELS.keys(), - ) - - col1, col2 = st.columns(2) - with col1: - st.slider( - """ - ###### Stability - *A lower stability provides a broader emotional range. - A value lower than 0.3 can lead to too much instability. - [Read more](https://docs.elevenlabs.io/speech-synthesis/voice-settings#stability).* - """, - min_value=0, - max_value=1.0, - step=0.05, - key="elevenlabs_stability", - ) - with col2: - st.slider( - """ - ###### Similarity Boost - *Dictates how hard the model should try to replicate the original voice. - [Read more](https://docs.elevenlabs.io/speech-synthesis/voice-settings#similarity).* - """, - min_value=0, - max_value=1.0, - step=0.05, - key="elevenlabs_similarity_boost", - ) - - if st.session_state.get("elevenlabs_model") == "eleven_multilingual_v2": - col1, col2 = st.columns(2) - with col1: - st.slider( - """ - ###### Style Exaggeration - """, - min_value=0, - max_value=1.0, - step=0.05, - key="elevenlabs_style", - value=0.0, - ) - with col2: - st.checkbox( - "Speaker Boost", - key="elevenlabs_speaker_boost", - value=True, - ) - - with st.expander( - "Eleven Labs Supported Languages", - style={"fontSize": "0.9rem", "textDecoration": "underline"}, - ): - st.caption( - "With Multilingual V2 voice model", style={"fontSize": "0.8rem"} - ) - st.caption( - ", ".join(ELEVEN_LABS_SUPPORTED_LANGS), style={"fontSize": "0.8rem"} - ) - - -@redis_cache_decorator + elevenlabs_settings() + + +def bark_selector(): + st.selectbox( + label=""" + ###### Bark History Prompt + """, + key="bark_history_prompt", + format_func=BARK_ALLOWED_PROMPTS.__getitem__, + options=BARK_ALLOWED_PROMPTS.keys(), + ) + + +def google_tts_selector(): + voices = google_tts_voices() + st.selectbox( + label=""" + ###### Voice name (Google TTS) + """, + key="google_voice_name", + format_func=voices.__getitem__, + options=voices.keys(), + ) + st.caption( + "*Please refer to the list of voice names [here](https://cloud.google.com/text-to-speech/docs/voices)*" + ) + + +def google_tts_settings(): + st.write(f"##### 🗣️ {TextToSpeechProviders.GOOGLE_TTS.value} Settings") + col1, col2 = st.columns(2) + with col1: + st.slider( + """ + ###### Speaking rate + *`1.0` is the normal native speed of the speaker* + """, + min_value=0.3, + max_value=4.0, + step=0.1, + key="google_speaking_rate", + ) + with col2: + st.slider( + """ + ###### Pitch + *Increase/Decrease semitones from the original pitch* + """, + min_value=-20.0, + max_value=20.0, + step=0.25, + key="google_pitch", + ) + + +def uberduck_selector(): + st.selectbox( + label=""" + ###### Voice name (Uberduck) + """, + key="uberduck_voice_name", + format_func=lambda option: f"{option}", + options=UBERDUCK_VOICES.keys(), + ) + + +def uberduck_settings(): + st.write(f"##### 🗣️ {TextToSpeechProviders.UBERDUCK.value} Settings") + st.slider( + """ + ###### Speaking rate + *`1.0` is the normal native speed of the speaker* + """, + min_value=0.5, + max_value=3.0, + step=0.25, + key="uberduck_speaking_rate", + ) + + +def elevenlabs_selector(page): + if not st.session_state.get("elevenlabs_api_key"): + st.session_state["elevenlabs_api_key"] = page.request.session.get( + SESSION_ELEVENLABS_API_KEY + ) + + elevenlabs_use_custom_key = st.checkbox( + "Use custom API key + Voice ID", + value=bool(st.session_state.get("elevenlabs_api_key")), + ) + if elevenlabs_use_custom_key: + st.session_state["elevenlabs_voice_name"] = None + elevenlabs_api_key = st.text_input( + """ + ###### Your ElevenLabs API key + *Read this + to know how to obtain an API key from + ElevenLabs.* + """, + key="elevenlabs_api_key", + ) + + selected_voice_id = st.session_state.get("elevenlabs_voice_id") + elevenlabs_voices = ( + {selected_voice_id: selected_voice_id} if selected_voice_id else {} + ) + + if elevenlabs_api_key: + try: + elevenlabs_voices = fetch_elevenlabs_voices(elevenlabs_api_key) + except requests.exceptions.HTTPError as e: + st.error(f"Invalid ElevenLabs API key. Failed to fetch voices: {e}") + + st.selectbox( + """ + ###### Voice ID (ElevenLabs) + """, + key="elevenlabs_voice_id", + options=elevenlabs_voices.keys(), + format_func=elevenlabs_voices.__getitem__, + ) + else: + st.session_state["elevenlabs_api_key"] = None + st.session_state["elevenlabs_voice_id"] = None + if not ( + page and (page.is_current_user_paying() or page.is_current_user_admin()) + ): + st.caption( + """ + Note: Please purchase Gooey.AI credits to use ElevenLabs voices [here](/account). + Alternatively, you can use your own ElevenLabs API key by selecting the checkbox above. + """ + ) + + st.session_state.update(elevenlabs_api_key=None, elevenlabs_voice_id=None) + st.selectbox( + """ + ###### Voice Name (ElevenLabs) + """, + key="elevenlabs_voice_name", + format_func=str, + options=ELEVEN_LABS_VOICES.keys(), + ) + + page.request.session[SESSION_ELEVENLABS_API_KEY] = st.session_state.get( + "elevenlabs_api_key" + ) + + st.selectbox( + """ + ###### Voice Model + """, + key="elevenlabs_model", + format_func=ELEVEN_LABS_MODELS.__getitem__, + options=ELEVEN_LABS_MODELS.keys(), + ) + + +def elevenlabs_settings(): + col1, col2 = st.columns(2) + with col1: + st.slider( + """ + ###### Stability + *A lower stability provides a broader emotional range. + A value lower than 0.3 can lead to too much instability. + [Read more](https://docs.elevenlabs.io/speech-synthesis/voice-settings#stability).* + """, + min_value=0, + max_value=1.0, + step=0.05, + key="elevenlabs_stability", + ) + with col2: + st.slider( + """ + ###### Similarity Boost + *Dictates how hard the model should try to replicate the original voice. + [Read more](https://docs.elevenlabs.io/speech-synthesis/voice-settings#similarity).* + """, + min_value=0, + max_value=1.0, + step=0.05, + key="elevenlabs_similarity_boost", + ) + + if st.session_state.get("elevenlabs_model") == "eleven_multilingual_v2": + col1, col2 = st.columns(2) + with col1: + st.slider( + """ + ###### Style Exaggeration + """, + min_value=0, + max_value=1.0, + step=0.05, + key="elevenlabs_style", + value=0.0, + ) + with col2: + st.checkbox( + "Speaker Boost", + key="elevenlabs_speaker_boost", + value=True, + ) + + with st.expander( + "Eleven Labs Supported Languages", + style={"fontSize": "0.9rem", "textDecoration": "underline"}, + ): + st.caption("With Multilingual V2 voice model", style={"fontSize": "0.8rem"}) + st.caption(", ".join(ELEVEN_LABS_SUPPORTED_LANGS), style={"fontSize": "0.8rem"}) + + +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) def google_tts_voices() -> dict[str, str]: voices: list[texttospeech.Voice] = ( texttospeech.TextToSpeechClient().list_voices().voices diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 4a510d150..6080131ee 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -31,7 +31,7 @@ from daras_ai_v2.doc_search_settings_widgets import ( is_user_uploaded_url, ) -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import raise_for_status, call_cmd, UserError from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS from daras_ai_v2.functional import flatmap_parallel, map_parallel from daras_ai_v2.gdrive_downloader import ( @@ -258,7 +258,7 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: meta = gdrive_metadata(url_to_gdrive_file_id(f)) except HttpError as e: if e.status_code == 404: - raise FileNotFoundError( + raise UserError( f"Could not download the google doc at {f_url} " f"Please make sure to make the document public for viewing." ) from e @@ -630,17 +630,9 @@ def pandoc_to_text(f_name: str, f_bytes: bytes, to="plain") -> str: tempfile.NamedTemporaryFile("r") as outfile, ): infile.write(f_bytes) - args = [ - "pandoc", - "--standalone", - infile.name, - "--to", - to, - "--output", - outfile.name, - ] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) + call_cmd( + "pandoc", "--standalone", infile.name, "--to", to, "--output", outfile.name + ) return outfile.read() diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index a65bcab88..41d76d7ce 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -416,7 +416,7 @@ def multiselect( def selectbox( label: str, - options: typing.Sequence[T], + options: typing.Iterable[T], format_func: typing.Callable[[T], typing.Any] = _default_format, key: str = None, help: str = None, diff --git a/gooeysite/bg_db_conn.py b/gooeysite/bg_db_conn.py index 9c7680df9..0c36daca8 100644 --- a/gooeysite/bg_db_conn.py +++ b/gooeysite/bg_db_conn.py @@ -31,5 +31,7 @@ def wrapper(*args, **kwargs): @db_middleware -def get_celery_result_db_safe(result: "celery.result.AsyncResult") -> typing.Any: - return result.get(disable_sync_subtasks=False) +def get_celery_result_db_safe( + result: "celery.result.AsyncResult", **kwargs +) -> typing.Any: + return result.get(disable_sync_subtasks=False, **kwargs) diff --git a/gooeysite/urls.py b/gooeysite/urls.py index 6c3809436..fba775500 100644 --- a/gooeysite/urls.py +++ b/gooeysite/urls.py @@ -16,8 +16,9 @@ """ from django.contrib import admin -from django.urls import path +from django.urls import path, include urlpatterns = [ + # path("__debug__/", include("debug_toolbar.urls")), path("", admin.site.urls), ] diff --git a/pull_request_template.md b/pull_request_template.md index a41f54e12..38cff8c3d 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -1,10 +1,12 @@ ### Q/A checklist -- [ ] Do a code review of the changes -- [ ] Add any new dependencies to poetry & export to requirementst.txt (`poetry export -o requirements.txt`) +- [ ] Run tests after placing [fixutre.json](https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ca0f13b8-d6ed-11ee-870b-8e93953183bb/fixture.json) in your project root +```bash +ulimit -n unlimited && pytest +``` +- [ ] Do a self code review of the changes - [ ] Carefully think about the stuff that might break because of this change - [ ] The relevant pages still run when you press submit -- [ ] If you added new settings / knobs, the values get saved if you save it on the UI - [ ] The API for those pages still work (API tab) - [ ] The public API interface doesn't change if you didn't want it to (check API tab > docs page) - [ ] Do your UI changes (if applicable) look acceptable on mobile? diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 2101aece5..7e709a18d 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -59,6 +59,10 @@ def preview_image(self, state: dict) -> str | None: def preview_description(self, state: dict) -> str: return "Which language model works best your prompt? Compare your text generations across multiple large language models (LLMs) like OpenAI's evolving and latest ChatGPT engines and others like Curie, Ada, Babbage." + @classmethod + def get_example_preferred_fields(cls, state: dict) -> list[str]: + return ["input_prompt", "selected_models"] + def render_form_v2(self): st.text_area( """ diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index dc5ea1ae2..7b1f421de 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -77,6 +77,10 @@ class ResponseModel(BaseModel): typing.Literal[tuple(e.name for e in Text2ImgModels)], list[str] ] + @classmethod + def get_example_preferred_fields(cls, state: dict) -> list[str]: + return ["selected_models"] + def preview_image(self, state: dict) -> str | None: return DEFAULT_COMPARE_TEXT2IMG_META_IMG @@ -264,4 +268,4 @@ def get_raw_price(self, state: dict) -> int: total += 15 case _: total += 2 - return total + return total * state.get("num_outputs", 1) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index 35a71a12c..d732e7b0c 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -1,6 +1,5 @@ import typing import uuid -from datetime import datetime, timedelta from django.db.models import TextChoices from pydantic import BaseModel @@ -10,10 +9,10 @@ from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_selector +from daras_ai_v2.exceptions import UserError from daras_ai_v2.gpu_server import call_celery_task_outfile from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.safety_checker import safety_checker -from daras_ai_v2.tabs_widget import MenuTabs DEFAULT_DEFORUMSD_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7dc25196-93fe-11ee-9e3a-02420a0001ce/AI%20Animation%20generator.jpg.png" @@ -455,27 +454,32 @@ def run(self, state: dict): if not self.request.user.disable_safety_checker: safety_checker(text=self.preview_input(state)) - state["output_video"] = call_celery_task_outfile( - "deforum", - pipeline=dict( - model_id=AnimationModels[request.selected_model].value, - seed=request.seed, - ), - inputs=dict( - animation_mode=request.animation_mode, - animation_prompts={ - fp["frame"]: fp["prompt"] for fp in request.animation_prompts - }, - max_frames=request.max_frames, - zoom=request.zoom, - translation_x=request.translation_x, - translation_y=request.translation_y, - rotation_3d_x=request.rotation_3d_x, - rotation_3d_y=request.rotation_3d_y, - rotation_3d_z=request.rotation_3d_z, - translation_z="0:(0)", - fps=request.fps, - ), - content_type="video/mp4", - filename=f"gooey.ai animation {request.animation_prompts}.mp4", - )[0] + try: + state["output_video"] = call_celery_task_outfile( + "deforum", + pipeline=dict( + model_id=AnimationModels[request.selected_model].value, + seed=request.seed, + ), + inputs=dict( + animation_mode=request.animation_mode, + animation_prompts={ + fp["frame"]: fp["prompt"] for fp in request.animation_prompts + }, + max_frames=request.max_frames, + zoom=request.zoom, + translation_x=request.translation_x, + translation_y=request.translation_y, + rotation_3d_x=request.rotation_3d_x, + rotation_3d_y=request.rotation_3d_y, + rotation_3d_z=request.rotation_3d_z, + translation_z="0:(0)", + fps=request.fps, + ), + content_type="video/mp4", + filename=f"gooey.ai animation {request.animation_prompts}.mp4", + )[0] + except RuntimeError as e: + msg = "\n\n".join(e.args).lower() + if "key frame string not correctly formatted" in msg: + raise UserError(str(e)) from e diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 9cca97806..e9eaad6ae 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -27,7 +27,7 @@ from daras_ai_v2.base import BasePage from daras_ai_v2.doc_search_settings_widgets import document_uploader from daras_ai_v2.enum_selector_widget import enum_selector -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import raise_for_status, call_cmd from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS from daras_ai_v2.functional import ( apply_parallel, @@ -277,8 +277,11 @@ def extract_info(url: str) -> list[dict | None]: params = dict(ignoreerrors=True, check_formats=False) with yt_dlp.YoutubeDL(params) as ydl: data = ydl.extract_info(url, download=False) - entries = data.get("entries", [data]) - return [e for e in entries if e] + if data: + entries = data.get("entries", [data]) + return [e for e in entries if e] + else: + return [{"webpage_url": url, "title": "Youtube Video"}] else: # assume it's a direct link doc_meta = doc_url_to_metadata(url) @@ -326,13 +329,7 @@ def extract_info(url: str) -> list[dict | None]: def get_pdf_num_pages(f_bytes: bytes) -> int: with tempfile.NamedTemporaryFile() as infile: infile.write(f_bytes) - args = ["pdfinfo", infile.name] - print("\t$ " + " ".join(args)) - try: - output = subprocess.check_output(args, stderr=subprocess.STDOUT, text=True) - except subprocess.CalledProcessError as e: - raise ValueError(f"PDF Error: {e.output}") - output = output.lower() + output = call_cmd("pdfinfo", infile.name).lower() for line in output.splitlines(): if not line.startswith("pages:"): continue diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 6d7977c25..322cafbd9 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -81,6 +81,10 @@ class ResponseModel(BaseModel): final_prompt: str final_search_query: str | None + @classmethod + def get_example_preferred_fields(self, state: dict) -> list[str]: + return ["documents"] + def render_form_v2(self): st.text_area("#### Search Query", key="search_query") document_uploader("#### Documents") @@ -205,9 +209,10 @@ def run_v2( def get_raw_price(self, state: dict) -> float: name = state.get("selected_model") try: - return llm_price[LargeLanguageModels[name]] * 2 + unit_price = llm_price[LargeLanguageModels[name]] * 2 except KeyError: - return 10 + unit_price = 10 + return unit_price * state.get("num_outputs", 1) def render_documents(state, label="**Documents**", *, key="documents"): diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 476bd5da6..8955d5e59 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -80,6 +80,10 @@ class ResponseModel(BaseModel): prompt_tree: PromptTree | None final_prompt: str + @classmethod + def get_example_preferred_fields(cls, state: dict) -> list[str]: + return ["task_instructions", "merge_instructions"] + def preview_image(self, state: dict) -> str | None: return DEFAULT_DOC_SUMMARY_META_IMG diff --git a/recipes/EmailFaceInpainting.py b/recipes/EmailFaceInpainting.py index 4b8b50d1e..385596f6d 100644 --- a/recipes/EmailFaceInpainting.py +++ b/recipes/EmailFaceInpainting.py @@ -88,6 +88,10 @@ class ResponseModel(BaseModel): output_images: list[str] email_sent: bool = False + @classmethod + def get_example_preferred_fields(self, state: dict) -> list[str]: + return ["email_address"] + def preview_image(self, state: dict) -> str | None: return DEFAULT_EMAIL_FACE_INPAINTING_META_IMG diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index 6d787bba6..2f1b6a0b0 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -336,6 +336,8 @@ def get_raw_price(self, state: dict) -> int: selected_model = state.get("selected_model") match selected_model: case InpaintingModels.dall_e.name: - return 20 + unit_price = 20 case _: - return 5 + unit_price = 5 + + return unit_price * state.get("num_outputs", 1) diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index e2713aaa2..41a1139e4 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -72,6 +72,10 @@ class RequestModel(BaseModel): class ResponseModel(BaseModel): output_images: list[str] + @classmethod + def get_example_preferred_fields(self, state: dict) -> list[str]: + return ["text_prompt"] + def preview_image(self, state: dict) -> str | None: return DEFAULT_IMG2IMG_META_IMG @@ -202,6 +206,8 @@ def get_raw_price(self, state: dict) -> int: selected_model = state.get("selected_model") match selected_model: case Img2ImgModels.dall_e.name: - return 20 + unit_price = 20 case _: - return 5 + unit_price = 5 + + return unit_price * state.get("num_outputs", 1) diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 1fc0f6c64..4c16261ae 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -4,6 +4,7 @@ import gooey_ui as st from bots.models import Workflow +from daras_ai_v2.text_to_speech_settings_widgets import text_to_speech_provider_selector from recipes.Lipsync import LipsyncPage from recipes.TextToSpeech import TextToSpeechPage, TextToSpeechProviders from daras_ai_v2.safety_checker import safety_checker @@ -86,6 +87,7 @@ def render_form_v2(self): """, key="text_prompt", ) + text_to_speech_provider_selector(self) def validate_form_v2(self): assert st.session_state.get( diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index a1e0c2449..db12760bf 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -314,6 +314,8 @@ def get_raw_price(self, state: dict) -> int: selected_model = state.get("selected_model") match selected_model: case InpaintingModels.dall_e.name: - return 20 + unit_price = 20 case _: - return 5 + unit_price = 5 + + return unit_price * state.get("num_outputs", 1) diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 1512e4523..2e3a30002 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -20,7 +20,7 @@ ) from daras_ai_v2.base import BasePage from daras_ai_v2.descriptions import prompting101 -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import raise_for_status, UserError from daras_ai_v2.img_model_settings_widgets import ( output_resolution_setting, img_model_settings, @@ -141,6 +141,15 @@ def related_workflows(self) -> list: EmailFaceInpaintingPage, ] + @classmethod + def get_example_preferred_fields(cls, state: dict) -> list[str]: + if state.get("qr_code_file"): + return ["qr_code_file"] + elif state.get("qr_code_input_image"): + return ["qr_code_input_image"] + else: + return ["qr_code_data"] + def render_form_v2(self): st.text_area( """ @@ -687,7 +696,7 @@ def generate_and_upload_qr_code( if isinstance(qr_code_data, str): qr_code_data = qr_code_data.strip() if not qr_code_data: - raise ValueError("Please provide QR Code URL, text content, or an image") + raise UserError("Please provide QR Code URL, text content, or an image") using_shortened_url = request.use_url_shortener and is_url(qr_code_data) if using_shortened_url: qr_code_data = ShortenedURL.objects.get_or_create_for_workflow( @@ -735,7 +744,7 @@ def extract_qr_code_data(img: np.ndarray) -> str: return info -class InvalidQRCode(AssertionError): +class InvalidQRCode(UserError): pass diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index 7494c992d..7b8ed1fc4 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -12,7 +12,7 @@ from daras_ai.image_input import upload_file_from_bytes, storage_blob_for from daras_ai_v2 import settings from daras_ai_v2.base import BasePage -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import raise_for_status, UserError from daras_ai_v2.gpu_server import GpuEndpoints, call_celery_task_outfile from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.text_to_speech_settings_widgets import ( @@ -21,6 +21,7 @@ ELEVEN_LABS_MODELS, text_to_speech_settings, TextToSpeechProviders, + text_to_speech_provider_selector, ) DEFAULT_TTS_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a73181ce-9457-11ee-8edd-02420a0001c7/Voice%20generators.jpg.png" @@ -81,6 +82,10 @@ class ResponseModel(BaseModel): def fallback_preivew_image(self) -> str | None: return DEFAULT_TTS_META_IMG + @classmethod + def get_example_preferred_fields(cls, state: dict) -> list[str]: + return ["tts_provider"] + def preview_description(self, state: dict) -> str: return "Input your text, pick a voice & a Text-to-Speech AI engine to create audio. Compare the best voice generators from Google, UberDuck.ai & more to add automated voices to your podcast, YouTube videos, website, or app." @@ -105,6 +110,7 @@ def render_form_v2(self): """, key="text_prompt", ) + text_to_speech_provider_selector(self) def fields_to_save(self): fields = super().fields_to_save() @@ -113,10 +119,11 @@ def fields_to_save(self): return fields def validate_form_v2(self): - assert st.session_state["text_prompt"], "Text input cannot be empty" + assert st.session_state.get("text_prompt"), "Text input cannot be empty" + assert st.session_state.get("tts_provider"), "Please select a TTS provider" def render_settings(self): - text_to_speech_settings(page=self) + text_to_speech_settings(self, st.session_state.get("tts_provider")) def get_raw_price(self, state: dict): tts_provider = self._get_tts_provider(state) @@ -254,13 +261,16 @@ def run(self, state: dict): case TextToSpeechProviders.ELEVEN_LABS: xi_api_key, is_custom_key = self._get_elevenlabs_api_key(state) - assert ( + if not ( is_custom_key or self.is_current_user_paying() or self.is_current_user_admin() - ), """ - Please purchase Gooey.AI credits to use ElevenLabs voices here. - """ + ): + raise UserError( + """ + Please purchase Gooey.AI credits to use ElevenLabs voices here. + """ + ) voice_model = self._get_elevenlabs_voice_model(state) voice_id = self._get_elevenlabs_voice_id(state) @@ -303,9 +313,10 @@ def _get_elevenlabs_voice_model(self, state: dict[str, str]): def _get_elevenlabs_voice_id(self, state: dict[str, str]): if state.get("elevenlabs_voice_id"): - assert state.get( - "elevenlabs_api_key" - ), "ElevenLabs API key is required to use a custom voice_id" + if not state.get("elevenlabs_api_key"): + raise UserError( + "ElevenLabs API key is required to use a custom voice_id" + ) return state["elevenlabs_voice_id"] else: # default to first in the mapping diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 2caa2a730..e426f31f5 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -20,6 +20,7 @@ ) from daras_ai_v2.azure_doc_extract import ( azure_form_recognizer, + azure_form_recognizer_models, ) from daras_ai_v2.base import BasePage, MenuTabs from daras_ai_v2.bot_integration_widgets import ( @@ -33,7 +34,9 @@ document_uploader, ) from daras_ai_v2.enum_selector_widget import enum_multiselect -from daras_ai_v2.field_render import field_title_desc +from daras_ai_v2.enum_selector_widget import enum_selector +from daras_ai_v2.exceptions import UserError +from daras_ai_v2.field_render import field_title_desc, field_desc from daras_ai_v2.functions import LLMTools from daras_ai_v2.glossary import glossary_input, validate_glossary_document from daras_ai_v2.language_model import ( @@ -70,6 +73,7 @@ from daras_ai_v2.text_to_speech_settings_widgets import ( TextToSpeechProviders, text_to_speech_settings, + text_to_speech_provider_selector, ) from daras_ai_v2.vector_search import DocSearchRequest from recipes.DocSearch import ( @@ -310,26 +314,81 @@ def render_description(self): def render_form_v2(self): st.text_area( """ - #### 📝 Prompt - High-level system instructions. + #### 📝 Instructions """, key="bot_script", height=300, ) + prompt_vars_widget( + "bot_script", + ) + + enum_selector( + LargeLanguageModels, + label="#### 🧠 Language Model", + key="selected_model", + use_selectbox=True, + ) document_uploader( """ -#### 📄 Documents (*optional*) -Upload documents or enter URLs to give your copilot a knowledge base. With each incoming user message, we'll search your documents via a vector DB query. -""" + #### 📄 Knowledge + Upload documents or enter URLs to give your copilot a knowledge base. With each incoming user message, we'll search your documents via a vector DB query. + """ ) - prompt_vars_widget( - "bot_script", - "task_instructions", - "query_instructions", - "keyword_instructions", - ) + st.markdown("#### Capabilities") + if st.checkbox( + "##### 🗣️ Speak Responses", + value=bool(st.session_state.get("tts_provider")), + ): + text_to_speech_provider_selector(self) + st.write("---") + enable_video = st.checkbox( + "##### 🫦 Add Lipsync Video", + value=bool(st.session_state.get("input_face")), + ) + else: + st.session_state["tts_provider"] = None + enable_video = False + if enable_video: + st.file_uploader( + """ + ###### 👩‍🦰 Input Face + Upload a video/image that contains faces to use + *Recommended - mp4 / mov / png / jpg / gif* + """, + key="input_face", + ) + st.write("---") + else: + st.session_state["input_face"] = None + + if st.checkbox( + "##### 🔠 Translation", + value=bool(st.session_state.get("user_language")), + ): + google_translate_language_selector( + f"{field_desc(self.RequestModel, 'user_language')}", + key="user_language", + ) + st.write("---") + + if st.checkbox( + "##### 🩻 Photo & Document Intelligence", + value=bool( + st.session_state.get("document_model"), + ), + ): + doc_model_descriptions = azure_form_recognizer_models() + st.selectbox( + f"{field_desc(self.RequestModel, 'document_model')}", + key="document_model", + options=[None, *doc_model_descriptions], + format_func=lambda x: ( + f"{doc_model_descriptions[x]} ({x})" if x else "———" + ), + ) def validate_form_v2(self): input_glossary = st.session_state.get("input_glossary_document", "") @@ -343,9 +402,43 @@ def render_usage_guide(self): youtube_video("-j2su1r8pEg") def render_settings(self): - if st.session_state.get("documents") or st.session_state.get( - "__documents_files" - ): + tts_provider = st.session_state.get("tts_provider") + if tts_provider: + text_to_speech_settings(self, tts_provider) + + input_face = st.session_state.get("__enable_video") + if input_face: + lipsync_settings() + + if st.session_state.get("user_language"): + st.markdown("##### 🔠 Translation Settings") + enable_glossary = st.checkbox( + "📖 Add Glossary", + value=bool( + st.session_state.get("input_glossary_document") + or st.session_state.get("output_glossary_document") + ), + ) + if enable_glossary: + st.caption( + """ + Provide a glossary to customize translation and improve accuracy of domain-specific terms. + If not specified or invalid, no glossary will be used. Read about the expected format [here](https://docs.google.com/document/d/1TwzAvFmFYekloRKql2PXNPIyqCbsHRL8ZtnWkzAYrh8/edit?usp=sharing). + """ + ) + glossary_input( + f"##### {field_title_desc(self.RequestModel, 'input_glossary_document')}", + key="input_glossary_document", + ) + glossary_input( + f"##### {field_title_desc(self.RequestModel, 'output_glossary_document')}", + key="output_glossary_document", + ) + else: + st.session_state["input_glossary_document"] = None + st.session_state["output_glossary_document"] = None + + if st.session_state.get("documents"): st.text_area( """ ##### 👩‍🏫 Document Search Results Instructions @@ -354,6 +447,9 @@ def render_settings(self): key="task_instructions", height=300, ) + prompt_vars_widget( + "task_instructions", + ) st.write("---") st.checkbox("🔗 Shorten Citation URLs", key="use_url_shortener") @@ -364,71 +460,8 @@ def render_settings(self): doc_search_settings(keyword_instructions_allowed=True) st.write("---") - language_model_settings(show_document_model=True) - - st.write("---") - google_translate_language_selector( - f"##### {field_title_desc(self.RequestModel, 'user_language')}", - key="user_language", - ) - enable_glossary = st.checkbox( - "📖 Customize with Glossary", - value=bool( - st.session_state.get("input_glossary_document") - or st.session_state.get("output_glossary_document") - ), - ) - st.markdown( - """ - Provide a glossary to customize translation and improve accuracy of domain-specific terms. - If not specified or invalid, no glossary will be used. Read about the expected format [here](https://docs.google.com/document/d/1TwzAvFmFYekloRKql2PXNPIyqCbsHRL8ZtnWkzAYrh8/edit?usp=sharing). - """ - ) - if enable_glossary: - glossary_input( - f"##### {field_title_desc(self.RequestModel, 'input_glossary_document')}", - key="input_glossary_document", - ) - glossary_input( - f"##### {field_title_desc(self.RequestModel, 'output_glossary_document')}", - key="output_glossary_document", - ) - else: - st.session_state["input_glossary_document"] = None - st.session_state["output_glossary_document"] = None - st.write("---") - - if not "__enable_audio" in st.session_state: - st.session_state["__enable_audio"] = bool( - st.session_state.get("tts_provider") - ) - enable_audio = st.checkbox("Enable Audio Output?", key="__enable_audio") - if not enable_audio: - st.write("---") - st.session_state["tts_provider"] = None - else: - text_to_speech_settings(page=self) - - st.write("---") - if not "__enable_video" in st.session_state: - st.session_state["__enable_video"] = bool( - st.session_state.get("input_face") - ) - enable_video = st.checkbox("Enable Video Output?", key="__enable_video") - if not enable_video: - st.session_state["input_face"] = None - else: - st.file_uploader( - """ - #### 👩‍🦰 Input Face - Upload a video/image that contains faces to use - *Recommended - mp4 / mov / png / jpg / gif* - """, - key="input_face", - ) - lipsync_settings() + language_model_settings(show_selector=False) - st.write("---") enum_multiselect( enum_cls=LLMTools, label="##### " + field_title_desc(self.RequestModel, "tools"), @@ -598,32 +631,42 @@ def get_raw_price(self, state: dict): "raw_tts_text", state.get("raw_output_text", []) ) tts_state = {"text_prompt": "".join(output_text_list)} - return super().get_raw_price(state) + TextToSpeechPage().get_raw_price( + total = super().get_raw_price(state) + TextToSpeechPage().get_raw_price( tts_state ) case _: - return super().get_raw_price(state) + total = super().get_raw_price(state) + + return total * state.get("num_outputs", 1) def additional_notes(self): tts_provider = st.session_state.get("tts_provider") match tts_provider: case TextToSpeechProviders.ELEVEN_LABS.name: - return f""" - - *Base cost = {super().get_raw_price(st.session_state)} credits* - - *Additional {TextToSpeechPage().additional_notes()}* - """ + return ( + f" \\\n" + f"*Base cost = {super().get_raw_price(st.session_state)} credits*" + f" | " + f"*Additional {TextToSpeechPage().get_cost_note()}*" + ) case _: return "" def run(self, state: dict) -> typing.Iterator[str | None]: request: VideoBotsPage.RequestModel = self.RequestModel.parse_obj(state) - if state.get("tts_provider") == TextToSpeechProviders.ELEVEN_LABS.name: - assert ( - self.is_current_user_paying() or self.is_current_user_admin() - ), """ + if state.get("tts_provider") == TextToSpeechProviders.ELEVEN_LABS.name and not ( + self.is_current_user_paying() or self.is_current_user_admin() + ): + raise UserError( + """ Please purchase Gooey.AI credits to use ElevenLabs voices here. """ + ) + + state.update( + dict(final_prompt=[], output_text=[], output_audio=[], output_video=[]) + ) user_input = request.input_prompt.strip() if not (user_input or request.input_images or request.input_documents): @@ -805,7 +848,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) max_allowed_tokens = min(max_allowed_tokens, request.max_tokens) if max_allowed_tokens < 0: - raise ValueError("Input Script is too long! Please reduce the script size.") + raise UserError("Input Script is too long! Please reduce the script size.") yield f"Summarizing with {model.value}..." if is_chat_model: @@ -889,9 +932,6 @@ def run(self, state: dict) -> typing.Iterator[str | None]: else: yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." - state["output_audio"] = [] - state["output_video"] = [] - if not request.tts_provider: return tts_state = dict(state) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index a8fd4e541..4781d4b40 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -33,7 +33,7 @@ TruncYear, Concat, ) -from django.db.models import Count, Avg +from django.db.models import Count, Avg, Q ID_COLUMNS = [ "conversation__fb_page_id", @@ -543,6 +543,7 @@ def calculate_stats_binned_by_time( df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] df.fillna(0, inplace=True) df = df.round(0).astype("int32", errors="ignore") + df = df.sort_values(by=["date"], ascending=True).reset_index() return df def plot_graphs(self, view, df): @@ -806,8 +807,6 @@ def get_tabular_data( neg_feedbacks: FeedbackQuerySet = Feedback.objects.filter( message__conversation__bot_integration=bi, rating=Feedback.Rating.RATING_THUMBS_DOWN, - created_at__date__gte=start_date, - created_at__date__lte=end_date, ) # type: ignore if start_date and end_date: neg_feedbacks = neg_feedbacks.filter( @@ -818,10 +817,9 @@ def get_tabular_data( df["Bot"] = bi.name elif details == "Answered Successfully": successful_messages: MessageQuerySet = Message.objects.filter( + Q(analysis_result__contains={"Answered": True}) + | Q(analysis_result__contains={"assistant": {"answer": "Found"}}), conversation__bot_integration=bi, - analysis_result__contains={"Answered": True}, - created_at__date__gte=start_date, - created_at__date__lte=end_date, ) # type: ignore if start_date and end_date: successful_messages = successful_messages.filter( @@ -832,10 +830,9 @@ def get_tabular_data( df["Bot"] = bi.name elif details == "Answered Unsuccessfully": unsuccessful_messages: MessageQuerySet = Message.objects.filter( + Q(analysis_result__contains={"Answered": False}) + | Q(analysis_result__contains={"assistant": {"answer": "Missing"}}), conversation__bot_integration=bi, - analysis_result__contains={"Answered": False}, - created_at__date__gte=start_date, - created_at__date__lte=end_date, ) # type: ignore if start_date and end_date: unsuccessful_messages = unsuccessful_messages.filter( diff --git a/recipes/asr.py b/recipes/asr.py index 7dcea4249..23d6b51f4 100644 --- a/recipes/asr.py +++ b/recipes/asr.py @@ -48,6 +48,10 @@ class ResponseModel(BaseModel): raw_output_text: list[str] | None output_text: list[str | AsrOutputJson] + @classmethod + def get_example_preferred_fields(cls, state: dict) -> list[str]: + return ["selected_model", "language", "google_translate_target"] + def preview_image(self, state: dict) -> str | None: return DEFAULT_ASR_META_IMG @@ -139,6 +143,7 @@ def run(self, state: dict): output_format=request.output_format, ), request.documents, + max_workers=4, ) # Run Translation diff --git a/routers/billing.py b/routers/billing.py index 5f2f2cb06..39dfa61fb 100644 --- a/routers/billing.py +++ b/routers/billing.py @@ -114,12 +114,14 @@ def account(request: Request): is_admin = request.user.email in settings.ADMIN_EMAILS context = { + "title": "Account • Gooey.AI", "request": request, "settings": settings, "available_subscriptions": available_subscriptions, "user_credits": request.user.balance, "subscription": get_user_subscription(request.user), "is_admin": is_admin, + "canonical_url": account_url, } return templates.TemplateResponse("account.html", context) diff --git a/routers/facebook_api.py b/routers/facebook_api.py index 41679b72b..93109d62b 100644 --- a/routers/facebook_api.py +++ b/routers/facebook_api.py @@ -136,7 +136,9 @@ def fb_connect_redirect(request: Request): ) user_access_token = _get_access_token_from_code(code, fb_connect_redirect_url) - db.get_user_doc_ref(request.user.uid).update({"fb_access_token": user_access_token}) + db.get_user_doc_ref(request.user.uid).set( + {"fb_access_token": user_access_token}, merge=True + ) fb_pages = get_currently_connected_fb_pages(user_access_token) if not fb_pages: diff --git a/routers/root.py b/routers/root.py index e5f27edc6..27cb62934 100644 --- a/routers/root.py +++ b/routers/root.py @@ -1,6 +1,5 @@ import datetime import os.path -import subprocess import tempfile import typing from time import time @@ -27,13 +26,11 @@ from daras_ai_v2.all_pages import all_api_pages, normalize_slug, page_slug_map from daras_ai_v2.api_examples_widget import api_example_generator from daras_ai_v2.asr import FFMPEG_WAV_ARGS, check_wav_audio_format -from daras_ai_v2.base import ( - RedirectException, - get_example_request_body, -) +from daras_ai_v2.base import RedirectException from daras_ai_v2.bots import request_json from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_scripts from daras_ai_v2.db import FIREBASE_SESSION_COOKIE +from daras_ai_v2.exceptions import ffmpeg, UserError from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_content import build_meta_tags, raw_build_meta_tags from daras_ai_v2.meta_preview_url import meta_preview_url @@ -186,22 +183,16 @@ def file_upload(request: Request, form_data: FormData = Depends(request_form_fil ) as infile: infile.write(data) infile.flush() - if not check_wav_audio_format(infile.name): - with tempfile.NamedTemporaryFile(suffix=".wav") as outfile: - args = [ - "ffmpeg", - "-y", - "-i", - infile.name, - *FFMPEG_WAV_ARGS, - outfile.name, - ] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) - - filename += ".wav" - content_type = "audio/wav" - data = outfile.read() + try: + if not check_wav_audio_format(infile.name): + with tempfile.NamedTemporaryFile(suffix=".wav") as outfile: + ffmpeg("-i", infile.name, *FFMPEG_WAV_ARGS, outfile.name) + + filename += ".wav" + content_type = "audio/wav" + data = outfile.read() + except UserError as e: + return Response(content=str(e), status_code=400) if content_type.startswith("image/"): with Image(blob=data) as img: @@ -305,9 +296,7 @@ def _api_docs_page(request): page = workflow.page_cls(request=request) state = page.get_root_published_run().saved_run.to_dict() - request_body = get_example_request_body( - page.RequestModel, state, include_all=include_all - ) + request_body = page.get_example_request_body(state, include_all=include_all) response_body = page.get_example_response_body( state, as_async=as_async, include_all=include_all ) diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index 6ea45c7d3..d1df62a58 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -31,7 +31,7 @@ def run(): add_model(model_ids[m], m.name) except KeyError: pass - add_model("gooey-gpu/wav2lip_gan.pth", "wav2lip") + add_model("wav2lip_gan.pth", "wav2lip") def add_model(model_id, model_name): diff --git a/templates/account.html b/templates/account.html index 2755facc1..6c14bc501 100644 --- a/templates/account.html +++ b/templates/account.html @@ -1,5 +1,11 @@ {% extends 'base.html' %} +{% block head %} + + + +{% endblock %} + {% block content %} -{% endblock content %} \ No newline at end of file +{% endblock content %} diff --git a/templates/base.html b/templates/base.html index e2c180de7..5af87990a 100644 --- a/templates/base.html +++ b/templates/base.html @@ -7,7 +7,8 @@ - {{ title }} + {% block title %}{{ title }}{% endblock title %} + {% block head %}{% endblock head %} diff --git a/templates/login_options.html b/templates/login_options.html index c287f13f0..4825e7451 100644 --- a/templates/login_options.html +++ b/templates/login_options.html @@ -1,7 +1,9 @@ {% extends 'base.html' %} +{% block title %}Login to Gooey.AI{% endblock title %} + {% block head %} - Login - Gooey.AI + {% endblock %} diff --git a/templates/low_balance_email.html b/templates/low_balance_email.html new file mode 100644 index 000000000..ca4a65e61 --- /dev/null +++ b/templates/low_balance_email.html @@ -0,0 +1,17 @@ +

+ Hey {{ user.display_name }}! +

+ +

+ This is a friendly reminder that your Gooey.AI balance is now just {{ user.balance }}. + Your account has consumed {{ total_credits_consumed }} credits in the last {{ settings.LOW_BALANCE_EMAIL_DAYS }} days. +

+ To buy more credits, please visit https://gooey.ai/account. +

+ As always, email us at sales@gooey.ai if you have any questions too. +

+ Thanks again for your business, +
+ Sean and the Gooey.AI team +

+{{ "{{{ pm:unsubscribe }}}" }} \ No newline at end of file diff --git a/templates/report_email.html b/templates/report_email.html index 098aca9bb..9e3afbf6a 100644 --- a/templates/report_email.html +++ b/templates/report_email.html @@ -19,6 +19,8 @@
Reason for Report: {{ reason_for_report }}

+ In the meantime, you can go through our docs or our video tutorials. +

{% if run_uid != user.uid %} Creator User ID: {{ run_uid }}
diff --git a/tests/test_apis.py b/tests/test_apis.py index 96c27c7ef..ddd6e0e48 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -6,10 +6,7 @@ from auth.auth_backend import force_authentication from bots.models import SavedRun, Workflow from daras_ai_v2.all_pages import all_test_pages -from daras_ai_v2.base import ( - BasePage, - get_example_request_body, -) +from daras_ai_v2.base import BasePage from server import app MAX_WORKERS = 20 @@ -27,7 +24,7 @@ def _test_api_sync(page_cls: typing.Type[BasePage]): state = page_cls.recipe_doc_sr().state r = client.post( f"/v2/{page_cls.slug_versions[0]}/", - json=get_example_request_body(page_cls.RequestModel, state), + json=page_cls.get_example_request_body(state), headers={"Authorization": f"Token None"}, allow_redirects=False, ) @@ -45,7 +42,7 @@ def _test_api_async(page_cls: typing.Type[BasePage]): r = client.post( f"/v3/{page_cls.slug_versions[0]}/async/", - json=get_example_request_body(page_cls.RequestModel, state), + json=page_cls.get_example_request_body(state), headers={"Authorization": f"Token None"}, allow_redirects=False, ) @@ -81,7 +78,7 @@ def _test_apis_examples(sr: SavedRun): page_cls = Workflow(sr.workflow).page_cls r = client.post( f"/v2/{page_cls.slug_versions[0]}/?example_id={sr.example_id}", - json=get_example_request_body(page_cls.RequestModel, state), + json=page_cls.get_example_request_body(state), headers={"Authorization": f"Token None"}, allow_redirects=False, ) diff --git a/tests/test_low_balance_email_check.py b/tests/test_low_balance_email_check.py new file mode 100644 index 000000000..66203a19b --- /dev/null +++ b/tests/test_low_balance_email_check.py @@ -0,0 +1,127 @@ +from django.utils import timezone + +from app_users.models import AppUserTransaction +from bots.models import AppUser +from celeryapp.tasks import run_low_balance_email_check +from daras_ai_v2 import settings +from daras_ai_v2.send_email import pytest_outbox + + +def test_dont_send_email_if_feature_is_disabled(transactional_db): + user = AppUser.objects.create( + uid="test_user", is_paying=True, balance=0, is_anonymous=False + ) + settings.LOW_BALANCE_EMAIL_ENABLED = False + run_low_balance_email_check(user.uid) + assert not pytest_outbox + + +def test_dont_send_email_if_user_is_not_paying(transactional_db): + user = AppUser.objects.create( + uid="test_user", is_paying=False, balance=0, is_anonymous=False + ) + settings.LOW_BALANCE_EMAIL_ENABLED = True + run_low_balance_email_check(user.uid) + assert not pytest_outbox + + +def test_dont_send_email_if_user_has_enough_balance(transactional_db): + user = AppUser.objects.create( + uid="test_user", is_paying=True, balance=500, is_anonymous=False + ) + settings.LOW_BALANCE_EMAIL_CREDITS = 100 + settings.LOW_BALANCE_EMAIL_ENABLED = True + run_low_balance_email_check(user.uid) + assert not pytest_outbox + + +def test_dont_send_email_if_user_has_been_emailed_recently(transactional_db): + user = AppUser.objects.create( + uid="test_user", + is_paying=True, + balance=66, + is_anonymous=False, + low_balance_email_sent_at=timezone.now(), + ) + settings.LOW_BALANCE_EMAIL_ENABLED = True + settings.LOW_BALANCE_EMAIL_DAYS = 1 + settings.LOW_BALANCE_EMAIL_CREDITS = 100 + run_low_balance_email_check(user.uid) + assert not pytest_outbox + + +def test_send_email_if_user_has_been_email_recently_but_made_a_purchase( + transactional_db, +): + user = AppUser.objects.create( + uid="test_user", + is_paying=True, + balance=22, + is_anonymous=False, + low_balance_email_sent_at=timezone.now(), + ) + AppUserTransaction.objects.create( + invoice_id="test_invoice_1", + user=user, + amount=100, + created_at=timezone.now(), + end_balance=100, + ) + AppUserTransaction.objects.create( + invoice_id="test_invoice_2", + user=user, + amount=-78, + created_at=timezone.now(), + end_balance=22, + ) + settings.LOW_BALANCE_EMAIL_ENABLED = True + settings.LOW_BALANCE_EMAIL_DAYS = 1 + settings.LOW_BALANCE_EMAIL_CREDITS = 100 + run_low_balance_email_check(user.uid) + + assert len(pytest_outbox) == 1 + assert " 22" in pytest_outbox[0]["html_body"] + assert " 78" in pytest_outbox[0]["html_body"] + + pytest_outbox.clear() + run_low_balance_email_check(user.uid) + assert not pytest_outbox + + +def test_send_email(transactional_db): + user = AppUser.objects.create( + uid="test_user", + is_paying=True, + balance=66, + is_anonymous=False, + ) + AppUserTransaction.objects.create( + invoice_id="test_invoice_1", + user=user, + amount=-100, + created_at=timezone.now() - timezone.timedelta(days=2), + end_balance=150 + 66, + ) + AppUserTransaction.objects.create( + invoice_id="test_invoice_2", + user=user, + amount=-150, + created_at=timezone.now(), + end_balance=66, + ) + + settings.LOW_BALANCE_EMAIL_ENABLED = True + settings.LOW_BALANCE_EMAIL_DAYS = 1 + settings.LOW_BALANCE_EMAIL_CREDITS = 100 + + run_low_balance_email_check(user.uid) + assert len(pytest_outbox) == 1 + body = pytest_outbox[0]["html_body"] + assert " 66" in body + assert " 150" in body + assert " pm:unsubscribe" in body + assert " 100" not in body + + pytest_outbox.clear() + run_low_balance_email_check(user.uid) + assert not pytest_outbox diff --git a/tests/test_public_endpoints.py b/tests/test_public_endpoints.py index be18647d8..31d4abea2 100644 --- a/tests/test_public_endpoints.py +++ b/tests/test_public_endpoints.py @@ -2,7 +2,7 @@ from starlette.routing import Route from starlette.testclient import TestClient -from bots.models import SavedRun +from bots.models import PublishedRun, Workflow from daras_ai_v2.all_pages import all_api_pages from daras_ai_v2.tabs_widget import MenuTabs from routers import facebook_api @@ -31,39 +31,36 @@ @pytest.mark.django_db -@pytest.mark.parametrize("path", route_paths) -def test_all_get(path): +def test_all_get(threadpool_subtest): + for path in route_paths: + threadpool_subtest(_test_get_path, path) + + +def _test_get_path(path): r = client.get(path, allow_redirects=False) assert r.ok -page_slugs = [slug for page_cls in all_api_pages for slug in page_cls.slug_versions] -tabs = list(MenuTabs.paths.values()) +@pytest.mark.django_db +def test_all_slugs(threadpool_subtest): + for page_cls in all_api_pages: + for slug in page_cls.slug_versions: + for tab in MenuTabs.paths.values(): + url = f"/{slug}/{tab}" + threadpool_subtest(_test_post_path, url) @pytest.mark.django_db -@pytest.mark.parametrize("slug", page_slugs) -@pytest.mark.parametrize("tab", tabs) -def test_page_slugs(slug, tab): - r = client.post( - f"/{slug}/{tab}", - json={}, - allow_redirects=True, - ) - assert r.status_code == 200 +def test_all_examples(threadpool_subtest): + qs = PublishedRun.objects.exclude( + is_approved_example=False, published_run_id="" + ).order_by("workflow") + for pr in qs: + slug = Workflow(pr.workflow).page_cls.slug_versions[-1] + url = f"/{slug}?example_id={pr.published_run_id}" + threadpool_subtest(_test_post_path, url) -@pytest.mark.django_db -def test_example_slugs(subtests): - for page_cls in all_api_pages: - for tab in tabs: - for example_id in SavedRun.objects.filter( - workflow=page_cls.workflow, - hidden=False, - example_id__isnull=False, - ).values_list("example_id", flat=True): - slug = page_cls.slug_versions[0] - url = f"/{slug}/{tab}?example_id={example_id}" - with subtests.test(msg=url): - r = client.post(url, json={}, allow_redirects=True) - assert r.status_code == 200 +def _test_post_path(url): + r = client.post(url, json={}, allow_redirects=True) + assert r.status_code == 200 diff --git a/tests/test_translation.py b/tests/test_translation.py index d06b56d63..3ecd6fb24 100644 --- a/tests/test_translation.py +++ b/tests/test_translation.py @@ -1,3 +1,4 @@ +from conftest import flaky from daras_ai_v2.asr import run_google_translate @@ -14,7 +15,7 @@ # hindi ( "ान का नर्सरी खेत में रोकने के लिए कितने दिन में तैयार हो जाता है", - "in how many days does the seed nursery become ready to be planted in the field?", + "how many days does it take for the seed nursery to be ready to be planted in the field?", ), # telugu ( @@ -48,6 +49,7 @@ def test_run_google_translate(threadpool_subtest): threadpool_subtest(_test_run_google_translate_one, text, expected) +@flaky def _test_run_google_translate_one( text: str, expected: str, glossary_url=None, target_lang="en" ):