Skip to content

Commit

Permalink
Merge branch 'master' into integration_ux
Browse files Browse the repository at this point in the history
  • Loading branch information
SanderGi committed Mar 1, 2024
2 parents 6fdf326 + 6241711 commit affd5a0
Show file tree
Hide file tree
Showing 70 changed files with 1,500 additions and 790 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ pg_restore --no-privileges --no-owner -d $PGDATABASE $fname
cid=$(docker ps | grep gooey-api-prod | cut -d " " -f 1 | head -1)
# exec the script to create the fixture
docker exec -it $cid poetry run ./manage.py runscript create_fixture
```

```bash
# copy the fixture outside container
docker cp $cid:/app/fixture.json .
# print the absolute path
Expand Down
1 change: 1 addition & 0 deletions app_users/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class AppUserAdmin(admin.ModelAdmin):
"view_transactions",
"open_in_firebase",
"open_in_stripe",
"low_balance_email_sent_at",
]

@admin.display(description="User Runs")
Expand Down
18 changes: 18 additions & 0 deletions app_users/migrations/0012_appuser_low_balance_email_sent_at.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.7 on 2024-02-14 07:23

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('app_users', '0011_appusertransaction_charged_amount_and_more'),
]

operations = [
migrations.AddField(
model_name='appuser',
name='low_balance_email_sent_at',
field=models.DateTimeField(blank=True, null=True),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Generated by Django 4.2.7 on 2024-02-28 14:16

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('app_users', '0012_appuser_low_balance_email_sent_at'),
]

operations = [
migrations.AddIndex(
model_name='appusertransaction',
index=models.Index(fields=['user', 'amount', '-created_at'], name='app_users_a_user_id_9b2e8d_idx'),
),
migrations.AddIndex(
model_name='appusertransaction',
index=models.Index(fields=['-created_at'], name='app_users_a_created_3c27fe_idx'),
),
]
12 changes: 11 additions & 1 deletion app_users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class AppUser(models.Model):
stripe_customer_id = models.CharField(max_length=255, default="", blank=True)
is_paying = models.BooleanField("paid", default=False)

low_balance_email_sent_at = models.DateTimeField(null=True, blank=True)

created_at = models.DateTimeField(
"created", editable=False, blank=True, default=timezone.now
)
Expand Down Expand Up @@ -207,7 +209,11 @@ def search_stripe_customer(self) -> stripe.Customer | None:
if not self.uid:
return None
if self.stripe_customer_id:
return stripe.Customer.retrieve(self.stripe_customer_id)
try:
return stripe.Customer.retrieve(self.stripe_customer_id)
except stripe.error.InvalidRequestError as e:
if e.http_status != 404:
raise
try:
customer = stripe.Customer.search(
query=f'metadata["uid"]:"{self.uid}"'
Expand Down Expand Up @@ -263,6 +269,10 @@ class AppUserTransaction(models.Model):

class Meta:
verbose_name = "Transaction"
indexes = [
models.Index(fields=["user", "amount", "-created_at"]),
models.Index(fields=["-created_at"]),
]

def __str__(self):
return f"{self.invoice_id} ({self.amount})"
17 changes: 13 additions & 4 deletions bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,25 @@ class SavedRunAdmin(admin.ModelAdmin):
django.db.models.JSONField: {"widget": JSONEditorWidget},
}

def get_queryset(self, request):
return (
super()
.get_queryset(request)
.prefetch_related(
"parent_version",
"parent_version__published_run",
"parent_version__published_run__saved_run",
)
)

def lookup_allowed(self, key, value):
if key in ["parent_version__published_run__id__exact"]:
return True
return super().lookup_allowed(key, value)

def view_user(self, saved_run: SavedRun):
return change_obj_url(
AppUser.objects.get(uid=saved_run.uid),
label=f"{saved_run.uid}",
)
user = AppUser.objects.get(uid=saved_run.uid)
return change_obj_url(user)

view_user.short_description = "View User"

Expand Down
18 changes: 18 additions & 0 deletions bots/migrations/0060_conversation_reset_at.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.7 on 2024-02-20 16:49

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('bots', '0059_savedrun_is_api_call'),
]

