Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UserError exception type to separate server vs user errors #241

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Generated by Django 4.2.7 on 2023-12-30 00:52

from django.db import migrations, models
from bots.models import FinishReason


def forwards_func(apps, schema_editor):
saved_run = apps.get_model("bots", "SavedRun")
saved_run.objects.filter(run_status="", error_msg="").update(
finish_reason=FinishReason.DONE,
)
saved_run.objects.exclude(error_msg="").update(
finish_reason=FinishReason.SERVER_ERROR,
)


def backwards_func(apps, schema_editor):
pass


class Migration(migrations.Migration):
dependencies = [
("bots", "0053_alter_publishedrun_workflow_alter_savedrun_workflow_and_more"),
]

operations = [
migrations.AddField(
model_name="savedrun",
name="finish_reason",
field=models.IntegerField(
choices=[(1, "User Error"), (2, "Server Error"), (3, "Done")],
default=None,
null=True,
),
),
migrations.RunPython(forwards_func, backwards_func),
]
14 changes: 14 additions & 0 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
EPOCH = datetime.datetime.utcfromtimestamp(0)


class FinishReason(models.IntegerChoices):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either change this to be ErrorReason, or save RecipeRunState in db

USER_ERROR = 1
SERVER_ERROR = 2
DONE = 3


class PublishedRunVisibility(models.IntegerChoices):
UNLISTED = 1
PUBLIC = 2
Expand Down Expand Up @@ -157,6 +163,11 @@ class SavedRun(models.Model):

state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder)

finish_reason = models.IntegerField(
choices=FinishReason.choices,
null=True,
default=None,
)
error_msg = models.TextField(default="", blank=True)
run_time = models.DurationField(default=datetime.timedelta, blank=True)
run_status = models.TextField(default="", blank=True)
Expand Down Expand Up @@ -225,6 +236,8 @@ def to_dict(self) -> dict:
ret[StateKeys.created_at] = self.created_at
if self.error_msg:
ret[StateKeys.error_msg] = self.error_msg
if self.finish_reason:
ret[StateKeys.finish_reason] = self.finish_reason
if self.run_time:
ret[StateKeys.run_time] = self.run_time.total_seconds()
if self.run_status:
Expand Down Expand Up @@ -255,6 +268,7 @@ def copy_from_firebase_state(self, state: dict) -> "SavedRun":
if created_at:
self.created_at = created_at
self.error_msg = state.pop(StateKeys.error_msg, None) or ""
self.finish_reason = state.pop(StateKeys.finish_reason, None)
self.run_time = datetime.timedelta(
seconds=state.pop(StateKeys.run_time, None) or 0
)
Expand Down
62 changes: 36 additions & 26 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from types import SimpleNamespace

import sentry_sdk
import pydantic

import gooey_ui as st
from app_users.models import AppUser
from bots.models import SavedRun
from bots.models import SavedRun, FinishReason
from celeryapp.celeryconfig import app
from daras_ai.image_input import truncate_text_words
from daras_ai_v2 import settings
from daras_ai_v2.base import StateKeys, err_msg_for_exc, BasePage
from daras_ai_v2.base import BasePage, StateKeys, err_msg_for_exc
from daras_ai_v2.exceptions import UserError, raise_as_user_error
from daras_ai_v2.send_email import send_email_via_postmark
from daras_ai_v2.settings import templates
from gooey_ui.pubsub import realtime_push
Expand All @@ -27,15 +29,16 @@ def gui_runner(
uid: str,
state: dict,
channel: str,
query_params: dict = None,
query_params: dict | None = None,
is_api_call: bool = False,
):
page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id)))
sr = page.run_doc_sr(run_id, uid)

st.set_session_state(state)
run_time = 0
run_time = 0.0
yield_val = None
finish_reason = None
error_msg = None
set_query_params(query_params or {})

Expand All @@ -53,6 +56,7 @@ def save(done=False):
# set run status and run time
status = {
StateKeys.run_time: run_time,
StateKeys.finish_reason: finish_reason,
StateKeys.error_msg: error_msg,
StateKeys.run_status: run_status,
}
Expand All @@ -75,30 +79,36 @@ def save(done=False):
try:
gen = page.run(st.session_state)
save()
while True:
# record time
start_time = time()
try:
# advance the generator (to further progress of run())
yield_val = next(gen)
# increment total time taken after every iteration
run_time += time() - start_time
continue
# run completed
except StopIteration:
run_time += time() - start_time
sr.transaction, sr.price = page.deduct_credits(st.session_state)
break
start_time = time()
try:
with raise_as_user_error([pydantic.ValidationError]):
for yield_val in gen:
# increment total time taken after every iteration
run_time += time() - start_time
save()
except UserError as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except UserError,pydantic.ValidationError

