Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pre/post function triggers #398

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,18 @@ def api_integration_stats_url(self, bi: BotIntegration):
)


@admin.register(PublishedRunVersion)
class PublishedRunVersionAdmin(admin.ModelAdmin):
search_fields = ["id", "version_id", "published_run__published_run_id"]
autocomplete_fields = ["published_run", "saved_run", "changed_by"]


class PublishedRunVersionInline(admin.TabularInline):
model = PublishedRunVersion
extra = 0
autocomplete_fields = PublishedRunVersionAdmin.autocomplete_fields


@admin.register(PublishedRun)
class PublishedRunAdmin(admin.ModelAdmin):
list_display = [
Expand All @@ -290,6 +302,7 @@ class PublishedRunAdmin(admin.ModelAdmin):
"created_at",
"updated_at",
]
inlines = [PublishedRunVersionInline]

def view_user(self, published_run: PublishedRun):
if published_run.created_by is None:
Expand Down Expand Up @@ -425,12 +438,6 @@ def rerun_tasks(self, request, queryset):
)


@admin.register(PublishedRunVersion)
class PublishedRunVersionAdmin(admin.ModelAdmin):
search_fields = ["id", "version_id", "published_run__published_run_id"]
autocomplete_fields = ["published_run", "saved_run", "changed_by"]


class LastActiveDeltaFilter(admin.SimpleListFilter):
title = Conversation.last_active_delta.short_description
parameter_name = Conversation.last_active_delta.__name__
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Generated by Django 4.2.7 on 2024-07-05 13:44

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('bots', '0075_alter_publishedrun_workflow_alter_savedrun_workflow_and_more'),
]

operations = [
migrations.AlterField(
model_name='workflowmetadata',
name='default_image',
field=models.URLField(blank=True, default='', help_text='Image shown on explore page'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='help_url',
field=models.URLField(blank=True, default='', help_text='(Not implemented)'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='meta_keywords',
field=models.JSONField(blank=True, default=list, help_text='(Not implemented)'),
),
migrations.AlterField(
model_name='workflowmetadata',
name='short_title',
field=models.TextField(help_text='Title used in breadcrumbs'),
),
]
10 changes: 10 additions & 0 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from bots.custom_fields import PostgresJSONEncoder, CustomURLField
from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.language_model import format_chat_entry
from functions.models import CalledFunction, CalledFunctionResponse
from gooeysite.custom_create import get_or_create_lazy

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -370,6 +371,15 @@ def get_creator(self) -> AppUser | None:
def open_in_gooey(self):
return open_in_new_tab(self.get_app_url(), label=self.get_app_url())

def api_output(self, state: dict = None) -> dict:
state = state or self.state
if self.state.get("functions"):
state["called_functions"] = [
CalledFunctionResponse.from_db(called_fn)
for called_fn in self.called_functions.all()
]
return state


def _parse_dt(dt) -> datetime.datetime | None:
if isinstance(dt, str):
Expand Down
12 changes: 4 additions & 8 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from daras_ai_v2.auto_recharge import auto_recharge_user
from daras_ai_v2.base import StateKeys, BasePage
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.fastapi_tricks import extract_model_fields
from daras_ai_v2.redis_cache import redis_lock
from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email
from daras_ai_v2.settings import templates
Expand Down Expand Up @@ -82,13 +83,8 @@ def save(done=False):
}
output = (
status
|
# extract outputs from local state
{
k: v
for k, v in st.session_state.items()
if k in page.ResponseModel.__fields__
}
# extract outputs from session state
| extract_model_fields(page.ResponseModel, st.session_state)
| extra_output
)
# send outputs to ui
Expand All @@ -97,7 +93,7 @@ def save(done=False):
page.dump_state_to_sr(st.session_state | output, sr)

try:
gen = page.run(st.session_state)
gen = page.main(sr, st.session_state)
save()
while True:
# record time
Expand Down
119 changes: 85 additions & 34 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
from copy import deepcopy, copy
from enum import Enum
from functools import cached_property
from itertools import pairwise
from random import Random
from time import sleep
Expand All @@ -18,7 +19,7 @@
from fastapi import HTTPException
from firebase_admin import auth
from furl import furl
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sentry_sdk.tracing import (
TRANSACTION_SOURCE_ROUTE,
)
Expand Down Expand Up @@ -48,11 +49,12 @@
from daras_ai_v2.db import (
ANONYMOUS_USER_COOKIE,
)
from daras_ai_v2.fastapi_tricks import get_route_path
from daras_ai_v2.fastapi_tricks import get_route_path, extract_model_fields
from daras_ai_v2.grid_layout_widget import grid_layout
from daras_ai_v2.html_spinner_widget import html_spinner
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
from daras_ai_v2.meta_preview_url import meta_preview_url
from daras_ai_v2.prompt_vars import variables_input
from daras_ai_v2.query_params import (
gooey_get_query_params,
)
Expand All @@ -64,14 +66,23 @@
from daras_ai_v2.user_date_widgets import (
render_local_dt_attrs,
)
from functions.models import (
RecipeFunction,
FunctionTrigger,
)
from functions.recipe_functions import (
functions_input,
call_recipe_functions,
is_functions_enabled,
render_called_functions,
)
from gooey_ui import (
realtime_clear_subs,
RedirectException,
)
from gooey_ui.components.modal import Modal
from gooey_ui.components.pills import pill
from gooey_ui.pubsub import realtime_pull
from gooeysite.custom_create import get_or_create_lazy
from routers.account import AccountTabs
from routers.root import RecipeTabs

Expand Down Expand Up @@ -117,7 +128,28 @@ class BasePage:

explore_image: str = None

RequestModel: typing.Type[BaseModel]
template_keys: typing.Iterable[str] = (
"task_instructions",
"query_instructions",
"keyword_instructions",
"input_prompt",
"bot_script",
"text_prompt",
"search_query",
"title",
)

class RequestModel(BaseModel):
functions: list[RecipeFunction] | None = Field(
None,
title="🧩 Functions",
)
variables: dict[str, typing.Any] = Field(
None,
title="⌥ Variables",
description="Variables to be used as Jinja prompt templates and in functions as arguments",
)

ResponseModel: typing.Type[BaseModel]

price = settings.CREDITS_TO_DEDUCT_PER_RUN
Expand Down Expand Up @@ -1310,19 +1342,28 @@ def render_run_cost(self):
st.caption(ret, line_clamp=1, unsafe_allow_html=True)

def _render_step_row(self):
with st.expander("**ℹ️ Details**"):
key = "details-expander"
with st.expander("**ℹ️ Details**", key=key):
if not st.session_state.get(key):
return
col1, col2 = st.columns([1, 2])
with col1:
self.render_description()
with col2:
placeholder = st.div()
render_called_functions(
saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre
)
try:
self.render_steps()
except NotImplementedError:
pass
else:
with placeholder:
st.write("##### 👣 Steps")
render_called_functions(
saved_run=self.get_current_sr(), trigger=FunctionTrigger.post
)

def _render_help(self):
placeholder = st.div()
Expand All @@ -1338,9 +1379,13 @@ def _render_help(self):
"""
)

key = "discord-expander"
with st.expander(
f"**🙋🏽‍♀️ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**"
f"**🙋🏽‍♀️ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**",
key=key,
):
if not st.session_state.get(key):
return
st.markdown(
"""
<div style="position: relative; padding-bottom: 56.25%; height: 500px; max-width: 500px;">
Expand All @@ -1353,6 +1398,27 @@ def _render_help(self):
def render_usage_guide(self):
raise NotImplementedError

def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]:
yield from call_recipe_functions(
saved_run=sr,
current_user=self.request.user,
request_model=self.RequestModel,
response_model=self.ResponseModel,
state=state,
trigger=FunctionTrigger.pre,
)