operations = [
migrations.AddField(
model_name='conversation',
name='reset_at',
field=models.DateTimeField(blank=True, default=None, null=True),
),
]
7 changes: 6 additions & 1 deletion bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,7 @@ class Conversation(models.Model):
)

created_at = models.DateTimeField(auto_now_add=True)
reset_at = models.DateTimeField(null=True, blank=True, default=None)

objects = ConversationQuerySet.as_manager()

Expand Down Expand Up @@ -1013,7 +1014,11 @@ def to_df_analysis_format(
)
return df

def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]:
def as_llm_context(
self, limit: int = 50, reset_at: datetime.datetime = None
) -> list["ConversationEntry"]:
if reset_at:
self = self.filter(created_at__gt=reset_at)
msgs = self.order_by("-created_at").prefetch_related("attachments")[:limit]
entries = [None] * len(msgs)
for i, msg in enumerate(reversed(msgs)):
Expand Down
10 changes: 9 additions & 1 deletion bots/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from json import JSONDecodeError

from celery import shared_task
from django.db.models import QuerySet
Expand Down Expand Up @@ -65,8 +66,15 @@ def msg_analysis(msg_id: int):
raise RuntimeError(sr.error_msg)

# save the result as json
output_text = flatten(sr.state["output_text"].values())[0]
try:
analysis_result = json.loads(output_text)
except JSONDecodeError:
analysis_result = {
"error": "Failed to parse the analysis result. Please check your script.",
}
Message.objects.filter(id=msg_id).update(
analysis_result=json.loads(flatten(sr.state["output_text"].values())[0]),
analysis_result=analysis_result,
)


Expand Down
107 changes: 103 additions & 4 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import datetime
import html
import traceback
import typing
from time import time
from types import SimpleNamespace

import requests
import sentry_sdk
from django.db.models import Sum
from django.utils import timezone
from fastapi import HTTPException

import gooey_ui as st
from app_users.models import AppUser
from app_users.models import AppUser, AppUserTransaction
from bots.models import SavedRun
from celeryapp.celeryconfig import app
from daras_ai.image_input import truncate_text_words
from daras_ai_v2 import settings
from daras_ai_v2.base import StateKeys, err_msg_for_exc, BasePage
from daras_ai_v2.base import StateKeys, BasePage
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.send_email import send_email_via_postmark
from daras_ai_v2.send_email import send_low_balance_email
from daras_ai_v2.settings import templates
from gooey_ui.pubsub import realtime_push
from gooey_ui.state import set_query_params
Expand All @@ -32,6 +40,17 @@ def gui_runner(
is_api_call: bool = False,
):
page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id)))

def event_processor(event, hint):
event["request"] = {
"method": "POST",
"url": page.app_url(query_params=query_params),
"data": state,
}
return event

page.setup_sentry(event_processor=event_processor)

sr = page.run_doc_sr(run_id, uid)
sr.is_api_call = is_api_call

Expand Down Expand Up @@ -95,8 +114,25 @@ def save(done=False):
# render errors nicely
except Exception as e:
run_time += time() - start_time
traceback.print_exc()
sentry_sdk.capture_exception(e)

if isinstance(e, HTTPException) and e.status_code == 402:
error_msg = page.generate_credit_error_message(
example_id=query_params.get("example_id"),
run_id=run_id,
uid=uid,
)
try:
raise UserError(error_msg) from e
except UserError as e:
sentry_sdk.capture_exception(e, level=e.sentry_level)
break

if isinstance(e, UserError):
sentry_level = e.sentry_level
else:
sentry_level = "error"
traceback.print_exc()
sentry_sdk.capture_exception(e, level=sentry_level)
error_msg = err_msg_for_exc(e)
break
finally:
Expand All @@ -105,6 +141,69 @@ def save(done=False):
save(done=True)
if not is_api_call:
send_email_on_completion(page, sr)
run_low_balance_email_check(uid)


