Skip to content

Commit

Permalink
Merge branch 'master' into twilio
Browse files Browse the repository at this point in the history
  • Loading branch information
SanderGi committed Jul 8, 2024
2 parents 62e6519 + 714a663 commit 25a9e5d
Show file tree
Hide file tree
Showing 63 changed files with 1,243 additions and 393 deletions.
36 changes: 17 additions & 19 deletions bots/admin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import json
from types import SimpleNamespace

import django.db.models
from django import forms
Expand Down Expand Up @@ -287,6 +288,18 @@ def api_integration_stats_url(self, bi: BotIntegration):
)


@admin.register(PublishedRunVersion)
class PublishedRunVersionAdmin(admin.ModelAdmin):
search_fields = ["id", "version_id", "published_run__published_run_id"]
autocomplete_fields = ["published_run", "saved_run", "changed_by"]


class PublishedRunVersionInline(admin.TabularInline):
model = PublishedRunVersion
extra = 0
autocomplete_fields = PublishedRunVersionAdmin.autocomplete_fields


@admin.register(PublishedRun)
class PublishedRunAdmin(admin.ModelAdmin):
list_display = [
Expand All @@ -308,6 +321,7 @@ class PublishedRunAdmin(admin.ModelAdmin):
"created_at",
"updated_at",
]
inlines = [PublishedRunVersionInline]

def view_user(self, published_run: PublishedRun):
if published_run.created_by is None:
Expand Down Expand Up @@ -423,32 +437,16 @@ def view_usage_cost(self, saved_run: SavedRun):
def rerun_tasks(self, request, queryset):
sr: SavedRun
for sr in queryset.all():
page_cls = Workflow(sr.workflow).page_cls
pr = sr.parent_published_run()
gui_runner.delay(
page_cls=page_cls,
user_id=AppUser.objects.get(uid=sr.uid).id,
run_id=sr.run_id,
uid=sr.uid,
state=sr.to_dict(),
channel=page_cls.realtime_channel_name(sr.run_id, sr.uid),
query_params=page_cls.clean_query_params(
example_id=pr and pr.published_run_id, run_id=sr.run_id, uid=sr.uid
),
is_api_call=sr.is_api_call,
page = Workflow(sr.workflow).page_cls(
request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid))
)
page.call_runner_task(sr)
self.message_user(
request,
f"Started re-running {queryset.count()} tasks in the background.",
)


@admin.register(PublishedRunVersion)
class PublishedRunVersionAdmin(admin.ModelAdmin):
search_fields = ["id", "version_id", "published_run__published_run_id"]
autocomplete_fields = ["published_run", "saved_run", "changed_by"]


class LastActiveDeltaFilter(admin.SimpleListFilter):
title = Conversation.last_active_delta.short_description
parameter_name = Conversation.last_active_delta.__name__
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Generated by Django 4.2.7 on 2024-07-05 13:44

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('bots', '0075_alter_publishedrun_workflow_alter_savedrun_workflow_and_more'),
]

operations = [
migrations.AlterField(
model_name='workflowmetadata',
name='default_image',
field=models.URLField(blank=True, default='', help_text='Image shown on explore page'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='help_url',
field=models.URLField(blank=True, default='', help_text='(Not implemented)'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='meta_keywords',
field=models.JSONField(blank=True, default=list, help_text='(Not implemented)'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='short_title',
field=models.TextField(help_text='Title used in breadcrumbs'),
),
]
37 changes: 34 additions & 3 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from bots.custom_fields import PostgresJSONEncoder, CustomURLField
from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.language_model import format_chat_entry
from functions.models import CalledFunction, CalledFunctionResponse
from gooeysite.custom_create import get_or_create_lazy

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -344,23 +345,30 @@ def submit_api_call(
current_user: AppUser,
request_body: dict,
enable_rate_limits: bool = False,
parent_pr: "PublishedRun" = None,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
from routers.api import submit_api_call

# run in a thread to avoid messing up threadlocals
with ThreadPool(1) as pool:
if parent_pr and parent_pr.saved_run == self:
# avoid passing run_id and uid for examples
query_params = dict(example_id=parent_pr.published_run_id)
else:
query_params = dict(
example_id=self.example_id, run_id=self.run_id, uid=self.uid
)
page, result, run_id, uid = pool.apply(
submit_api_call,
kwds=dict(
page_cls=Workflow(self.workflow).page_cls,
query_params=dict(
example_id=self.example_id, run_id=self.run_id, uid=self.uid
),
query_params=query_params,
user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
),
)

return result, page.run_doc_sr(run_id, uid)

def get_creator(self) -> AppUser | None:
Expand All @@ -373,6 +381,15 @@ def get_creator(self) -> AppUser | None:
def open_in_gooey(self):
return open_in_new_tab(self.get_app_url(), label=self.get_app_url())

def api_output(self, state: dict = None) -> dict:
state = state or self.state
if self.state.get("functions"):
state["called_functions"] = [
CalledFunctionResponse.from_db(called_fn)
for called_fn in self.called_functions.all()
]
return state


def _parse_dt(dt) -> datetime.datetime | None:
if isinstance(dt, str):
Expand Down Expand Up @@ -1740,6 +1757,20 @@ def get_run_count(self):
or 0
)

def submit_api_call(
self,
*,
current_user: AppUser,
request_body: dict,
enable_rate_limits: bool = False,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
return self.saved_run.submit_api_call(
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
parent_pr=self,
)


class PublishedRunVersion(models.Model):
version_id = models.CharField(max_length=128, unique=True)
Expand Down
4 changes: 3 additions & 1 deletion bots/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None):

# make the api call
result, sr = analysis_sr.submit_api_call(
current_user=billing_account, request_body=dict(variables=variables)
current_user=billing_account,
request_body=dict(variables=variables),
parent_pr=anal.published_run,
)

# save the run before the result is ready
Expand Down
142 changes: 55 additions & 87 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,124 +26,93 @@
from daras_ai_v2.settings import templates
from gooey_ui.pubsub import realtime_push
from gooey_ui.state import set_query_params
from gooeysite.bg_db_conn import db_middleware, next_db_safe
from gooeysite.bg_db_conn import db_middleware

DEFAULT_RUN_STATUS = "Running..."


@app.task
def gui_runner(
*,
page_cls: typing.Type[BasePage],
user_id: str,
user_id: int,
run_id: str,
uid: str,
state: dict,
channel: str,
query_params: dict = None,
is_api_call: bool = False,
):
page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id)))