yield from self.run(state)

yield from call_recipe_functions(
saved_run=sr,
current_user=self.request.user,
request_model=self.RequestModel,
response_model=self.ResponseModel,
state=state,
trigger=FunctionTrigger.post,
)

def run(self, state: dict) -> typing.Iterator[str | None]:
# initialize request and response
request = self.RequestModel.parse_obj(state)
Expand All @@ -1373,7 +1439,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
self.ResponseModel.validate(response)

def run_v2(
self, request: BaseModel, response: BaseModel
self, request: RequestModel, response: BaseModel
) -> typing.Iterator[str | None]:
raise NotImplementedError

Expand All @@ -1400,8 +1466,11 @@ def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool):

def _render_input_col(self):
self.render_form_v2()
self.render_variables()

with st.expander("⚙️ Settings"):
self.render_settings()

submitted = self.render_submit_button()
with st.div(style={"textAlign": "right"}):
st.caption(
Expand All @@ -1410,6 +1479,13 @@ def _render_input_col(self):
)
return submitted

def render_variables(self):
st.write("---")
functions_input(self.request.user)
variables_input(
template_keys=self.template_keys, allow_add=is_functions_enabled()
)

@classmethod
def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState:
if state.get(StateKeys.run_status):
Expand Down Expand Up @@ -2058,7 +2134,8 @@ def get_example_response_body(
run_id=run_id,
uid=self.request.user and self.request.user.uid,
)
output = extract_model_fields(self.ResponseModel, state, include_all=True)
sr = self.get_current_sr()
output = sr.api_output(extract_model_fields(self.ResponseModel, state))
if as_async:
return dict(
run_id=run_id,
Expand Down Expand Up @@ -2143,32 +2220,6 @@ def render_output_caption():
)


def extract_model_fields(
model: typing.Type[BaseModel],
state: dict,
include_all: bool = False,
preferred_fields: list[str] = None,
diff_from: dict | None = None,
) -> dict:
"""
Include a field in result if:
- include_all is true
- field is required
- field is preferred
- diff_from is provided and field value differs from diff_from
"""
return {
field_name: state.get(field_name)
for field_name, field in model.__fields__.items()
if (
include_all
or field.required
or (preferred_fields and field_name in preferred_fields)
or (diff_from and state.get(field_name) != diff_from.get(field_name))
)
}


def extract_nested_str(obj) -> str:
if isinstance(obj, str):
return obj
Expand Down
31 changes: 31 additions & 0 deletions daras_ai_v2/custom_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import typing
from enum import Enum

import typing_extensions

T = typing.TypeVar("T", bound="GooeyEnum")


class GooeyEnum(Enum):
@classmethod
def db_choices(cls):
return [(e.db_value, e.label) for e in cls]

@classmethod
def from_db(cls, db_value) -> typing_extensions.Self:
for e in cls:
if e.db_value == db_value:
return e
raise ValueError(f"Invalid {cls.__name__} {db_value=}")

@classmethod
@property
def api_choices(cls):
return typing.Literal[tuple(e.name for e in cls)]

@classmethod
def from_api(cls, name: str) -> typing_extensions.Self:
for e in cls:
if e.name == name:
return e
raise ValueError(f"Invalid {cls.__name__} {name=}")
Loading