def err_msg_for_exc(e: Exception):
if isinstance(e, requests.HTTPError):
response: requests.Response = e.response
try:
err_body = response.json()
except requests.JSONDecodeError:
err_str = response.text
else:
format_exc = err_body.get("format_exc")
if format_exc:
print("⚡️ " + format_exc)
err_type = err_body.get("type")
err_str = err_body.get("str")
if err_type and err_str:
return f"(GPU) {err_type}: {err_str}"
err_str = str(err_body)
return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}"
elif isinstance(e, HTTPException):
return f"(HTTP {e.status_code}) {e.detail})"
elif isinstance(e, UserError):
return e.message
else:
return f"{type(e).__name__}: {e}"


def run_low_balance_email_check(uid: str):
# don't send email if feature is disabled
if not settings.LOW_BALANCE_EMAIL_ENABLED:
return
user = AppUser.objects.get(uid=uid)
# don't send email if user is not paying or has enough balance
if not user.is_paying or user.balance > settings.LOW_BALANCE_EMAIL_CREDITS:
return
last_purchase = (
AppUserTransaction.objects.filter(user=user, amount__gt=0)
.order_by("-created_at")
.first()
)
email_date_cutoff = timezone.now() - datetime.timedelta(
days=settings.LOW_BALANCE_EMAIL_DAYS
)
# send email if user has not been sent email in last X days or last purchase was after last email sent
if (
# user has not been sent any email
not user.low_balance_email_sent_at
# user was sent email before X days
or (user.low_balance_email_sent_at < email_date_cutoff)
# user has made a purchase after last email sent
or (last_purchase and last_purchase.created_at > user.low_balance_email_sent_at)
):
# calculate total credits consumed in last X days
total_credits_consumed = abs(
AppUserTransaction.objects.filter(
user=user, amount__lt=0, created_at__gte=email_date_cutoff
).aggregate(Sum("amount"))["amount__sum"]
or 0
)
send_low_balance_email(user=user, total_credits_consumed=total_credits_consumed)
user.low_balance_email_sent_at = timezone.now()
user.save(update_fields=["low_balance_email_sent_at"])


def send_email_on_completion(page: BasePage, sr: SavedRun):
Expand Down
23 changes: 22 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
from auth import auth_backend
from celeryapp import app
from daras_ai_v2.base import BasePage
from daras_ai_v2.send_email import pytest_outbox


def flaky(fn):
max_tries = 5

@wraps(fn)
def wrapper(*args, **kwargs):
for i in range(max_tries):
try:
return fn(*args, **kwargs)
except Exception:
if i == max_tries - 1:
raise

return wrapper


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -44,7 +60,7 @@ def _mock_gui_runner(


@pytest.fixture
def threadpool_subtest(subtests, max_workers: int = 8):
def threadpool_subtest(subtests, max_workers: int = 128):
ts = []

def submit(fn, *args, msg=None, **kwargs):
Expand All @@ -68,6 +84,11 @@ def runner(*args, **kwargs):
t.join()


@pytest.fixture(autouse=True)
def clear_pytest_outbox():
pytest_outbox.clear()


# class DummyDatabaseBlocker(pytest_django.plugin._DatabaseBlocker):
# class _dj_db_wrapper:
# def ensure_connection(self):
Expand Down
4 changes: 3 additions & 1 deletion daras_ai/extract_face.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from daras_ai_v2.exceptions import UserError


def extract_and_reposition_face_cv2(
orig_img,
Expand Down Expand Up @@ -118,7 +120,7 @@ def face_oval_hull_generator(image_cv2):
results = face_mesh.process(cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB))

if not results.multi_face_landmarks:
raise ValueError("Face not found")
raise UserError("Face not found")

for landmark_list in results.multi_face_landmarks:
idx_to_coordinates = build_idx_to_coordinates_dict(
Expand Down
Loading

0 comments on commit affd5a0

Please sign in to comment.