Skip to content

Commit

Permalink
save function calls to db and show in steps
Browse files Browse the repository at this point in the history
add pre/post functions to all recipes
enable variables in functions
lazy load details and steps
  • Loading branch information
devxpy committed Jul 5, 2024
1 parent ad6851c commit 05db35c
Show file tree
Hide file tree
Showing 29 changed files with 664 additions and 181 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Generated by Django 4.2.7 on 2024-07-05 13:44

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('bots', '0075_alter_publishedrun_workflow_alter_savedrun_workflow_and_more'),
]

operations = [
migrations.AlterField(
model_name='workflowmetadata',
name='default_image',
field=models.URLField(blank=True, default='', help_text='Image shown on explore page'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='help_url',
field=models.URLField(blank=True, default='', help_text='(Not implemented)'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='meta_keywords',
field=models.JSONField(blank=True, default=list, help_text='(Not implemented)'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='short_title',
field=models.TextField(help_text='Title used in breadcrumbs'),
),
]
10 changes: 10 additions & 0 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from bots.custom_fields import PostgresJSONEncoder, CustomURLField
from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.language_model import format_chat_entry
from functions.models import CalledFunction, CalledFunctionResponse
from gooeysite.custom_create import get_or_create_lazy

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -370,6 +371,15 @@ def get_creator(self) -> AppUser | None:
def open_in_gooey(self):
return open_in_new_tab(self.get_app_url(), label=self.get_app_url())

def api_output(self, state: dict = None) -> dict:
state = state or self.state
if self.state.get("functions"):
state["called_functions"] = [
CalledFunctionResponse.from_db(called_fn)
for called_fn in self.called_functions.all()
]
return state


def _parse_dt(dt) -> datetime.datetime | None:
if isinstance(dt, str):
Expand Down
2 changes: 1 addition & 1 deletion celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def save(done=False):
page.dump_state_to_sr(st.session_state | output, sr)

try:
gen = page.run(st.session_state)
gen = page.main(sr, st.session_state)
save()
while True:
# record time
Expand Down
138 changes: 111 additions & 27 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import html
import inspect
import json
import math
import typing
import uuid
Expand All @@ -18,7 +19,7 @@
from fastapi import HTTPException
from firebase_admin import auth
from furl import furl
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sentry_sdk.tracing import (
TRANSACTION_SOURCE_ROUTE,
)
Expand Down Expand Up @@ -53,6 +54,7 @@
from daras_ai_v2.html_spinner_widget import html_spinner
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
from daras_ai_v2.meta_preview_url import meta_preview_url
from daras_ai_v2.prompt_vars import variables_input
from daras_ai_v2.query_params import (
gooey_get_query_params,
)
Expand All @@ -64,14 +66,23 @@
from daras_ai_v2.user_date_widgets import (
render_local_dt_attrs,
)
from functions.models import (
RecipeFunction,
FunctionTrigger,
)
from functions.recipe_functions import (
functions_input,
call_recipe_functions,
is_functions_enabled,
render_called_functions,
)
from gooey_ui import (
realtime_clear_subs,
RedirectException,
)
from gooey_ui.components.modal import Modal
from gooey_ui.components.pills import pill
from gooey_ui.pubsub import realtime_pull
from gooeysite.custom_create import get_or_create_lazy
from routers.account import AccountTabs
from routers.root import RecipeTabs

Expand Down Expand Up @@ -117,7 +128,28 @@ class BasePage:

explore_image: str = None

RequestModel: typing.Type[BaseModel]
template_keys: typing.Iterable[str] = (
"task_instructions",
"query_instructions",
"keyword_instructions",
"input_prompt",
"bot_script",
"text_prompt",
"search_query",
"title",
)

class RequestModel(BaseModel):
functions: list[RecipeFunction] | None = Field(
None,
title="🧩 Functions",
)
variables: dict[str, typing.Any] = Field(
None,
title="⌥ Variables",
description="Variables to be used as Jinja prompt templates and in functions as arguments",
)

ResponseModel: typing.Type[BaseModel]

price = settings.CREDITS_TO_DEDUCT_PER_RUN
Expand Down Expand Up @@ -1310,19 +1342,28 @@ def render_run_cost(self):
st.caption(ret, line_clamp=1, unsafe_allow_html=True)

def _render_step_row(self):
with st.expander("**ℹ️ Details**"):
key = "details-expander"
with st.expander("**ℹ️ Details**", key=key):
if not st.session_state.get(key):
return
col1, col2 = st.columns([1, 2])
with col1:
self.render_description()
with col2:
placeholder = st.div()
render_called_functions(
saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre
)
try:
self.render_steps()
except NotImplementedError:
pass
else:
with placeholder:
st.write("##### 👣 Steps")
render_called_functions(
saved_run=self.get_current_sr(), trigger=FunctionTrigger.post
)

def _render_help(self):
placeholder = st.div()
Expand All @@ -1338,9 +1379,13 @@ def _render_help(self):
"""
)

key = "discord-expander"
with st.expander(
f"**🙋🏽‍♀️ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**"
f"**🙋🏽‍♀️ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**",
key=key,
):
if not st.session_state.get(key):
return
st.markdown(
"""
<div style="position: relative; padding-bottom: 56.25%; height: 500px; max-width: 500px;">
Expand All @@ -1353,6 +1398,27 @@ def _render_help(self):
def render_usage_guide(self):
raise NotImplementedError

def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]:
yield from call_recipe_functions(
saved_run=sr,
current_user=self.request.user,
request_model=self.RequestModel,
response_model=self.ResponseModel,
state=state,
trigger=FunctionTrigger.pre,
)

yield from self.run(state)

yield from call_recipe_functions(
saved_run=sr,
current_user=self.request.user,
request_model=self.RequestModel,
response_model=self.ResponseModel,
state=state,
trigger=FunctionTrigger.post,
)

def run(self, state: dict) -> typing.Iterator[str | None]:
# initialize request and response
request = self.RequestModel.parse_obj(state)
Expand All @@ -1373,7 +1439,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
self.ResponseModel.validate(response)

def run_v2(
self, request: BaseModel, response: BaseModel
self, request: RequestModel, response: BaseModel
) -> typing.Iterator[str | None]:
raise NotImplementedError

Expand All @@ -1400,8 +1466,11 @@ def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool):

def _render_input_col(self):
self.render_form_v2()
self.render_variables()

with st.expander("⚙️ Settings"):
self.render_settings()

submitted = self.render_submit_button()
with st.div(style={"textAlign": "right"}):
st.caption(
Expand All @@ -1410,6 +1479,13 @@ def _render_input_col(self):
)
return submitted

def render_variables(self):
st.write("---")
functions_input(self.request.user)
variables_input(
template_keys=self.template_keys, allow_add=is_functions_enabled()
)

@classmethod
def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState:
if state.get(StateKeys.run_status):
Expand Down Expand Up @@ -1506,7 +1582,7 @@ def on_submit(self):
from celeryapp.tasks import auto_recharge

try:
example_id, run_id, uid = self.create_new_run(enable_rate_limits=True)
sr = self.create_new_run(enable_rate_limits=True)
except RateLimitExceeded as e:
st.session_state[StateKeys.run_status] = None
st.session_state[StateKeys.error_msg] = e.detail.get("error", "")
Expand All @@ -1516,15 +1592,15 @@ def on_submit(self):
auto_recharge.delay(user_id=self.request.user.id)

if settings.CREDITS_TO_DEDUCT_PER_RUN and not self.check_credits():
st.session_state[StateKeys.run_status] = None
st.session_state[StateKeys.error_msg] = self.generate_credit_error_message(
example_id, run_id, uid
sr.run_status = ""
sr.error_msg = self.generate_credit_error_message(
sr.example_id, sr.run_id, sr.uid
)
self.dump_state_to_sr(st.session_state, self.run_doc_sr(run_id, uid))
sr.save(update_fields=["run_status", "error_msg"])
else:
self.call_runner_task(example_id, run_id, uid)
self.call_runner_task(sr.example_id, sr.run_id, sr.uid)

raise RedirectException(self.app_url(run_id=run_id, uid=uid))
raise RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid))

