diff --git a/bots/tasks.py b/bots/tasks.py index a6c3cebe8..4e00b4e78 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -1,4 +1,5 @@ import json +from json import JSONDecodeError from celery import shared_task from django.db.models import QuerySet @@ -65,8 +66,15 @@ def msg_analysis(msg_id: int): raise RuntimeError(sr.error_msg) # save the result as json + output_text = flatten(sr.state["output_text"].values())[0] + try: + analysis_result = json.loads(output_text) + except JSONDecodeError: + analysis_result = { + "error": "Failed to parse the analysis result. Please check your script.", + } Message.objects.filter(id=msg_id).update( - analysis_result=json.loads(flatten(sr.state["output_text"].values())[0]), + analysis_result=analysis_result, ) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 8d13a6f7d..1805ba245 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -1,8 +1,11 @@ +from fastapi import HTTPException +import html import traceback import typing from time import time from types import SimpleNamespace +import requests import sentry_sdk import gooey_ui as st @@ -11,7 +14,8 @@ from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings -from daras_ai_v2.base import StateKeys, err_msg_for_exc, BasePage +from daras_ai_v2.base import StateKeys, BasePage +from daras_ai_v2.exceptions import UserError from daras_ai_v2.send_email import send_email_via_postmark from daras_ai_v2.settings import templates from gooey_ui.pubsub import realtime_push @@ -95,8 +99,25 @@ def save(done=False): # render errors nicely except Exception as e: run_time += time() - start_time - traceback.print_exc() - sentry_sdk.capture_exception(e) + + if isinstance(e, HTTPException) and e.status_code == 402: + error_msg = page.generate_credit_error_message( + example_id=query_params.get("example_id"), + run_id=run_id, + uid=uid, + ) + try: + raise UserError(error_msg) from e + except UserError as e: + sentry_sdk.capture_exception(e, level=e.sentry_level) + break + + if isinstance(e, UserError): + sentry_level = e.sentry_level + else: + sentry_level = "error" + traceback.print_exc() + sentry_sdk.capture_exception(e, level=sentry_level) error_msg = err_msg_for_exc(e) break finally: @@ -107,6 +128,31 @@ def save(done=False): send_email_on_completion(page, sr) +def err_msg_for_exc(e: Exception): + if isinstance(e, requests.HTTPError): + response: requests.Response = e.response + try: + err_body = response.json() + except requests.JSONDecodeError: + err_str = response.text + else: + format_exc = err_body.get("format_exc") + if format_exc: + print("⚡️ " + format_exc) + err_type = err_body.get("type") + err_str = err_body.get("str") + if err_type and err_str: + return f"(GPU) {err_type}: {err_str}" + err_str = str(err_body) + return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}" + elif isinstance(e, HTTPException): + return f"(HTTP {e.status_code}) {e.detail})" + elif isinstance(e, UserError): + return e.message + else: + return f"{type(e).__name__}: {e}" + + def send_email_on_completion(page: BasePage, sr: SavedRun): run_time_sec = sr.run_time.total_seconds() if ( diff --git a/daras_ai/extract_face.py b/daras_ai/extract_face.py index e81b35d6e..aeb107a45 100644 --- a/daras_ai/extract_face.py +++ b/daras_ai/extract_face.py @@ -1,5 +1,7 @@ import numpy as np +from daras_ai_v2.exceptions import UserError + def extract_and_reposition_face_cv2( orig_img, @@ -118,7 +120,7 @@ def face_oval_hull_generator(image_cv2): results = face_mesh.process(cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)) if not results.multi_face_landmarks: - raise ValueError("Face not found") + raise UserError("Face not found") for landmark_list in results.multi_face_landmarks: idx_to_coordinates = build_idx_to_coordinates_dict( diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 81562a49f..e37fb39b0 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -14,7 +14,13 @@ import gooey_ui as st from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import ( + raise_for_status, + UserError, + ffmpeg, + call_cmd, + ffprobe, +) from daras_ai_v2.functional import map_parallel from daras_ai_v2.gdrive_downloader import ( is_gdrive_url, @@ -27,19 +33,50 @@ SHORT_FILE_CUTOFF = 5 * 1024 * 1024 # 1 MB - TRANSLITERATION_SUPPORTED = {"ar", "bn", " gu", "hi", "ja", "kn", "ru", "ta", "te"} # below CHIRP list was found experimentally since the supported languages list by google is actually wrong: -CHIRP_SUPPORTED = {"af-ZA", "sq-AL", "am-ET", "ar-EG", "hy-AM", "as-IN", "ast-ES", "az-AZ", "eu-ES", "be-BY", "bs-BA", "bg-BG", "my-MM", "ca-ES", "ceb-PH", "ckb-IQ", "zh-Hans-CN", "yue-Hant-HK", "hr-HR", "cs-CZ", "da-DK", "nl-NL", "en-AU", "en-IN", "en-GB", "en-US", "et-EE", "fil-PH", "fi-FI", "fr-CA", "fr-FR", "gl-ES", "ka-GE", "de-DE", "el-GR", "gu-IN", "ha-NG", "iw-IL", "hi-IN", "hu-HU", "is-IS", "id-ID", "it-IT", "ja-JP", "jv-ID", "kea-CV", "kam-KE", "kn-IN", "kk-KZ", "km-KH", "ko-KR", "ky-KG", "lo-LA", "lv-LV", "ln-CD", "lt-LT", "luo-KE", "lb-LU", "mk-MK", "ms-MY", "ml-IN", "mt-MT", "mi-NZ", "mr-IN", "mn-MN", "ne-NP", "ny-MW", "oc-FR", "ps-AF", "fa-IR", "pl-PL", "pt-BR", "pa-Guru-IN", "ro-RO", "ru-RU", "nso-ZA", "sr-RS", "sn-ZW", "sd-IN", "si-LK", "sk-SK", "sl-SI", "so-SO", "es-ES", "es-US", "su-ID", "sw", "sv-SE", "tg-TJ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-PK", "uz-UZ", "vi-VN", "cy-GB", "wo-SN", "yo-NG", "zu-ZA"} # fmt: skip - -WHISPER_SUPPORTED = {"af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da", "nl", "en", "et", "fi", "fr", "gl", "de", "el", "he", "hi", "hu", "is", "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi", "ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy"} # fmt: skip +CHIRP_SUPPORTED = {"af-ZA", "sq-AL", "am-ET", "ar-EG", "hy-AM", "as-IN", "ast-ES", "az-AZ", "eu-ES", "be-BY", "bs-BA", + "bg-BG", "my-MM", "ca-ES", "ceb-PH", "ckb-IQ", "zh-Hans-CN", "yue-Hant-HK", "hr-HR", "cs-CZ", + "da-DK", "nl-NL", "en-AU", "en-IN", "en-GB", "en-US", "et-EE", "fil-PH", "fi-FI", "fr-CA", "fr-FR", + "gl-ES", "ka-GE", "de-DE", "el-GR", "gu-IN", "ha-NG", "iw-IL", "hi-IN", "hu-HU", "is-IS", "id-ID", + "it-IT", "ja-JP", "jv-ID", "kea-CV", "kam-KE", "kn-IN", "kk-KZ", "km-KH", "ko-KR", "ky-KG", "lo-LA", + "lv-LV", "ln-CD", "lt-LT", "luo-KE", "lb-LU", "mk-MK", "ms-MY", "ml-IN", "mt-MT", "mi-NZ", "mr-IN", + "mn-MN", "ne-NP", "ny-MW", "oc-FR", "ps-AF", "fa-IR", "pl-PL", "pt-BR", "pa-Guru-IN", "ro-RO", + "ru-RU", "nso-ZA", "sr-RS", "sn-ZW", "sd-IN", "si-LK", "sk-SK", "sl-SI", "so-SO", "es-ES", "es-US", + "su-ID", "sw", "sv-SE", "tg-TJ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-PK", "uz-UZ", + "vi-VN", "cy-GB", "wo-SN", "yo-NG", "zu-ZA"} # fmt: skip + +WHISPER_SUPPORTED = {"af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da", "nl", "en", "et", "fi", + "fr", "gl", "de", "el", "he", "hi", "hu", "is", "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", + "mk", "ms", "mr", "mi", "ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw", + "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy"} # fmt: skip # See page 14 of https://scontent-sea1-1.xx.fbcdn.net/v/t39.2365-6/369747868_602316515432698_2401716319310287708_n.pdf?_nc_cat=106&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=_5cpNOcftdYAX8rCrVo&_nc_ht=scontent-sea1-1.xx&oh=00_AfDVkx7XubifELxmB_Un-yEYMJavBHFzPnvTbTlalbd_1Q&oe=65141B39 # For now, below are listed the languages that support ASR. Note that Seamless only accepts ISO 639-3 codes. -SEAMLESS_SUPPORTED = {"afr", "amh", "arb", "ary", "arz", "asm", "ast", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", "cmn", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "gaz", "gle", "glg", "guj", "heb", "hin", "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", "kan", "kat", "kaz", "kea", "khk", "khm", "kir", "kor", "lao", "lit", "ltz", "lug", "luo", "lvs", "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld", "nno", "nob", "npi", "nya", "oci", "ory", "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna", "snd", "som", "spa", "srp", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", "xho", "yor", "yue", "zlm", "zul"} # fmt: skip - -AZURE_SUPPORTED = {"af-ZA", "am-ET", "ar-AE", "ar-BH", "ar-DZ", "ar-EG", "ar-IL", "ar-IQ", "ar-JO", "ar-KW", "ar-LB", "ar-LY", "ar-MA", "ar-OM", "ar-PS", "ar-QA", "ar-SA", "ar-SY", "ar-TN", "ar-YE", "az-AZ", "bg-BG", "bn-IN", "bs-BA", "ca-ES", "cs-CZ", "cy-GB", "da-DK", "de-AT", "de-CH", "de-DE", "el-GR", "en-AU", "en-CA", "en-GB", "en-GH", "en-HK", "en-IE", "en-IN", "en-KE", "en-NG", "en-NZ", "en-PH", "en-SG", "en-TZ", "en-US", "en-ZA", "es-AR", "es-BO", "es-CL", "es-CO", "es-CR", "es-CU", "es-DO", "es-EC", "es-ES", "es-GQ", "es-GT", "es-HN", "es-MX", "es-NI", "es-PA", "es-PE", "es-PR", "es-PY", "es-SV", "es-US", "es-UY", "es-VE", "et-EE", "eu-ES", "fa-IR", "fi-FI", "fil-PH", "fr-BE", "fr-CA", "fr-CH", "fr-FR", "ga-IE", "gl-ES", "gu-IN", "he-IL", "hi-IN", "hr-HR", "hu-HU", "hy-AM", "id-ID", "is-IS", "it-CH", "it-IT", "ja-JP", "jv-ID", "ka-GE", "kk-KZ", "km-KH", "kn-IN", "ko-KR", "lo-LA", "lt-LT", "lv-LV", "mk-MK", "ml-IN", "mn-MN", "mr-IN", "ms-MY", "mt-MT", "my-MM", "nb-NO", "ne-NP", "nl-BE", "nl-NL", "pa-IN", "pl-PL", "ps-AF", "pt-BR", "pt-PT", "ro-RO", "ru-RU", "si-LK", "sk-SK", "sl-SI", "so-SO", "sq-AL", "sr-RS", "sv-SE", "sw-KE", "sw-TZ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-IN", "uz-UZ", "vi-VN", "wuu-CN", "yue-CN", "zh-CN", "zh-CN-shandong", "zh-CN-sichuan", "zh-HK", "zh-TW", "zu-ZA"} # fmt: skip +SEAMLESS_SUPPORTED = {"afr", "amh", "arb", "ary", "arz", "asm", "ast", "azj", "bel", "ben", "bos", "bul", "cat", "ceb", + "ces", "ckb", "cmn", "cym", "dan", "deu", "ell", "eng", "est", "eus", "fin", "fra", "gaz", "gle", + "glg", "guj", "heb", "hin", "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", + "kan", "kat", "kaz", "kea", "khk", "khm", "kir", "kor", "lao", "lit", "ltz", "lug", "luo", "lvs", + "mai", "mal", "mar", "mkd", "mlt", "mni", "mya", "nld", "nno", "nob", "npi", "nya", "oci", "ory", + "pan", "pbt", "pes", "pol", "por", "ron", "rus", "slk", "slv", "sna", "snd", "som", "spa", "srp", + "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", "xho", "yor", + "yue", "zlm", "zul"} # fmt: skip + +AZURE_SUPPORTED = {"af-ZA", "am-ET", "ar-AE", "ar-BH", "ar-DZ", "ar-EG", "ar-IL", "ar-IQ", "ar-JO", "ar-KW", "ar-LB", + "ar-LY", "ar-MA", "ar-OM", "ar-PS", "ar-QA", "ar-SA", "ar-SY", "ar-TN", "ar-YE", "az-AZ", "bg-BG", + "bn-IN", "bs-BA", "ca-ES", "cs-CZ", "cy-GB", "da-DK", "de-AT", "de-CH", "de-DE", "el-GR", "en-AU", + "en-CA", "en-GB", "en-GH", "en-HK", "en-IE", "en-IN", "en-KE", "en-NG", "en-NZ", "en-PH", "en-SG", + "en-TZ", "en-US", "en-ZA", "es-AR", "es-BO", "es-CL", "es-CO", "es-CR", "es-CU", "es-DO", "es-EC", + "es-ES", "es-GQ", "es-GT", "es-HN", "es-MX", "es-NI", "es-PA", "es-PE", "es-PR", "es-PY", "es-SV", + "es-US", "es-UY", "es-VE", "et-EE", "eu-ES", "fa-IR", "fi-FI", "fil-PH", "fr-BE", "fr-CA", "fr-CH", + "fr-FR", "ga-IE", "gl-ES", "gu-IN", "he-IL", "hi-IN", "hr-HR", "hu-HU", "hy-AM", "id-ID", "is-IS", + "it-CH", "it-IT", "ja-JP", "jv-ID", "ka-GE", "kk-KZ", "km-KH", "kn-IN", "ko-KR", "lo-LA", "lt-LT", + "lv-LV", "mk-MK", "ml-IN", "mn-MN", "mr-IN", "ms-MY", "mt-MT", "my-MM", "nb-NO", "ne-NP", "nl-BE", + "nl-NL", "pa-IN", "pl-PL", "ps-AF", "pt-BR", "pt-PT", "ro-RO", "ru-RU", "si-LK", "sk-SK", "sl-SI", + "so-SO", "sq-AL", "sr-RS", "sv-SE", "sw-KE", "sw-TZ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", + "ur-IN", "uz-UZ", "vi-VN", "wuu-CN", "yue-CN", "zh-CN", "zh-CN-shandong", "zh-CN-sichuan", "zh-HK", + "zh-TW", "zu-ZA"} # fmt: skip MAX_POLLS = 100 # https://deepgram.com/product/languages for the "general" model: @@ -575,7 +612,7 @@ def run_asr( assert data.get("chunks"), f"{selected_model.value} can't generate VTT" return generate_vtt(data["chunks"]) case _: - raise ValueError(f"Invalid output format: {output_format}") + raise UserError(f"Invalid output format: {output_format}") def _get_or_create_recognizer( @@ -683,7 +720,7 @@ def download_youtube_to_wav(youtube_url: str) -> tuple[str, int]: infile = os.path.join(tmpdir, "infile") outfile = os.path.join(tmpdir, "outfile.wav") # run yt-dlp to download audio - args = [ + call_cmd( "yt-dlp", "--no-playlist", "--format", @@ -691,13 +728,9 @@ def download_youtube_to_wav(youtube_url: str) -> tuple[str, int]: "--output", infile, youtube_url, - ] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) + ) # convert audio to single channel wav - args = ["ffmpeg", "-y", "-i", infile, *FFMPEG_WAV_ARGS, outfile] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) + ffmpeg("-i", infile, *FFMPEG_WAV_ARGS, outfile) # read wav file into memory with open(outfile, "rb") as f: wavdata = f.read() @@ -728,43 +761,12 @@ def audio_bytes_to_wav(audio_bytes: bytes) -> tuple[bytes | None, int]: with tempfile.NamedTemporaryFile(suffix=".wav") as outfile: # convert audio to single channel wav - args = [ - "ffmpeg", - "-y", - "-i", - infile.name, - *FFMPEG_WAV_ARGS, - outfile.name, - ] - print("\t$ " + " ".join(args)) - try: - subprocess.check_output(args, stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - ffmpeg_output_error = ValueError(e.output, e) - raise ValueError( - "Invalid audio file. Could not convert audio to wav format. Please confirm the file is not corrupted and has a supported format (google 'ffmpeg supported audio file types')" - ) from ffmpeg_output_error + ffmpeg("-i", infile.name, *FFMPEG_WAV_ARGS, outfile.name) return outfile.read(), os.path.getsize(outfile.name) def check_wav_audio_format(filename: str) -> bool: - args = [ - "ffprobe", - "-v", - "quiet", - "-print_format", - "json", - "-show_streams", - filename, - ] - print("\t$ " + " ".join(args)) - try: - data = json.loads(subprocess.check_output(args, stderr=subprocess.STDOUT)) - except subprocess.CalledProcessError as e: - ffmpeg_output_error = ValueError(e.output, e) - raise ValueError( - "Invalid audio file. Please confirm the file is not corrupted and has a supported format (google 'ffmpeg supported audio file types')" - ) from ffmpeg_output_error + data = ffprobe(filename) return ( len(data["streams"]) == 1 and data["streams"][0]["codec_name"] == "pcm_s16le" diff --git a/daras_ai_v2/azure_image_moderation.py b/daras_ai_v2/azure_image_moderation.py index 30da7a561..156c8fda6 100644 --- a/daras_ai_v2/azure_image_moderation.py +++ b/daras_ai_v2/azure_image_moderation.py @@ -1,7 +1,5 @@ -from typing import Any - -from furl import furl import requests +from furl import furl from daras_ai_v2 import settings from daras_ai_v2.exceptions import raise_for_status @@ -11,7 +9,7 @@ def get_auth_headers(): return {"Ocp-Apim-Subscription-Key": settings.AZURE_IMAGE_MODERATION_KEY} -def run_moderator(image_url: str, cache: bool) -> dict[str, Any]: +def is_image_nsfw(image_url: str, cache: bool = False) -> bool: url = str( furl(settings.AZURE_IMAGE_MODERATION_ENDPOINT) / "contentmoderator/moderate/v1.0/ProcessImage/Evaluate" @@ -22,10 +20,7 @@ def run_moderator(image_url: str, cache: bool) -> dict[str, Any]: headers=get_auth_headers(), json={"DataRepresentation": "URL", "Value": image_url}, ) + if r.status_code == 400 and b"Image Size Error" in r.content: + return False raise_for_status(r) - return r.json() - - -def is_image_nsfw(image_url: str, cache: bool = False) -> bool: - response = run_moderator(image_url=image_url, cache=cache) - return response["IsImageAdultClassified"] + return r.json().get("IsImageAdultClassified", False) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index d6243aeb7..0fc63ea33 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -13,7 +13,6 @@ from time import sleep from types import SimpleNamespace -import requests import sentry_sdk from django.utils import timezone from django.utils.text import slugify @@ -1949,27 +1948,6 @@ def extract_nested_str(obj) -> str: return "" -def err_msg_for_exc(e): - if isinstance(e, requests.HTTPError): - response: requests.Response = e.response - try: - err_body = response.json() - except requests.JSONDecodeError: - err_str = response.text - else: - format_exc = err_body.get("format_exc") - if format_exc: - print("⚡️ " + format_exc) - err_type = err_body.get("type") - err_str = err_body.get("str") - if err_type and err_str: - return f"(GPU) {err_type}: {err_str}" - err_str = str(err_body) - return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}" - else: - return f"{type(e).__name__}: {e}" - - def force_redirect(url: str): # note: assumes sanitized URLs st.html( diff --git a/daras_ai_v2/exceptions.py b/daras_ai_v2/exceptions.py index c2cef5109..7f684e713 100644 --- a/daras_ai_v2/exceptions.py +++ b/daras_ai_v2/exceptions.py @@ -1,12 +1,12 @@ -from logging import getLogger +import json +import subprocess import requests +from loguru import logger from requests import HTTPError from daras_ai.image_input import truncate_filename -logger = getLogger(__name__) - def raise_for_status(resp: requests.Response): """Raises :class:`HTTPError`, if one occurred.""" @@ -36,3 +36,47 @@ def raise_for_status(resp: requests.Response): def _response_preview(resp: requests.Response) -> bytes: return truncate_filename(resp.content, 500, sep=b"...") + + +class UserError(Exception): + def __init__(self, message: str, sentry_level: str = "info"): + self.message = message + self.sentry_level = sentry_level + super().__init__(message) + + +FFMPEG_ERR_MSG = ( + "Unsupported File Format\n\n" + "We encountered an issue processing your file as it appears to be in a format not supported by our system or may be corrupted. " + "You can find a list of supported formats at [FFmpeg Formats](https://ffmpeg.org/general.html#File-Formats)." +) + + +def ffmpeg(*args) -> str: + return call_cmd("ffmpeg", "-hide_banner", "-y", *args, err_msg=FFMPEG_ERR_MSG) + + +def ffprobe(filename: str) -> dict: + text = call_cmd( + "ffprobe", + "-v", + "quiet", + "-print_format", + "json", + "-show_streams", + filename, + err_msg=FFMPEG_ERR_MSG, + ) + return json.loads(text) + + +def call_cmd(*args, err_msg: str = "") -> str: + logger.info("$ " + " ".join(map(str, args))) + try: + return subprocess.check_output(args, stderr=subprocess.STDOUT, text=True) + except subprocess.CalledProcessError as e: + err_msg = err_msg or f"{str(args[0]).capitalize()} Error" + try: + raise subprocess.SubprocessError(e.output) from e + except subprocess.SubprocessError as e: + raise UserError(err_msg) from e diff --git a/daras_ai_v2/lipsync_api.py b/daras_ai_v2/lipsync_api.py index 2f4e6e7cc..f8d7a879b 100644 --- a/daras_ai_v2/lipsync_api.py +++ b/daras_ai_v2/lipsync_api.py @@ -1,21 +1,27 @@ +from daras_ai_v2.exceptions import UserError from daras_ai_v2.gpu_server import call_celery_task_outfile def wav2lip(*, face: str, audio: str, pads: (int, int, int, int)) -> bytes: - return call_celery_task_outfile( - "wav2lip", - pipeline=dict( - model_id="wav2lip_gan.pth", - ), - inputs=dict( - face=face, - audio=audio, - pads=pads, - batch_size=256, - # "out_height": 480, - # "smooth": True, - # "fps": 25, - ), - content_type="video/mp4", - filename=f"gooey.ai lipsync.mp4", - )[0] + try: + return call_celery_task_outfile( + "wav2lip", + pipeline=dict( + model_id="wav2lip_gan.pth", + ), + inputs=dict( + face=face, + audio=audio, + pads=pads, + batch_size=256, + # "out_height": 480, + # "smooth": True, + # "fps": 25, + ), + content_type="video/mp4", + filename=f"gooey.ai lipsync.mp4", + )[0] + except ValueError as e: + msg = "\n\n".join(e.args).lower() + if "unsupported" in msg: + raise UserError(msg) from e diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 3f6859164..8ae962cc4 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -1,10 +1,15 @@ +from contextlib import contextmanager + from app_users.models import AppUser +from daras_ai_v2 import settings from daras_ai_v2.azure_image_moderation import is_image_nsfw +from daras_ai_v2.exceptions import UserError from daras_ai_v2.functional import flatten -from daras_ai_v2 import settings from gooeysite.bg_db_conn import get_celery_result_db_safe from recipes.CompareLLM import CompareLLMPage +SAFETY_CHECKER_MSG = "Your request was rejected as a result of our safety system. Your input image may contain contents that are not allowed by our safety system." + def safety_checker(*, text: str | None = None, image: str | None = None): assert text or image, "safety_checker: at least one of text, image is required" @@ -44,15 +49,21 @@ def safety_checker_text(text_input: str): if not lines: continue if lines[-1].upper().endswith("FLAGGED"): - raise ValueError( - "Your request was rejected as a result of our safety system. Your prompt may contain text that is not allowed by our safety system." - ) + raise UserError(SAFETY_CHECKER_MSG) def safety_checker_image(image_url: str, cache: bool = False) -> None: if is_image_nsfw(image_url=image_url, cache=cache): - raise ValueError( - "Your request was rejected as a result of our safety system. " - "Your input image may contain contents that are not allowed " - "by our safety system." - ) + raise UserError(SAFETY_CHECKER_MSG) + + +@contextmanager +def capture_openai_content_policy_violation(): + import openai + + try: + yield + except openai.BadRequestError as e: + if e.response.status_code == 400 and "content_policy_violation" in e.message: + raise UserError(SAFETY_CHECKER_MSG) from e + raise diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index ce6333068..f68fed2ec 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -13,12 +13,15 @@ resize_img_fit, get_downscale_factor, ) -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import ( + raise_for_status, +) from daras_ai_v2.extract_face import rgb_img_to_rgba from daras_ai_v2.gpu_server import ( b64_img_decode, call_sd_multi, ) +from daras_ai_v2.safety_checker import capture_openai_content_policy_violation SD_IMG_MAX_SIZE = (768, 768) @@ -283,27 +286,29 @@ def text2img( client = OpenAI() width, height = _get_dall_e_3_img_size(width, height) - response = client.images.generate( - model=text2img_model_ids[Text2ImgModels[selected_model]], - n=1, # num_outputs, not supported yet - prompt=prompt, - response_format="b64_json", - quality=dall_e_3_quality, - style=dall_e_3_style, - size=f"{width}x{height}", - ) + with capture_openai_content_policy_violation(): + response = client.images.generate( + model=text2img_model_ids[Text2ImgModels[selected_model]], + n=1, # num_outputs, not supported yet + prompt=prompt, + response_format="b64_json", + quality=dall_e_3_quality, + style=dall_e_3_style, + size=f"{width}x{height}", + ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] case Text2ImgModels.dall_e.name: from openai import OpenAI edge = _get_dall_e_img_size(width, height) client = OpenAI() - response = client.images.generate( - n=num_outputs, - prompt=prompt, - size=f"{edge}x{edge}", - response_format="b64_json", - ) + with capture_openai_content_policy_violation(): + response = client.images.generate( + n=num_outputs, + prompt=prompt, + size=f"{edge}x{edge}", + response_format="b64_json", + ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] case _: prompt = add_prompt_prefix(prompt, selected_model) @@ -379,12 +384,13 @@ def img2img( image = resize_img_pad(init_image_bytes, (edge, edge)) client = OpenAI() - response = client.images.create_variation( - image=image, - n=num_outputs, - size=f"{edge}x{edge}", - response_format="b64_json", - ) + with capture_openai_content_policy_violation(): + response = client.images.create_variation( + image=image, + n=num_outputs, + size=f"{edge}x{edge}", + response_format="b64_json", + ) out_imgs = [ resize_img_fit(b64_img_decode(part.b64_json), (width, height)) @@ -503,13 +509,14 @@ def inpainting( image = rgb_img_to_rgba(edit_image_bytes, mask_bytes) client = OpenAI() - response = client.images.edit( - prompt=prompt, - image=image, - n=num_outputs, - size=f"{edge}x{edge}", - response_format="b64_json", - ) + with capture_openai_content_policy_violation(): + response = client.images.edit( + prompt=prompt, + image=image, + n=num_outputs, + size=f"{edge}x{edge}", + response_format="b64_json", + ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] case InpaintingModels.sd_2.name | InpaintingModels.runway_ml.name: diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 4a510d150..6080131ee 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -31,7 +31,7 @@ from daras_ai_v2.doc_search_settings_widgets import ( is_user_uploaded_url, ) -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import raise_for_status, call_cmd, UserError from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS from daras_ai_v2.functional import flatmap_parallel, map_parallel from daras_ai_v2.gdrive_downloader import ( @@ -258,7 +258,7 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: meta = gdrive_metadata(url_to_gdrive_file_id(f)) except HttpError as e: if e.status_code == 404: - raise FileNotFoundError( + raise UserError( f"Could not download the google doc at {f_url} " f"Please make sure to make the document public for viewing." ) from e @@ -630,17 +630,9 @@ def pandoc_to_text(f_name: str, f_bytes: bytes, to="plain") -> str: tempfile.NamedTemporaryFile("r") as outfile, ): infile.write(f_bytes) - args = [ - "pandoc", - "--standalone", - infile.name, - "--to", - to, - "--output", - outfile.name, - ] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) + call_cmd( + "pandoc", "--standalone", infile.name, "--to", to, "--output", outfile.name + ) return outfile.read() diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index 35a71a12c..b3c9cfbec 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -1,6 +1,5 @@ import typing import uuid -from datetime import datetime, timedelta from django.db.models import TextChoices from pydantic import BaseModel @@ -13,7 +12,6 @@ from daras_ai_v2.gpu_server import call_celery_task_outfile from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.safety_checker import safety_checker -from daras_ai_v2.tabs_widget import MenuTabs DEFAULT_DEFORUMSD_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7dc25196-93fe-11ee-9e3a-02420a0001ce/AI%20Animation%20generator.jpg.png" @@ -455,27 +453,32 @@ def run(self, state: dict): if not self.request.user.disable_safety_checker: safety_checker(text=self.preview_input(state)) - state["output_video"] = call_celery_task_outfile( - "deforum", - pipeline=dict( - model_id=AnimationModels[request.selected_model].value, - seed=request.seed, - ), - inputs=dict( - animation_mode=request.animation_mode, - animation_prompts={ - fp["frame"]: fp["prompt"] for fp in request.animation_prompts - }, - max_frames=request.max_frames, - zoom=request.zoom, - translation_x=request.translation_x, - translation_y=request.translation_y, - rotation_3d_x=request.rotation_3d_x, - rotation_3d_y=request.rotation_3d_y, - rotation_3d_z=request.rotation_3d_z, - translation_z="0:(0)", - fps=request.fps, - ), - content_type="video/mp4", - filename=f"gooey.ai animation {request.animation_prompts}.mp4", - )[0] + try: + state["output_video"] = call_celery_task_outfile( + "deforum", + pipeline=dict( + model_id=AnimationModels[request.selected_model].value, + seed=request.seed, + ), + inputs=dict( + animation_mode=request.animation_mode, + animation_prompts={ + fp["frame"]: fp["prompt"] for fp in request.animation_prompts + }, + max_frames=request.max_frames, + zoom=request.zoom, + translation_x=request.translation_x, + translation_y=request.translation_y, + rotation_3d_x=request.rotation_3d_x, + rotation_3d_y=request.rotation_3d_y, + rotation_3d_z=request.rotation_3d_z, + translation_z="0:(0)", + fps=request.fps, + ), + content_type="video/mp4", + filename=f"gooey.ai animation {request.animation_prompts}.mp4", + )[0] + except RuntimeError as e: + msg = "\n\n".join(e.args).lower() + if "key frame string not correctly formatted" in msg: + raise st.UserError(str(e)) from e diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 9cca97806..e9eaad6ae 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -27,7 +27,7 @@ from daras_ai_v2.base import BasePage from daras_ai_v2.doc_search_settings_widgets import document_uploader from daras_ai_v2.enum_selector_widget import enum_selector -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import raise_for_status, call_cmd from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS from daras_ai_v2.functional import ( apply_parallel, @@ -277,8 +277,11 @@ def extract_info(url: str) -> list[dict | None]: params = dict(ignoreerrors=True, check_formats=False) with yt_dlp.YoutubeDL(params) as ydl: data = ydl.extract_info(url, download=False) - entries = data.get("entries", [data]) - return [e for e in entries if e] + if data: + entries = data.get("entries", [data]) + return [e for e in entries if e] + else: + return [{"webpage_url": url, "title": "Youtube Video"}] else: # assume it's a direct link doc_meta = doc_url_to_metadata(url) @@ -326,13 +329,7 @@ def extract_info(url: str) -> list[dict | None]: def get_pdf_num_pages(f_bytes: bytes) -> int: with tempfile.NamedTemporaryFile() as infile: infile.write(f_bytes) - args = ["pdfinfo", infile.name] - print("\t$ " + " ".join(args)) - try: - output = subprocess.check_output(args, stderr=subprocess.STDOUT, text=True) - except subprocess.CalledProcessError as e: - raise ValueError(f"PDF Error: {e.output}") - output = output.lower() + output = call_cmd("pdfinfo", infile.name).lower() for line in output.splitlines(): if not line.startswith("pages:"): continue diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 1512e4523..b83ab4f2c 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -20,7 +20,7 @@ ) from daras_ai_v2.base import BasePage from daras_ai_v2.descriptions import prompting101 -from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.exceptions import raise_for_status, UserError from daras_ai_v2.img_model_settings_widgets import ( output_resolution_setting, img_model_settings, @@ -687,7 +687,7 @@ def generate_and_upload_qr_code( if isinstance(qr_code_data, str): qr_code_data = qr_code_data.strip() if not qr_code_data: - raise ValueError("Please provide QR Code URL, text content, or an image") + raise UserError("Please provide QR Code URL, text content, or an image") using_shortened_url = request.use_url_shortener and is_url(qr_code_data) if using_shortened_url: qr_code_data = ShortenedURL.objects.get_or_create_for_workflow( diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 2d0bc26e0..4c5d2e19b 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -32,6 +32,7 @@ document_uploader, ) from daras_ai_v2.enum_selector_widget import enum_multiselect +from daras_ai_v2.exceptions import UserError from daras_ai_v2.field_render import field_title_desc from daras_ai_v2.functions import LLMTools from daras_ai_v2.glossary import glossary_input, validate_glossary_document @@ -807,7 +808,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) max_allowed_tokens = min(max_allowed_tokens, request.max_tokens) if max_allowed_tokens < 0: - raise ValueError("Input Script is too long! Please reduce the script size.") + raise UserError("Input Script is too long! Please reduce the script size.") yield f"Summarizing with {model.value}..." if is_chat_model: diff --git a/routers/root.py b/routers/root.py index e5f27edc6..0de3d9b68 100644 --- a/routers/root.py +++ b/routers/root.py @@ -34,6 +34,7 @@ from daras_ai_v2.bots import request_json from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_scripts from daras_ai_v2.db import FIREBASE_SESSION_COOKIE +from daras_ai_v2.exceptions import ffmpeg from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_content import build_meta_tags, raw_build_meta_tags from daras_ai_v2.meta_preview_url import meta_preview_url @@ -188,16 +189,7 @@ def file_upload(request: Request, form_data: FormData = Depends(request_form_fil infile.flush() if not check_wav_audio_format(infile.name): with tempfile.NamedTemporaryFile(suffix=".wav") as outfile: - args = [ - "ffmpeg", - "-y", - "-i", - infile.name, - *FFMPEG_WAV_ARGS, - outfile.name, - ] - print("\t$ " + " ".join(args)) - subprocess.check_call(args) + ffmpeg("-i", infile.name, *FFMPEG_WAV_ARGS, outfile.name) filename += ".wav" content_type = "audio/wav"