Skip to content

Commit

Permalink
save function calls to db and show in steps
Browse files Browse the repository at this point in the history
add pre/post functions to all recipes
enable variables in functions
lazy load details and steps
  • Loading branch information
devxpy committed Jul 5, 2024
1 parent 91772e6 commit d20e058
Show file tree
Hide file tree
Showing 30 changed files with 664 additions and 192 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Generated by Django 4.2.7 on 2024-07-05 13:44

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('bots', '0075_alter_publishedrun_workflow_alter_savedrun_workflow_and_more'),
]

operations = [
migrations.AlterField(
model_name='workflowmetadata',
name='default_image',
field=models.URLField(blank=True, default='', help_text='Image shown on explore page'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='help_url',
field=models.URLField(blank=True, default='', help_text='(Not implemented)'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='meta_keywords',
field=models.JSONField(blank=True, default=list, help_text='(Not implemented)'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='short_title',
field=models.TextField(help_text='Title used in breadcrumbs'),
),
]
10 changes: 10 additions & 0 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from bots.custom_fields import PostgresJSONEncoder, CustomURLField
from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.language_model import format_chat_entry
from functions.models import CalledFunction, CalledFunctionResponse
from gooeysite.custom_create import get_or_create_lazy

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -370,6 +371,15 @@ def get_creator(self) -> AppUser | None:
def open_in_gooey(self):
return open_in_new_tab(self.get_app_url(), label=self.get_app_url())

def api_output(self, state: dict = None) -> dict:
state = state or self.state
if self.state.get("functions"):
state["called_functions"] = [
CalledFunctionResponse.from_db(called_fn)
for called_fn in self.called_functions.all()
]
return state


def _parse_dt(dt) -> datetime.datetime | None:
if isinstance(dt, str):
Expand Down
12 changes: 4 additions & 8 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from daras_ai_v2.auto_recharge import auto_recharge_user
from daras_ai_v2.base import StateKeys, BasePage
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.fastapi_tricks import extract_model_fields
from daras_ai_v2.redis_cache import redis_lock
from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email
from daras_ai_v2.settings import templates
Expand Down Expand Up @@ -82,13 +83,8 @@ def save(done=False):
}
output = (
status
|
# extract outputs from local state
{
k: v
for k, v in st.session_state.items()
if k in page.ResponseModel.__fields__
}
# extract outputs from session state
| extract_model_fields(page.ResponseModel, st.session_state)
| extra_output
)
# send outputs to ui
Expand All @@ -97,7 +93,7 @@ def save(done=False):
page.dump_state_to_sr(st.session_state | output, sr)

try:
gen = page.run(st.session_state)
gen = page.main(sr, st.session_state)
save()
while True:
# record time
Expand Down
119 changes: 85 additions & 34 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
from copy import deepcopy, copy
from enum import Enum
from functools import cached_property
from itertools import pairwise
from random import Random
from time import sleep
Expand All @@ -18,7 +19,7 @@
from fastapi import HTTPException
from firebase_admin import auth
from furl import furl
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sentry_sdk.tracing import (
TRANSACTION_SOURCE_ROUTE,
)
Expand Down Expand Up @@ -48,11 +49,12 @@
from daras_ai_v2.db import (
ANONYMOUS_USER_COOKIE,
)
from daras_ai_v2.fastapi_tricks import get_route_path
from daras_ai_v2.fastapi_tricks import get_route_path, extract_model_fields
from daras_ai_v2.grid_layout_widget import grid_layout
from daras_ai_v2.html_spinner_widget import html_spinner
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
from daras_ai_v2.meta_preview_url import meta_preview_url
from daras_ai_v2.prompt_vars import variables_input
from daras_ai_v2.query_params import (
gooey_get_query_params,
)
Expand All @@ -64,14 +66,23 @@
from daras_ai_v2.user_date_widgets import (
render_local_dt_attrs,
)
from functions.models import (
RecipeFunction,
FunctionTrigger,
)
from functions.recipe_functions import (
functions_input,
call_recipe_functions,
is_functions_enabled,
render_called_functions,
)
from gooey_ui import (
realtime_clear_subs,
RedirectException,
)
from gooey_ui.components.modal import Modal
from gooey_ui.components.pills import pill
from gooey_ui.pubsub import realtime_pull
from gooeysite.custom_create import get_or_create_lazy
from routers.account import AccountTabs
from routers.root import RecipeTabs

Expand Down Expand Up @@ -117,7 +128,28 @@ class BasePage:

explore_image: str = None

RequestModel: typing.Type[BaseModel]
template_keys: typing.Iterable[str] = (
"task_instructions",
"query_instructions",
"keyword_instructions",
"input_prompt",
"bot_script",
"text_prompt",
"search_query",
"title",
)

class RequestModel(BaseModel):
functions: list[RecipeFunction] | None = Field(
None,
title="🧩 Functions",
)
variables: dict[str, typing.Any] = Field(
None,
title="⌥ Variables",
description="Variables to be used as Jinja prompt templates and in functions as arguments",
)

ResponseModel: typing.Type[BaseModel]

price = settings.CREDITS_TO_DEDUCT_PER_RUN
Expand Down Expand Up @@ -1310,19 +1342,28 @@ def render_run_cost(self):
st.caption(ret, line_clamp=1, unsafe_allow_html=True)

def _render_step_row(self):
with st.expander("**ℹ️ Details**"):
key = "details-expander"
with st.expander("**ℹ️ Details**", key=key):
if not st.session_state.get(key):
return
col1, col2 = st.columns([1, 2])
with col1:
self.render_description()
with col2:
placeholder = st.div()
render_called_functions(
saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre
)
try:
self.render_steps()
except NotImplementedError:
pass
else:
with placeholder:
st.write("##### 👣 Steps")
render_called_functions(
saved_run=self.get_current_sr(), trigger=FunctionTrigger.post
)

def _render_help(self):
placeholder = st.div()
Expand All @@ -1338,9 +1379,13 @@ def _render_help(self):
"""
)

key = "discord-expander"
with st.expander(
f"**🙋🏽‍♀️ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**"
f"**🙋🏽‍♀️ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**",
key=key,
):
if not st.session_state.get(key):
return
st.markdown(
"""
<div style="position: relative; padding-bottom: 56.25%; height: 500px; max-width: 500px;">
Expand All @@ -1353,6 +1398,27 @@ def _render_help(self):
def render_usage_guide(self):
raise NotImplementedError

def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]:
yield from call_recipe_functions(
saved_run=sr,
current_user=self.request.user,
request_model=self.RequestModel,
response_model=self.ResponseModel,
state=state,
trigger=FunctionTrigger.pre,
)

yield from self.run(state)

yield from call_recipe_functions(
saved_run=sr,
current_user=self.request.user,
request_model=self.RequestModel,
response_model=self.ResponseModel,
state=state,
trigger=FunctionTrigger.post,
)

def run(self, state: dict) -> typing.Iterator[str | None]:
# initialize request and response
request = self.RequestModel.parse_obj(state)
Expand All @@ -1373,7 +1439,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
self.ResponseModel.validate(response)

def run_v2(
self, request: BaseModel, response: BaseModel
self, request: RequestModel, response: BaseModel
) -> typing.Iterator[str | None]:
raise NotImplementedError

Expand All @@ -1400,8 +1466,11 @@ def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool):