# handled errors caused due to user input
run_time += time() - start_time
finish_reason = FinishReason.USER_ERROR
traceback.print_exc()
sentry_sdk.capture_exception(e)
error_msg = err_msg_for_exc(e)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just do error_msg = str(e)

except Exception as e:
# render errors nicely
except Exception as e:
run_time += time() - start_time
traceback.print_exc()
sentry_sdk.capture_exception(e)
error_msg = err_msg_for_exc(e)
break
finally:
save()
run_time += time() - start_time
finish_reason = FinishReason.SERVER_ERROR
traceback.print_exc()
sentry_sdk.capture_exception(e)
error_msg = err_msg_for_exc(e)
else:
# run completed
run_time += time() - start_time
finish_reason = FinishReason.DONE
sr.transaction, sr.price = page.deduct_credits(st.session_state)
finally:
if not finish_reason:
finish_reason = FinishReason.SERVER_ERROR
error_msg = "Something went wrong. Please try again later."
save(done=True)
if not is_api_call:
send_email_on_completion(page, sr)
Expand Down
6 changes: 5 additions & 1 deletion daras_ai/extract_face.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np


class FaceNotFoundException(ValueError):
pass


def extract_and_reposition_face_cv2(
orig_img,
*,
Expand Down Expand Up @@ -118,7 +122,7 @@ def face_oval_hull_generator(image_cv2):
results = face_mesh.process(cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB))

if not results.multi_face_landmarks:
raise ValueError("Face not found")
raise FaceNotFoundException("Face not found")