def event_processor(event, hint):
event["request"] = {
"method": "POST",
"url": page.app_url(query_params=query_params),
"data": state,
}
return event

page.setup_sentry(event_processor=event_processor)

sr = page.run_doc_sr(run_id, uid)
sr.is_api_call = is_api_call

st.set_session_state(state)
run_time = 0
yield_val = None
start_time = time()
error_msg = None
set_query_params(query_params or {})

@db_middleware
def save(done=False):
def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False):
if isinstance(yield_val, tuple):
run_status, extra_output = yield_val
else:
run_status = yield_val
extra_output = {}

if done:
# clear run status
run_status = None
else:
# set run status to the yield value of generator
run_status = yield_val or "Running..."
if isinstance(run_status, tuple):
run_status, extra_output = run_status
else:
extra_output = {}
# set run status and run time
status = {
StateKeys.run_time: run_time,
StateKeys.error_msg: error_msg,
StateKeys.run_status: run_status,
}
run_status = run_status or DEFAULT_RUN_STATUS

output = (
status
|
# extract outputs from local state
# extract status of the run
{
StateKeys.error_msg: error_msg,
StateKeys.run_time: time() - start_time,
StateKeys.run_status: run_status,
}
# extract outputs from local state
| {
k: v
for k, v in st.session_state.items()
if k in page.ResponseModel.__fields__
}
# add extra outputs from the run
| extra_output
)

# send outputs to ui
realtime_push(channel, output)
# save to db
page.dump_state_to_sr(st.session_state | output, sr)

user = AppUser.objects.get(id=user_id)
page = page_cls(request=SimpleNamespace(user=user))
page.setup_sentry()
sr = page.run_doc_sr(run_id, uid)
st.set_session_state(sr.to_dict())
set_query_params(dict(run_id=run_id, uid=uid))

try:
gen = page.run(st.session_state)
save()
while True:
# record time
start_time = time()
save_on_step()
for val in page.main(sr, st.session_state):
save_on_step(val)
# render errors nicely
except Exception as e:
if isinstance(e, HTTPException) and e.status_code == 402:
error_msg = page.generate_credit_error_message(run_id, uid)
try:
# advance the generator (to further progress of run())
yield_val = next_db_safe(gen)
# increment total time taken after every iteration
run_time += time() - start_time
continue
# run completed
except StopIteration:
run_time += time() - start_time
sr.transaction, sr.price = page.deduct_credits(st.session_state)
break
# render errors nicely
except Exception as e:
run_time += time() - start_time

if isinstance(e, HTTPException) and e.status_code == 402:
error_msg = page.generate_credit_error_message(
example_id=query_params.get("example_id"),
run_id=run_id,
uid=uid,
)
try:
raise UserError(error_msg) from e
except UserError as e:
sentry_sdk.capture_exception(e, level=e.sentry_level)
break

if isinstance(e, UserError):
sentry_level = e.sentry_level
else:
sentry_level = "error"
traceback.print_exc()
sentry_sdk.capture_exception(e, level=sentry_level)
error_msg = err_msg_for_exc(e)
break
finally:
save()
raise UserError(error_msg) from e
except UserError as e:
sentry_sdk.capture_exception(e, level=e.sentry_level)
else:
if isinstance(e, UserError):
sentry_level = e.sentry_level
else:
sentry_level = "error"
traceback.print_exc()
sentry_sdk.capture_exception(e, level=sentry_level)
error_msg = err_msg_for_exc(e)
# run completed successfully, deduct credits
else:
sr.transaction, sr.price = page.deduct_credits(st.session_state)
finally:
save(done=True)
if not is_api_call:
save_on_step(done=True)
if not sr.is_api_call:
send_email_on_completion(page, sr)
run_low_balance_email_check(uid)
run_low_balance_email_check(user)


def err_msg_for_exc(e: Exception):
Expand Down Expand Up @@ -172,11 +141,10 @@ def err_msg_for_exc(e: Exception):
return f"{type(e).__name__}: {e}"


def run_low_balance_email_check(uid: str):
def run_low_balance_email_check(user: AppUser):
# don't send email if feature is disabled
if not settings.LOW_BALANCE_EMAIL_ENABLED:
return
user = AppUser.objects.get(uid=uid)
# don't send email if user is not paying or has enough balance
if not user.is_paying or user.balance > settings.LOW_BALANCE_EMAIL_CREDITS:
return
Expand Down
Loading

0 comments on commit 25a9e5d

Please sign in to comment.