def should_submit_after_login(self) -> bool:
return (
Expand All @@ -1534,7 +1610,9 @@ def should_submit_after_login(self) -> bool:
and not self.request.user.is_anonymous
)

def create_new_run(self, *, enable_rate_limits: bool = False, **defaults):
def create_new_run(
self, *, enable_rate_limits: bool = False, **defaults
) -> SavedRun:
st.session_state[StateKeys.run_status] = "Starting..."
st.session_state.pop(StateKeys.error_msg, None)
st.session_state.pop(StateKeys.run_time, None)
Expand Down Expand Up @@ -1574,9 +1652,23 @@ def create_new_run(self, *, enable_rate_limits: bool = False, **defaults):
create=True,
defaults=dict(parent=parent, parent_version=parent_version) | defaults,
)
self.dump_state_to_sr(st.session_state, sr)

return None, run_id, uid
# ensure the request is validated
state = st.session_state | json.loads(
self.RequestModel.parse_obj(st.session_state).json(exclude_unset=True)
)
self.dump_state_to_sr(state, sr)

return sr

def dump_state_to_sr(self, state: dict, sr: SavedRun):
sr.set(
{
field_name: deepcopy(state[field_name])
for field_name in self.fields_to_save()
if field_name in state
}
)

def call_runner_task(self, example_id, run_id, uid, is_api_call=False):
from celeryapp.tasks import gui_runner
Expand Down Expand Up @@ -1673,15 +1765,6 @@ def load_state_defaults(cls, state: dict):
state.setdefault(k, v)
return state