for landmark_list in results.multi_face_landmarks:
idx_to_coordinates = build_idx_to_coordinates_dict(
Expand Down
5 changes: 3 additions & 2 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import gooey_ui as st
from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri
from daras_ai_v2 import settings
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.functional import map_parallel
from daras_ai_v2.gdrive_downloader import (
is_gdrive_url,
Expand Down Expand Up @@ -570,7 +571,7 @@ def run_asr(
assert data.get("chunks"), f"{selected_model.value} can't generate VTT"
return generate_vtt(data["chunks"])
case _:
raise ValueError(f"Invalid output format: {output_format}")
raise UserError(f"Invalid output format: {output_format}")


def _get_or_create_recognizer(
Expand Down Expand Up @@ -756,7 +757,7 @@ def check_wav_audio_format(filename: str) -> bool:
data = json.loads(subprocess.check_output(args, stderr=subprocess.STDOUT))
except subprocess.CalledProcessError as e:
ffmpeg_output_error = ValueError(e.output, e)
raise ValueError(
raise UserError(
"Invalid audio file. Please confirm the file is not corrupted and has a supported format (google 'ffmpeg supported audio file types')"
) from ffmpeg_output_error
return (
Expand Down
25 changes: 23 additions & 2 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import gooey_ui as st
from app_users.models import AppUser, AppUserTransaction
from bots.models import (
FinishReason,
SavedRun,
PublishedRun,
PublishedRunVersion,
Expand All @@ -46,6 +47,7 @@
from daras_ai_v2.db import (
ANONYMOUS_USER_COOKIE,
)
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.grid_layout_widget import grid_layout
from daras_ai_v2.html_spinner_widget import html_spinner
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
Expand Down Expand Up @@ -95,6 +97,7 @@ class StateKeys:
created_at = "created_at"
updated_at = "updated_at"

finish_reason = "__finish_reason"
error_msg = "__error_msg"
run_time = "__run_time"
run_status = "__run_status"
Expand Down Expand Up @@ -1381,7 +1384,17 @@ def _render_completed_output(self):

def _render_failed_output(self):
err_msg = st.session_state.get(StateKeys.error_msg)
st.error(err_msg, unsafe_allow_html=True)
finish_reason = st.session_state.get(StateKeys.finish_reason)
self._render_error(err_msg, finish_reason)

def _render_error(self, error_msg: str, finish_reason: FinishReason):
match finish_reason:
case FinishReason.USER_ERROR:
st.warning(error_msg, unsafe_allow_html=True)
case FinishReason.SERVER_ERROR:
st.error(error_msg, unsafe_allow_html=True)
case _:
raise ValueError(f"invalid finish reason for error: {finish_reason}")

def _render_running_output(self):
run_status = st.session_state.get(StateKeys.run_status)
Expand Down Expand Up @@ -1421,6 +1434,7 @@ def on_submit(self):
st.session_state[StateKeys.error_msg] = self.generate_credit_error_message(
example_id, run_id, uid
)
st.session_state[StateKeys.finish_reason] = FinishReason.USER_ERROR
self.run_doc_sr(run_id, uid).set(self.state_to_doc(st.session_state))
else:
self.call_runner_task(example_id, run_id, uid)
Expand All @@ -1439,6 +1453,7 @@ def should_submit_after_login(self) -> bool:
def create_new_run(self):
st.session_state[StateKeys.run_status] = "Starting..."
st.session_state.pop(StateKeys.error_msg, None)
st.session_state.pop(StateKeys.finish_reason, None)
st.session_state.pop(StateKeys.run_time, None)
self._setup_rng_seed()
self.clear_outputs()
Expand Down Expand Up @@ -1529,6 +1544,7 @@ def _setup_rng_seed(self):
def clear_outputs(self):
# clear error msg
st.session_state.pop(StateKeys.error_msg, None)
st.session_state.pop(StateKeys.finish_reason, None)
# clear outputs
for field_name in self.ResponseModel.__fields__:
st.session_state.pop(field_name, None)
Expand Down Expand Up @@ -1608,6 +1624,7 @@ def fields_to_save(self) -> [str]:
for field_name in model.__fields__
] + [
StateKeys.error_msg,
StateKeys.finish_reason,
StateKeys.run_status,
StateKeys.run_time,
]
Expand Down Expand Up @@ -1711,7 +1728,9 @@ def _render_run_preview(self, saved_run: SavedRun):
if saved_run.run_status:
html_spinner(saved_run.run_status)
elif saved_run.error_msg:
st.error(saved_run.error_msg, unsafe_allow_html=True)
self._render_error(
error_msg=saved_run.error_msg, finish_reason=saved_run.finish_reason
)

return self.render_example(saved_run.to_dict())

Expand Down Expand Up @@ -1963,6 +1982,8 @@ def err_msg_for_exc(e):
return f"(GPU) {err_type}: {err_str}"
err_str = str(err_body)
return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}"
elif isinstance(e, UserError):
return str(e)
else:
return f"{type(e).__name__}: {e}"

Expand Down
14 changes: 14 additions & 0 deletions daras_ai_v2/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from contextlib import contextmanager
from typing import Type


class UserError(Exception):
pass


@contextmanager
def raise_as_user_error(excs: list[Type[Exception]]):
try:
yield
except tuple(excs) as e:
raise UserError(str(e)) from e
1 change: 1 addition & 0 deletions daras_ai_v2/face_restoration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import requests

from daras_ai.image_input import upload_file_from_bytes
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.gpu_server import call_gpu_server_b64, GpuEndpoints
from daras_ai_v2.stable_diffusion import sd_upscale

Expand Down
5 changes: 3 additions & 2 deletions daras_ai_v2/safety_checker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from app_users.models import AppUser
from daras_ai_v2.azure_image_moderation import is_image_nsfw
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.functional import flatten
from daras_ai_v2 import settings
from recipes.CompareLLM import CompareLLMPage
Expand Down Expand Up @@ -43,14 +44,14 @@ def safety_checker_text(text_input: str):
if not lines:
continue
if lines[-1].upper().endswith("FLAGGED"):
raise ValueError(
raise UserError(
"Your request was rejected as a result of our safety system. Your prompt may contain text that is not allowed by our safety system."
)


def safety_checker_image(image_url: str, cache: bool = False) -> None:
if is_image_nsfw(image_url=image_url, cache=cache):
raise ValueError(
raise UserError(
"Your request was rejected as a result of our safety system. "
"Your input image may contain contents that are not allowed "
"by our safety system."
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
guess_ext_from_response,
get_mimetype_from_response,
)
from daras_ai_v2.exceptions import UserError
from daras_ai_v2 import settings
from daras_ai_v2.asr import AsrModels, run_asr, run_google_translate
from daras_ai_v2.azure_doc_extract import (
Expand Down Expand Up @@ -270,7 +271,7 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata:
meta = gdrive_metadata(url_to_gdrive_file_id(f))
except HttpError as e:
if e.status_code == 404:
raise FileNotFoundError(
raise UserError(
f"Could not download the google doc at {f_url} "
f"Please make sure to make the document public for viewing."
) from e
Expand Down
Loading
Loading