def _render_input_col(self):
self.render_form_v2()
self.render_variables()

with st.expander("⚙️ Settings"):
self.render_settings()

submitted = self.render_submit_button()
with st.div(style={"textAlign": "right"}):
st.caption(
Expand All @@ -1410,6 +1479,13 @@ def _render_input_col(self):
)
return submitted

def render_variables(self):
st.write("---")
functions_input(self.request.user)
variables_input(
template_keys=self.template_keys, allow_add=is_functions_enabled()
)

@classmethod
def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState:
if state.get(StateKeys.run_status):
Expand Down Expand Up @@ -2058,7 +2134,8 @@ def get_example_response_body(
run_id=run_id,
uid=self.request.user and self.request.user.uid,
)
output = extract_model_fields(self.ResponseModel, state, include_all=True)
sr = self.get_current_sr()
output = sr.api_output(extract_model_fields(self.ResponseModel, state))
if as_async:
return dict(
run_id=run_id,
Expand Down Expand Up @@ -2143,32 +2220,6 @@ def render_output_caption():
)


def extract_model_fields(
model: typing.Type[BaseModel],
state: dict,
include_all: bool = False,
preferred_fields: list[str] = None,
diff_from: dict | None = None,
) -> dict:
"""
Include a field in result if:
- include_all is true
- field is required
- field is preferred
- diff_from is provided and field value differs from diff_from
"""
return {
field_name: state.get(field_name)
for field_name, field in model.__fields__.items()
if (
include_all
or field.required
or (preferred_fields and field_name in preferred_fields)
or (diff_from and state.get(field_name) != diff_from.get(field_name))
)
}


def extract_nested_str(obj) -> str:
if isinstance(obj, str):
return obj
Expand Down
31 changes: 31 additions & 0 deletions daras_ai_v2/custom_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import typing
from enum import Enum

import typing_extensions

T = typing.TypeVar("T", bound="GooeyEnum")


class GooeyEnum(Enum):
@classmethod
def db_choices(cls):
return [(e.db_value, e.label) for e in cls]

@classmethod
def from_db(cls, db_value) -> typing_extensions.Self:
for e in cls:
if e.db_value == db_value:
return e
raise ValueError(f"Invalid {cls.__name__} {db_value=}")

@classmethod
@property
def api_choices(cls):
return typing.Literal[tuple(e.name for e in cls)]

@classmethod
def from_api(cls, name: str) -> typing_extensions.Self:
for e in cls:
if e.name == name:
return e
raise ValueError(f"Invalid {cls.__name__} {name=}")
8 changes: 1 addition & 7 deletions daras_ai_v2/doc_search_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from daras_ai_v2.embedding_model import EmbeddingModels
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder
from daras_ai_v2.prompt_vars import prompt_vars_widget
from daras_ai_v2.prompt_vars import variables_input
from daras_ai_v2.search_ref import CitationStyles

_user_media_url_prefix = os.path.join(
Expand Down Expand Up @@ -110,9 +110,6 @@ def query_instructions_widget():
key="query_instructions",
height=300,
)
prompt_vars_widget(
"query_instructions",
)


def keyword_instructions_widget():
Expand All @@ -124,9 +121,6 @@ def keyword_instructions_widget():
key="keyword_instructions",
height=300,
)
prompt_vars_widget(
"keyword_instructions",
)


def doc_extract_selector(current_user: AppUser | None):
Expand Down
Loading

0 comments on commit d20e058

Please sign in to comment.