def dump_state_to_sr(self, state: dict, sr: SavedRun):
sr.set(
{
field_name: deepcopy(state[field_name])
for field_name in self.fields_to_save()
if field_name in state
}
)

def fields_to_save(self) -> [str]:
# only save the fields in request/response
return [
Expand Down Expand Up @@ -2058,7 +2141,8 @@ def get_example_response_body(
run_id=run_id,
uid=self.request.user and self.request.user.uid,
)
output = extract_model_fields(self.ResponseModel, state, include_all=True)
sr = self.get_current_sr()
output = sr.api_output(extract_model_fields(self.ResponseModel, state))
if as_async:
return dict(
run_id=run_id,
Expand Down Expand Up @@ -2146,7 +2230,7 @@ def render_output_caption():
def extract_model_fields(
model: typing.Type[BaseModel],
state: dict,
include_all: bool = False,
include_all: bool = True,
preferred_fields: list[str] = None,
diff_from: dict | None = None,
) -> dict:
Expand Down
31 changes: 31 additions & 0 deletions daras_ai_v2/custom_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import typing
from enum import Enum

import typing_extensions

T = typing.TypeVar("T", bound="GooeyEnum")


class GooeyEnum(Enum):
@classmethod
def db_choices(cls):
return [(e.db_value, e.label) for e in cls]

@classmethod
def from_db(cls, db_value) -> typing_extensions.Self:
for e in cls:
if e.db_value == db_value:
return e
raise ValueError(f"Invalid {cls.__name__} {db_value=}")

@classmethod
@property
def api_choices(cls):
return typing.Literal[tuple(e.name for e in cls)]

@classmethod
def from_api(cls, name: str) -> typing_extensions.Self:
for e in cls:
if e.name == name:
return e
raise ValueError(f"Invalid {cls.__name__} {name=}")
Loading

0 comments on commit 05db35c

Please sign in to comment.