diff --git a/bots/migrations/0076_alter_workflowmetadata_default_image_and_more.py b/bots/migrations/0076_alter_workflowmetadata_default_image_and_more.py
new file mode 100644
index 000000000..8be0bb4ed
--- /dev/null
+++ b/bots/migrations/0076_alter_workflowmetadata_default_image_and_more.py
@@ -0,0 +1,33 @@
+# Generated by Django 4.2.7 on 2024-07-05 13:44
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('bots', '0075_alter_publishedrun_workflow_alter_savedrun_workflow_and_more'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='workflowmetadata',
+ name='default_image',
+ field=models.URLField(blank=True, default='', help_text='Image shown on explore page'),
+ ),
+ migrations.AlterField(
+ model_name='workflowmetadata',
+ name='help_url',
+ field=models.URLField(blank=True, default='', help_text='(Not implemented)'),
+ ),
+ migrations.AlterField(
+ model_name='workflowmetadata',
+ name='meta_keywords',
+ field=models.JSONField(blank=True, default=list, help_text='(Not implemented)'),
+ ),
+ migrations.AlterField(
+ model_name='workflowmetadata',
+ name='short_title',
+ field=models.TextField(help_text='Title used in breadcrumbs'),
+ ),
+ ]
diff --git a/bots/models.py b/bots/models.py
index fbb124726..a2c5715a1 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -17,6 +17,7 @@
from bots.custom_fields import PostgresJSONEncoder, CustomURLField
from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.language_model import format_chat_entry
+from functions.models import CalledFunction, CalledFunctionResponse
from gooeysite.custom_create import get_or_create_lazy
if typing.TYPE_CHECKING:
@@ -370,6 +371,15 @@ def get_creator(self) -> AppUser | None:
def open_in_gooey(self):
return open_in_new_tab(self.get_app_url(), label=self.get_app_url())
+ def api_output(self, state: dict = None) -> dict:
+ state = state or self.state
+ if self.state.get("functions"):
+ state["called_functions"] = [
+ CalledFunctionResponse.from_db(called_fn)
+ for called_fn in self.called_functions.all()
+ ]
+ return state
+
def _parse_dt(dt) -> datetime.datetime | None:
if isinstance(dt, str):
diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py
index f535b6ff5..dbdb356b3 100644
--- a/celeryapp/tasks.py
+++ b/celeryapp/tasks.py
@@ -21,6 +21,7 @@
from daras_ai_v2.auto_recharge import auto_recharge_user
from daras_ai_v2.base import StateKeys, BasePage
from daras_ai_v2.exceptions import UserError
+from daras_ai_v2.fastapi_tricks import extract_model_fields
from daras_ai_v2.redis_cache import redis_lock
from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email
from daras_ai_v2.settings import templates
@@ -82,13 +83,8 @@ def save(done=False):
}
output = (
status
- |
- # extract outputs from local state
- {
- k: v
- for k, v in st.session_state.items()
- if k in page.ResponseModel.__fields__
- }
+ # extract outputs from session state
+ | extract_model_fields(page.ResponseModel, st.session_state)
| extra_output
)
# send outputs to ui
@@ -97,7 +93,7 @@ def save(done=False):
page.dump_state_to_sr(st.session_state | output, sr)
try:
- gen = page.run(st.session_state)
+ gen = page.main(sr, st.session_state)
save()
while True:
# record time
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 6b641314d..cde7ecbe0 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -6,6 +6,7 @@
import uuid
from copy import deepcopy, copy
from enum import Enum
+from functools import cached_property
from itertools import pairwise
from random import Random
from time import sleep
@@ -18,7 +19,7 @@
from fastapi import HTTPException
from firebase_admin import auth
from furl import furl
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from sentry_sdk.tracing import (
TRANSACTION_SOURCE_ROUTE,
)
@@ -48,11 +49,12 @@
from daras_ai_v2.db import (
ANONYMOUS_USER_COOKIE,
)
-from daras_ai_v2.fastapi_tricks import get_route_path
+from daras_ai_v2.fastapi_tricks import get_route_path, extract_model_fields
from daras_ai_v2.grid_layout_widget import grid_layout
from daras_ai_v2.html_spinner_widget import html_spinner
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
from daras_ai_v2.meta_preview_url import meta_preview_url
+from daras_ai_v2.prompt_vars import variables_input
from daras_ai_v2.query_params import (
gooey_get_query_params,
)
@@ -64,6 +66,16 @@
from daras_ai_v2.user_date_widgets import (
render_local_dt_attrs,
)
+from functions.models import (
+ RecipeFunction,
+ FunctionTrigger,
+)
+from functions.recipe_functions import (
+ functions_input,
+ call_recipe_functions,
+ is_functions_enabled,
+ render_called_functions,
+)
from gooey_ui import (
realtime_clear_subs,
RedirectException,
@@ -71,7 +83,6 @@
from gooey_ui.components.modal import Modal
from gooey_ui.components.pills import pill
from gooey_ui.pubsub import realtime_pull
-from gooeysite.custom_create import get_or_create_lazy
from routers.account import AccountTabs
from routers.root import RecipeTabs
@@ -117,7 +128,28 @@ class BasePage:
explore_image: str = None
- RequestModel: typing.Type[BaseModel]
+ template_keys: typing.Iterable[str] = (
+ "task_instructions",
+ "query_instructions",
+ "keyword_instructions",
+ "input_prompt",
+ "bot_script",
+ "text_prompt",
+ "search_query",
+ "title",
+ )
+
+ class RequestModel(BaseModel):
+ functions: list[RecipeFunction] | None = Field(
+ None,
+ title="𧊠Functions",
+ )
+ variables: dict[str, typing.Any] = Field(
+ None,
+ title="âĨ Variables",
+ description="Variables to be used as Jinja prompt templates and in functions as arguments",
+ )
+
ResponseModel: typing.Type[BaseModel]
price = settings.CREDITS_TO_DEDUCT_PER_RUN
@@ -1310,12 +1342,18 @@ def render_run_cost(self):
st.caption(ret, line_clamp=1, unsafe_allow_html=True)
def _render_step_row(self):
- with st.expander("**âšī¸ Details**"):
+ key = "details-expander"
+ with st.expander("**âšī¸ Details**", key=key):
+ if not st.session_state.get(key):
+ return
col1, col2 = st.columns([1, 2])
with col1:
self.render_description()
with col2:
placeholder = st.div()
+ render_called_functions(
+ saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre
+ )
try:
self.render_steps()
except NotImplementedError:
@@ -1323,6 +1361,9 @@ def _render_step_row(self):
else:
with placeholder:
st.write("##### đŖ Steps")
+ render_called_functions(
+ saved_run=self.get_current_sr(), trigger=FunctionTrigger.post
+ )
def _render_help(self):
placeholder = st.div()
@@ -1338,9 +1379,13 @@ def _render_help(self):
"""
)
+ key = "discord-expander"
with st.expander(
- f"**đđŊââī¸ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**"
+ f"**đđŊââī¸ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**",
+ key=key,
):
+ if not st.session_state.get(key):
+ return
st.markdown(
"""
@@ -1353,6 +1398,27 @@ def _render_help(self):
def render_usage_guide(self):
raise NotImplementedError
+ def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]:
+ yield from call_recipe_functions(
+ saved_run=sr,
+ current_user=self.request.user,
+ request_model=self.RequestModel,
+ response_model=self.ResponseModel,
+ state=state,
+ trigger=FunctionTrigger.pre,
+ )
+
+ yield from self.run(state)
+
+ yield from call_recipe_functions(
+ saved_run=sr,
+ current_user=self.request.user,
+ request_model=self.RequestModel,
+ response_model=self.ResponseModel,
+ state=state,
+ trigger=FunctionTrigger.post,
+ )
+
def run(self, state: dict) -> typing.Iterator[str | None]:
# initialize request and response
request = self.RequestModel.parse_obj(state)
@@ -1373,7 +1439,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
self.ResponseModel.validate(response)
def run_v2(
- self, request: BaseModel, response: BaseModel
+ self, request: RequestModel, response: BaseModel
) -> typing.Iterator[str | None]:
raise NotImplementedError
@@ -1400,8 +1466,11 @@ def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool):
def _render_input_col(self):
self.render_form_v2()
+ self.render_variables()
+
with st.expander("âī¸ Settings"):
self.render_settings()
+
submitted = self.render_submit_button()
with st.div(style={"textAlign": "right"}):
st.caption(
@@ -1410,6 +1479,13 @@ def _render_input_col(self):
)
return submitted
+ def render_variables(self):
+ st.write("---")
+ functions_input(self.request.user)
+ variables_input(
+ template_keys=self.template_keys, allow_add=is_functions_enabled()
+ )
+
@classmethod
def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState:
if state.get(StateKeys.run_status):
@@ -2058,7 +2134,8 @@ def get_example_response_body(
run_id=run_id,
uid=self.request.user and self.request.user.uid,
)
- output = extract_model_fields(self.ResponseModel, state, include_all=True)
+ sr = self.get_current_sr()
+ output = sr.api_output(extract_model_fields(self.ResponseModel, state))
if as_async:
return dict(
run_id=run_id,
@@ -2143,32 +2220,6 @@ def render_output_caption():
)
-def extract_model_fields(
- model: typing.Type[BaseModel],
- state: dict,
- include_all: bool = False,
- preferred_fields: list[str] = None,
- diff_from: dict | None = None,
-) -> dict:
- """
- Include a field in result if:
- - include_all is true
- - field is required
- - field is preferred
- - diff_from is provided and field value differs from diff_from
- """
- return {
- field_name: state.get(field_name)
- for field_name, field in model.__fields__.items()
- if (
- include_all
- or field.required
- or (preferred_fields and field_name in preferred_fields)
- or (diff_from and state.get(field_name) != diff_from.get(field_name))
- )
- }
-
-
def extract_nested_str(obj) -> str:
if isinstance(obj, str):
return obj
diff --git a/daras_ai_v2/custom_enum.py b/daras_ai_v2/custom_enum.py
new file mode 100644
index 000000000..b9aacb843
--- /dev/null
+++ b/daras_ai_v2/custom_enum.py
@@ -0,0 +1,31 @@
+import typing
+from enum import Enum
+
+import typing_extensions
+
+T = typing.TypeVar("T", bound="GooeyEnum")
+
+
+class GooeyEnum(Enum):
+ @classmethod
+ def db_choices(cls):
+ return [(e.db_value, e.label) for e in cls]
+
+ @classmethod
+ def from_db(cls, db_value) -> typing_extensions.Self:
+ for e in cls:
+ if e.db_value == db_value:
+ return e
+ raise ValueError(f"Invalid {cls.__name__} {db_value=}")
+
+ @classmethod
+ @property
+ def api_choices(cls):
+ return typing.Literal[tuple(e.name for e in cls)]
+
+ @classmethod
+ def from_api(cls, name: str) -> typing_extensions.Self:
+ for e in cls:
+ if e.name == name:
+ return e
+ raise ValueError(f"Invalid {cls.__name__} {name=}")
diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py
index 00d6c9983..10ae86a7d 100644
--- a/daras_ai_v2/doc_search_settings_widgets.py
+++ b/daras_ai_v2/doc_search_settings_widgets.py
@@ -10,7 +10,7 @@
from daras_ai_v2.embedding_model import EmbeddingModels
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder
-from daras_ai_v2.prompt_vars import prompt_vars_widget
+from daras_ai_v2.prompt_vars import variables_input
from daras_ai_v2.search_ref import CitationStyles
_user_media_url_prefix = os.path.join(
@@ -110,9 +110,6 @@ def query_instructions_widget():
key="query_instructions",
height=300,
)
- prompt_vars_widget(
- "query_instructions",
- )
def keyword_instructions_widget():
@@ -124,9 +121,6 @@ def keyword_instructions_widget():
key="keyword_instructions",
height=300,
)
- prompt_vars_widget(
- "keyword_instructions",
- )
def doc_extract_selector(current_user: AppUser | None):
diff --git a/daras_ai_v2/fastapi_tricks.py b/daras_ai_v2/fastapi_tricks.py
index 494393ab2..3d080e5be 100644
--- a/daras_ai_v2/fastapi_tricks.py
+++ b/daras_ai_v2/fastapi_tricks.py
@@ -5,6 +5,7 @@
from fastapi import Depends
from fastapi.routing import APIRoute
from furl import furl
+from pydantic import BaseModel
from starlette.requests import Request
from daras_ai_v2 import settings
@@ -64,3 +65,29 @@ def get_route_path(route_fn: typing.Callable, params: dict = None) -> str:
from server import app
return os.path.join(app.url_path_for(route_fn.__name__, **(params or {})), "")
+
+
+def extract_model_fields(
+ model: typing.Type[BaseModel],
+ state: dict,
+ include_all: bool = True,
+ preferred_fields: list[str] = None,
+ diff_from: dict | None = None,
+) -> dict:
+ """
+ Include a field in result if:
+ - include_all is true
+ - field is required
+ - field is preferred
+ - diff_from is provided and field value differs from diff_from
+ """
+ return {
+ field_name: state.get(field_name)
+ for field_name, field in model.__fields__.items()
+ if (
+ include_all
+ or field.required
+ or (preferred_fields and field_name in preferred_fields)
+ or (diff_from and state.get(field_name) != diff_from.get(field_name))
+ )
+ }
diff --git a/daras_ai_v2/functions.py b/daras_ai_v2/functions.py
deleted file mode 100644
index 2c6c03348..000000000
--- a/daras_ai_v2/functions.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import json
-import tempfile
-import typing
-from enum import Enum
-
-from daras_ai.image_input import upload_file_from_bytes
-from daras_ai_v2.settings import templates
-
-
-def json_to_pdf(filename: str, data: str) -> str:
- html = templates.get_template("form_output.html").render(data=json.loads(data))
- pdf_bytes = html_to_pdf(html)
- if not filename.endswith(".pdf"):
- filename += ".pdf"
- return upload_file_from_bytes(filename, pdf_bytes, "application/pdf")
-
-
-def html_to_pdf(html: str) -> bytes:
- from playwright.sync_api import sync_playwright
-
- with sync_playwright() as p:
- browser = p.chromium.launch()
- page = browser.new_page()
- page.set_content(html)
- with tempfile.NamedTemporaryFile(suffix=".pdf") as outfile:
- page.pdf(path=outfile.name, format="A4")
- ret = outfile.read()
- browser.close()
-
- return ret
-
-
-class LLMTools(Enum):
- json_to_pdf = (
- json_to_pdf,
- "Save JSON as PDF",
- {
- "type": "function",
- "function": {
- "name": json_to_pdf.__name__,
- "description": "Save JSON data to PDF",
- "parameters": {
- "type": "object",
- "properties": {
- "filename": {
- "type": "string",
- "description": "A short but descriptive filename for the PDF",
- },
- "data": {
- "type": "string",
- "description": "The JSON data to write to the PDF",
- },
- },
- "required": ["filename", "data"],
- },
- },
- },
- )
- # send_reply_buttons = (print, "Send back reply buttons to the user.", {})
-
- def __new__(cls, fn: typing.Callable, label: str, spec: dict):
- obj = object.__new__(cls)
- obj._value_ = fn.__name__
- obj.fn = fn
- obj.label = label
- obj.spec = spec
- return obj
-
- # def __init__(self, *args, **kwargs):
- # self._value_ = self.name
diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py
index 0679f21ac..4ab64fe4e 100644
--- a/daras_ai_v2/language_model.py
+++ b/daras_ai_v2/language_model.py
@@ -24,7 +24,7 @@
from daras_ai.image_input import gs_url_to_uri, bytes_to_cv2_img, cv2_img_to_bytes
from daras_ai_v2.asr import get_google_auth_session
from daras_ai_v2.exceptions import raise_for_status, UserError
-from daras_ai_v2.functions import LLMTools
+from functions.recipe_functions import LLMTools
from daras_ai_v2.gpu_server import call_celery_task
from daras_ai_v2.text_splitter import (
default_length_function,
diff --git a/daras_ai_v2/prompt_vars.py b/daras_ai_v2/prompt_vars.py
index b6c2b4795..4ab9878a2 100644
--- a/daras_ai_v2/prompt_vars.py
+++ b/daras_ai_v2/prompt_vars.py
@@ -1,3 +1,5 @@
+import json
+import typing
from datetime import datetime
from types import SimpleNamespace
@@ -8,36 +10,105 @@
import gooey_ui as st
-def prompt_vars_widget(*keys: str, variables_key: str = "variables"):
+def variables_input(
+ *,
+ template_keys: typing.Iterable[str],
+ label: str = "###### âĨ Variables",
+ key: str = "variables",
+ allow_add: bool = False,
+):
+ from daras_ai_v2.workflow_url_input import del_button
+
# find all variables in the prompts
env = jinja2.sandbox.SandboxedEnvironment()
- template_vars = set()
+ template_var_names = set()
err = None
- for k in keys:
+ for k in template_keys:
try:
parsed = env.parse(st.session_state.get(k, ""))
except jinja2.exceptions.TemplateSyntaxError as e:
err = e
else:
- template_vars |= jinja2.meta.find_undeclared_variables(parsed)
+ template_var_names |= jinja2.meta.find_undeclared_variables(parsed)
+
+ old_vars = st.session_state.get(key, {})
+
+ var_add_key = f"--{key}:add_btn"
+ var_name_key = f"--{key}:add_name"
+ if st.session_state.pop(var_add_key, None):
+ if var_name := st.session_state.pop(var_name_key, None):
+ old_vars[var_name] = ""
- # don't mistake globals for vars
- template_vars -= set(context_globals().keys())
+ all_var_names = (
+ (template_var_names | set(old_vars))
+ - set(context_globals().keys()) # dont show global context variables
+ - set(st.session_state.keys()) # dont show other session state variables
+ )
- if not (template_vars or err):
- return
+ st.session_state[key] = new_vars = {}
+ title_shown = False
+ for name in sorted(all_var_names):
+ var_key = f"--{key}:{name}"
- st.write("###### âĨ Variables")
- old_state = st.session_state.get(variables_key, {})
- new_state = {}
- for name in sorted(template_vars):
- if name in st.session_state:
+ del_key = f"--{var_key}:del"
+ if st.session_state.get(del_key, None):
continue
- var_key = f"__{variables_key}_{name}"
- st.session_state.setdefault(var_key, old_state.get(name, ""))
- new_state[name] = st.text_area("`" + name + "`", key=var_key, height=300)
- st.session_state[variables_key] = new_state
+
+ if not title_shown:
+ st.write(label)
+ title_shown = True
+
+ col1, col2 = st.columns([11, 1], responsive=False)
+ with col1:
+ value = old_vars.get(name)
+ try:
+ new_text_value = st.session_state[var_key]
+ except KeyError:
+ if value is None:
+ value = ""
+ is_json = isinstance(value, (dict, list))
+ if is_json:
+ value = json.dumps(value, indent=2)
+ st.session_state[var_key] = str(value)
+ else:
+ try:
+ value = json.loads(new_text_value)
+ is_json = isinstance(value, (dict, list))
+ if not is_json:
+ value = new_text_value
+ except json.JSONDecodeError:
+ is_json = False
+ value = new_text_value
+ new_vars[name] = value
+
+ st.text_area(
+ "**```" + name + "```**" + (" (JSON)" if is_json else ""),
+ key=var_key,
+ height=300,
+ )
+ if name not in template_var_names:
+ with col2, st.div(className="pt-3 mt-4"):
+ del_button(key=del_key)
+
+ if allow_add:
+ if not title_shown:
+ st.write(label)
+ st.newline()
+ col1, col2, _ = st.columns([6, 2, 4], responsive=False)
+ with col1:
+ with st.div(style=dict(fontFamily="var(--bs-font-monospace)")):
+ st.text_input(
+ "",
+ key=var_name_key,
+ placeholder="my_var_name",
+ )
+ with col2:
+ st.button(
+ ' Add',
+ key=var_add_key,
+ type="tertiary",
+ )
if err:
st.error(f"{type(err).__qualname__}: {err.message}")
diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py
index bec46ad4f..58878a9af 100644
--- a/daras_ai_v2/settings.py
+++ b/daras_ai_v2/settings.py
@@ -62,6 +62,7 @@
"embeddings",
"handles",
"payments",
+ "functions",
]
MIDDLEWARE = [
diff --git a/daras_ai_v2/workflow_url_input.py b/daras_ai_v2/workflow_url_input.py
index 76a7e6142..3660ba358 100644
--- a/daras_ai_v2/workflow_url_input.py
+++ b/daras_ai_v2/workflow_url_input.py
@@ -23,7 +23,7 @@ def workflow_url_input(
current_user: AppUser | None = None,
allow_none: bool = False,
) -> tuple[typing.Type[BasePage], SavedRun, PublishedRun | None] | None:
- added_options = init_workflow_selector(internal_state, key)
+ init_workflow_selector(internal_state, key)
col1, col2, col3, col4 = st.columns([9, 1, 1, 1], responsive=False)
if not internal_state.get("workflow") and internal_state.get("url"):
@@ -40,7 +40,7 @@ def workflow_url_input(
internal_state["workflow"] = page_cls.workflow
with col1:
options = get_published_run_options(page_cls, current_user=current_user)
- options.update(added_options)
+ options.update(internal_state.get("--added_workflows", {}))
with st.div(className="pt-1"):
url = st.selectbox(
"",
@@ -96,7 +96,7 @@ def del_button(key: str):
def init_workflow_selector(
internal_state: dict,
key: str,
-) -> dict:
+):
if st.session_state.get(key + ":edit-done"):
st.session_state.pop(key + ":edit-mode", None)
st.session_state.pop(key + ":edit-done", None)
@@ -109,7 +109,7 @@ def init_workflow_selector(
try:
_, sr, pr = url_to_runs(str(internal_state["url"]))
except Exception:
- return {}
+ return
workflow = sr.workflow
page_cls = Workflow(workflow).page_cls
@@ -122,9 +122,7 @@ def init_workflow_selector(
internal_state["workflow"] = workflow
internal_state["url"] = url
- return {url: title}
-
- return {}
+ internal_state.setdefault("--added_workflows", {})[url] = title
def url_to_runs(
diff --git a/functions/__init__.py b/functions/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/functions/admin.py b/functions/admin.py
new file mode 100644
index 000000000..977ee15ef
--- /dev/null
+++ b/functions/admin.py
@@ -0,0 +1,9 @@
+from django.contrib import admin
+
+from functions.models import CalledFunction
+
+
+# Register your models here.
+@admin.register(CalledFunction)
+class CalledFunctionAdmin(admin.ModelAdmin):
+ autocomplete_fields = ["saved_run", "function_run"]
diff --git a/functions/apps.py b/functions/apps.py
new file mode 100644
index 000000000..3c317752d
--- /dev/null
+++ b/functions/apps.py
@@ -0,0 +1,6 @@
+from django.apps import AppConfig
+
+
+class FunctionsConfig(AppConfig):
+ default_auto_field = "django.db.models.BigAutoField"
+ name = "functions"
diff --git a/function-executor.js b/functions/executor.js
similarity index 74%
rename from function-executor.js
rename to functions/executor.js
index 8bcc63dd6..097cf8123 100644
--- a/function-executor.js
+++ b/functions/executor.js
@@ -1,17 +1,21 @@
+//
+// To update this, run:
+// deployctl deploy --include functions/executor.js functions/executor.js --prod
+// (Exclude --prod when testing in development)
+//
Deno.serve(async (req) => {
if (!isAuthenticated(req)) {
return new Response("Unauthorized", { status: 401 });
}
let logs = captureConsole();
- let code = await req.json();
+ let { code, variables } = await req.json();
let status, response;
try {
- let Deno = undefined; // Deno should not available to user code
- let retval = eval(code);
+ let retval = isolatedEval(code, variables);
if (retval instanceof Function) {
- retval = retval();
+ retval = retval(variables);
}
if (retval instanceof Promise) {
retval = await retval;
@@ -27,6 +31,14 @@ Deno.serve(async (req) => {
return new Response(body, { status });
});
+function isolatedEval(code, variables) {
+ // Hide global objects
+ let Deno = undefined;
+ let globalThis = undefined;
+ let window = undefined;
+ return eval(code);
+}
+
function isAuthenticated(req) {
let authorization = req.headers.get("Authorization");
if (!authorization) return false;
diff --git a/functions/migrations/0001_initial.py b/functions/migrations/0001_initial.py
new file mode 100644
index 000000000..b04707d16
--- /dev/null
+++ b/functions/migrations/0001_initial.py
@@ -0,0 +1,26 @@
+# Generated by Django 4.2.7 on 2024-07-05 13:44
+
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+ initial = True
+
+ dependencies = [
+ ('bots', '0076_alter_workflowmetadata_default_image_and_more'),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='CalledFunction',
+ fields=[
+ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+ ('trigger', models.IntegerField(choices=[(1, 'Pre'), (2, 'Post')])),
+ ('created_at', models.DateTimeField(auto_now_add=True)),
+ ('function_run', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='called_by_runs', to='bots.savedrun')),
+ ('saved_run', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='called_functions', to='bots.savedrun')),
+ ],
+ ),
+ ]
diff --git a/functions/migrations/__init__.py b/functions/migrations/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/functions/models.py b/functions/models.py
new file mode 100644
index 000000000..0901fb109
--- /dev/null
+++ b/functions/models.py
@@ -0,0 +1,63 @@
+import typing
+
+from django.db import models
+from pydantic import BaseModel, Field
+
+from daras_ai_v2.custom_enum import GooeyEnum
+from daras_ai_v2.pydantic_validation import FieldHttpUrl
+
+
+class _TriggerData(typing.NamedTuple):
+ label: str
+ db_value: int
+
+
+class FunctionTrigger(_TriggerData, GooeyEnum):
+ pre = _TriggerData(label="Pre", db_value=1)
+ post = _TriggerData(label="Post", db_value=2)
+
+
+class RecipeFunction(BaseModel):
+ url: FieldHttpUrl = Field(
+ title="URL",
+ description="The URL of the [function](https://gooey.ai/functions) to call.",
+ )
+ trigger: FunctionTrigger.api_choices = Field(
+ title="Trigger",
+ description="When to run this function. `pre` runs before the recipe, `post` runs after the recipe.",
+ )
+
+
+class CalledFunctionResponse(BaseModel):
+ url: str
+ trigger: FunctionTrigger.api_choices
+ return_value: typing.Any
+
+ @classmethod
+ def from_db(cls, called_fn: "CalledFunction") -> "CalledFunctionResponse":
+ return cls(
+ url=called_fn.function_run.get_app_url(),
+ trigger=FunctionTrigger.from_db(called_fn.trigger).name,
+ return_value=called_fn.function_run.state.get("return_value"),
+ )
+
+
+class CalledFunction(models.Model):
+ saved_run = models.ForeignKey(
+ "bots.SavedRun",
+ on_delete=models.CASCADE,
+ related_name="called_functions",
+ )
+ function_run = models.ForeignKey(
+ "bots.SavedRun",
+ on_delete=models.CASCADE,
+ related_name="called_by_runs",
+ )
+ trigger = models.IntegerField(
+ choices=FunctionTrigger.db_choices(),
+ )
+
+ created_at = models.DateTimeField(auto_now_add=True)
+
+ def __str__(self):
+ return f"{self.saved_run} -> {self.function_run} ({self.trigger})"
diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py
new file mode 100644
index 000000000..1ba22be2a
--- /dev/null
+++ b/functions/recipe_functions.py
@@ -0,0 +1,216 @@
+import json
+import tempfile
+import typing
+from enum import Enum
+
+from pydantic import BaseModel
+
+import gooey_ui as st
+from app_users.models import AppUser
+from daras_ai.image_input import upload_file_from_bytes
+from daras_ai_v2.enum_selector_widget import enum_selector
+from daras_ai_v2.fastapi_tricks import extract_model_fields
+from daras_ai_v2.field_render import field_title_desc
+from daras_ai_v2.settings import templates
+from functions.models import CalledFunction, FunctionTrigger
+
+if typing.TYPE_CHECKING:
+ from bots.models import SavedRun
+
+
+def call_recipe_functions(
+ *,
+ saved_run: "SavedRun",
+ current_user: AppUser,
+ request_model: typing.Type[BaseModel],
+ response_model: typing.Type[BaseModel],
+ state: dict,
+ trigger: FunctionTrigger,
+):
+ from daras_ai_v2.workflow_url_input import url_to_runs
+ from gooeysite.bg_db_conn import get_celery_result_db_safe
+
+ request = request_model.parse_obj(state)
+
+ functions = getattr(request, "functions", None) or []
+ functions = [fun for fun in functions if fun.trigger == trigger.name]
+ if not functions:
+ return
+ variables = state.setdefault("variables", {})
+
+ yield f"Running {trigger.name} hooks..."
+
+ for fun in functions:
+ # run the function
+ page_cls, sr, pr = url_to_runs(fun.url)
+ result, sr = sr.submit_api_call(
+ current_user=current_user,
+ request_body=dict(
+ variables=sr.state.get("variables", {})
+ | variables
+ | dict(
+ request=json.loads(
+ request.json(exclude_unset=True, exclude={"variables"})
+ ),
+ response=extract_model_fields(response_model, state),
+ ),
+ ),
+ )
+
+ CalledFunction.objects.create(
+ saved_run=saved_run, function_run=sr, trigger=trigger.db_value
+ )
+
+ # wait for the result if its a pre request function
+ if trigger == FunctionTrigger.post:
+ continue
+ get_celery_result_db_safe(result)
+ sr.refresh_from_db()
+ # if failed, raise error
+ if sr.error_msg:
+ raise RuntimeError(sr.error_msg)
+
+ # save the output from the function
+ return_value = sr.state.get("return_value")
+ if return_value is None:
+ continue
+ if isinstance(return_value, dict):
+ for k, v in return_value.items():
+ if k in request_model.__fields__ or k in response_model.__fields__:
+ state[k] = v
+ else:
+ variables[k] = v
+ else:
+ variables["return_value"] = return_value
+
+
+def render_called_functions(*, saved_run: "SavedRun", trigger: FunctionTrigger):
+ from recipes.Functions import FunctionsPage
+ from daras_ai_v2.breadcrumbs import get_title_breadcrumbs
+
+ if not is_functions_enabled():
+ return
+ qs = saved_run.called_functions.filter(trigger=trigger.db_value)
+ if not qs.exists():
+ return
+ for called_fn in qs:
+ tb = get_title_breadcrumbs(
+ FunctionsPage,
+ called_fn.function_run,
+ called_fn.function_run.parent_published_run(),
+ )
+ title = (tb.published_title and tb.published_title.title) or tb.h1_title
+ st.write(f"###### 𧊠Called [{title}]({called_fn.function_run.get_app_url()})")
+ return_value = called_fn.function_run.state.get("return_value")
+ if return_value is not None:
+ st.json(return_value)
+ st.newline()
+
+
+def is_functions_enabled(key="functions") -> bool:
+ return bool(st.session_state.get(f"--enable-{key}"))
+
+
+def functions_input(current_user: AppUser, key="functions"):
+ from recipes.BulkRunner import list_view_editor
+ from daras_ai_v2.base import BasePage
+
+ def render_function_input(list_key: str, del_key: str, d: dict):
+ from daras_ai_v2.workflow_url_input import workflow_url_input
+ from recipes.Functions import FunctionsPage
+
+ col1, col2 = st.columns([2, 10], responsive=False)
+ with col1:
+ col1.node.props["className"] += " pt-1"
+ d["trigger"] = enum_selector(
+ enum_cls=FunctionTrigger,
+ use_selectbox=True,
+ key=list_key + ":trigger",
+ value=d.get("trigger"),
+ )
+ with col2:
+ workflow_url_input(
+ page_cls=FunctionsPage,
+ key=list_key + ":url",
+ internal_state=d,
+ del_key=del_key,
+ current_user=current_user,
+ )
+
+ if st.checkbox(
+ f"##### {field_title_desc(BasePage.RequestModel, key)}",
+ key=f"--enable-{key}",
+ value=key in st.session_state,
+ ):
+ st.session_state.setdefault(key, [{}])
+ list_view_editor(
+ add_btn_label="â Add Function",
+ key=key,
+ render_inputs=render_function_input,
+ )
+ st.write("---")
+ else:
+ st.session_state.pop(key, None)
+
+
+def json_to_pdf(filename: str, data: str) -> str:
+ html = templates.get_template("form_output.html").render(data=json.loads(data))
+ pdf_bytes = html_to_pdf(html)
+ if not filename.endswith(".pdf"):
+ filename += ".pdf"
+ return upload_file_from_bytes(filename, pdf_bytes, "application/pdf")
+
+
+def html_to_pdf(html: str) -> bytes:
+ from playwright.sync_api import sync_playwright
+
+ with sync_playwright() as p:
+ browser = p.chromium.launch()
+ page = browser.new_page()
+ page.set_content(html)
+ with tempfile.NamedTemporaryFile(suffix=".pdf") as outfile:
+ page.pdf(path=outfile.name, format="A4")
+ ret = outfile.read()
+ browser.close()
+
+ return ret
+
+
+class LLMTools(Enum):
+ json_to_pdf = (
+ json_to_pdf,
+ "Save JSON as PDF",
+ {
+ "type": "function",
+ "function": {
+ "name": json_to_pdf.__name__,
+ "description": "Save JSON data to PDF",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "filename": {
+ "type": "string",
+ "description": "A short but descriptive filename for the PDF",
+ },
+ "data": {
+ "type": "string",
+ "description": "The JSON data to write to the PDF",
+ },
+ },
+ "required": ["filename", "data"],
+ },
+ },
+ },
+ )
+ # send_reply_buttons = (print, "Send back reply buttons to the user.", {})
+
+ def __new__(cls, fn: typing.Callable, label: str, spec: dict):
+ obj = object.__new__(cls)
+ obj._value_ = fn.__name__
+ obj.fn = fn
+ obj.label = label
+ obj.spec = spec
+ return obj
+
+ # def __init__(self, *args, **kwargs):
+ # self._value_ = self.name
diff --git a/functions/tests.py b/functions/tests.py
new file mode 100644
index 000000000..7ce503c2d
--- /dev/null
+++ b/functions/tests.py
@@ -0,0 +1,3 @@
+from django.test import TestCase
+
+# Create your tests here.
diff --git a/functions/views.py b/functions/views.py
new file mode 100644
index 000000000..91ea44a21
--- /dev/null
+++ b/functions/views.py
@@ -0,0 +1,3 @@
+from django.shortcuts import render
+
+# Create your views here.
diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py
index fdeb5183f..c87e3a90d 100644
--- a/gooey_ui/components/__init__.py
+++ b/gooey_ui/components/__init__.py
@@ -551,12 +551,13 @@ def anchor(
form_submit_button = button
-def expander(label: str, *, expanded: bool = False, **props):
+def expander(label: str, *, expanded: bool = False, key: str = None, **props):
node = state.RenderTreeNode(
name="expander",
props=dict(
label=dedent(label),
open=expanded,
+ name=key or md5_values(label, expanded, props),
**props,
),
)
diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py
index b1a60157a..3712a6b8f 100644
--- a/recipes/CompareLLM.py
+++ b/recipes/CompareLLM.py
@@ -17,7 +17,7 @@
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
-from daras_ai_v2.prompt_vars import prompt_vars_widget, render_prompt_vars
+from daras_ai_v2.prompt_vars import variables_input, render_prompt_vars
DEFAULT_COMPARE_LM_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/fef06d86-1f70-11ef-b8ee-02420a00015b/LLMs.jpg"
@@ -50,8 +50,6 @@ class RequestModel(BasePage.RequestModel):
max_tokens: int | None
sampling_temperature: float | None
- variables: dict[str, typing.Any] | None
-
response_format_type: ResponseFormatType = Field(
None,
title="Response Format",
@@ -81,7 +79,6 @@ def render_form_v2(self):
help="What a fine day..",
height=300,
)
- prompt_vars_widget("input_prompt")
enum_multiselect(
LargeLanguageModels,
diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py
index 99776b1a1..b864d6700 100644
--- a/recipes/DocSearch.py
+++ b/recipes/DocSearch.py
@@ -22,7 +22,7 @@
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
-from daras_ai_v2.prompt_vars import prompt_vars_widget, render_prompt_vars
+from daras_ai_v2.prompt_vars import variables_input, render_prompt_vars
from daras_ai_v2.query_generator import generate_final_search_query
from daras_ai_v2.search_ref import (
SearchReference,
@@ -78,8 +78,6 @@ class RequestModel(DocSearchRequest, BasePage.RequestModel):
citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None
- variables: dict[str, typing.Any] | None
-
class ResponseModel(BaseModel):
output_text: list[str]
@@ -94,7 +92,6 @@ def get_example_preferred_fields(self, state: dict) -> list[str]:
def render_form_v2(self):
st.text_area("#### Search Query", key="search_query")
bulk_documents_uploader("#### Documents")
- prompt_vars_widget("task_instructions", "query_instructions")
def validate_form_v2(self):
search_query = st.session_state.get("search_query", "").strip()
diff --git a/recipes/Functions.py b/recipes/Functions.py
index c0ea22de5..b2b1c23dd 100644
--- a/recipes/Functions.py
+++ b/recipes/Functions.py
@@ -1,4 +1,3 @@
-import json
import typing
import requests
@@ -9,6 +8,7 @@
from daras_ai_v2 import settings
from daras_ai_v2.base import BasePage
from daras_ai_v2.field_render import field_title_desc
+from daras_ai_v2.prompt_vars import variables_input
class ConsoleLogs(BaseModel):
@@ -27,6 +27,11 @@ class RequestModel(BaseModel):
title="Code",
description="The JS code to be executed.",
)
+ variables: dict[str, typing.Any] = Field(
+ {},
+ title="Variables",
+ description="Variables to be used in the code",
+ )
class ResponseModel(BaseModel):
return_value: typing.Any = Field(
@@ -51,10 +56,11 @@ def run_v2(
response: "FunctionsPage.ResponseModel",
) -> typing.Iterator[str | None]:
yield "Running your code..."
+ # this will run functions/executor.js in deno deploy
r = requests.post(
settings.DENO_FUNCTIONS_URL,
headers={"Authorization": f"Basic {settings.DENO_FUNCTIONS_AUTH_TOKEN}"},
- json=request.code,
+ json=dict(code=request.code, variables=request.variables or {}),
)
data = r.json()
response.logs = data.get("logs")
@@ -67,9 +73,11 @@ def render_form_v2(self):
st.text_area(
"##### " + field_title_desc(self.RequestModel, "code"),
key="code",
- height=500,
)
+ def render_variables(self):
+ variables_input(template_keys=["code"], allow_add=True)
+
def render_output(self):
if error := st.session_state.get("error"):
with st.tag("pre", className="bg-danger bg-opacity-25"):
diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py
index 8328f9dc6..270e53f2f 100644
--- a/recipes/GoogleGPT.py
+++ b/recipes/GoogleGPT.py
@@ -17,7 +17,7 @@
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
-from daras_ai_v2.prompt_vars import render_prompt_vars, prompt_vars_widget
+from daras_ai_v2.prompt_vars import render_prompt_vars, variables_input
from daras_ai_v2.query_generator import generate_final_search_query
from daras_ai_v2.search_ref import (
SearchReference,
@@ -100,8 +100,6 @@ class RequestModel(GoogleSearchMixin, BasePage.RequestModel):
"dense_weight"
].field_info
- variables: dict[str, typing.Any] | None
-
class ResponseModel(BaseModel):
output_text: list[str]
@@ -115,7 +113,6 @@ class ResponseModel(BaseModel):
def render_form_v2(self):
st.text_area("#### Google Search Query", key="search_query")
st.text_input("Search on a specific site *(optional)*", key="site_filter")
- prompt_vars_widget("task_instructions", "query_instructions")
def validate_form_v2(self):
assert st.session_state.get(
diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py
index 099b1d72b..2f81abc43 100644
--- a/recipes/VideoBots.py
+++ b/recipes/VideoBots.py
@@ -52,7 +52,6 @@
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.exceptions import UserError
from daras_ai_v2.field_render import field_title_desc, field_desc, field_title
-from daras_ai_v2.functions import LLMTools
from daras_ai_v2.glossary import validate_glossary_document
from daras_ai_v2.language_model import (
run_language_model,
@@ -71,7 +70,7 @@
from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModel
from daras_ai_v2.lipsync_settings_widgets import lipsync_settings
from daras_ai_v2.loom_video_widget import youtube_video
-from daras_ai_v2.prompt_vars import render_prompt_vars, prompt_vars_widget
+from daras_ai_v2.prompt_vars import render_prompt_vars
from daras_ai_v2.pydantic_validation import FieldHttpUrl
from daras_ai_v2.query_generator import generate_final_search_query
from daras_ai_v2.query_params import gooey_get_query_params
@@ -89,6 +88,7 @@
text_to_speech_provider_selector,
)
from daras_ai_v2.vector_search import DocSearchRequest
+from functions.recipe_functions import LLMTools
from gooey_ui import RedirectException
from recipes.DocSearch import (
get_top_k_references,
@@ -247,8 +247,6 @@ class RequestModelBase(BasePage.RequestModel):
LipsyncModel.Wav2Lip.name
)
- variables: dict[str, typing.Any] | None
-
tools: list[LLMTools] | None = Field(
title="đ ī¸ Tools",
description="Give your copilot superpowers by giving it access to tools. Powered by [Function calling](https://platform.openai.com/docs/guides/function-calling).",
@@ -331,9 +329,6 @@ def render_form_v2(self):
key="bot_script",
height=300,
)
- prompt_vars_widget(
- "bot_script",
- )
enum_selector(
LargeLanguageModels,
@@ -519,9 +514,6 @@ def render_settings(self):
key="task_instructions",
height=300,
)
- prompt_vars_widget(
- "task_instructions",
- )
citation_style_selector()
st.checkbox("đ Shorten Citation URLs", key="use_url_shortener")
@@ -604,7 +596,10 @@ def render_output(self):
references = st.session_state.get("references", [])
if not references:
return
- with st.expander("đââī¸ Sources"):
+ key = "sources-expander"
+ with st.expander("đââī¸ Sources", key=key):
+ if not st.session_state.get(key):
+ return
for idx, ref in enumerate(references):
st.write(f"**{idx + 1}**. [{ref['title']}]({ref['url']})")
text_output(
diff --git a/routers/api.py b/routers/api.py
index 979af1c3a..816837dd9 100644
--- a/routers/api.py
+++ b/routers/api.py
@@ -19,22 +19,21 @@
from starlette.datastructures import UploadFile
from starlette.requests import Request
-from celeryapp.tasks import auto_recharge
-from daras_ai_v2.auto_recharge import user_should_auto_recharge
import gooey_ui as st
from app_users.models import AppUser
from auth.token_authentication import api_auth_header
from bots.models import RetentionPolicy
+from celeryapp.tasks import auto_recharge
from daras_ai.image_input import upload_file_from_bytes
from daras_ai_v2 import settings
from daras_ai_v2.all_pages import all_api_pages
+from daras_ai_v2.auto_recharge import user_should_auto_recharge
from daras_ai_v2.base import (
BasePage,
- StateKeys,
RecipeRunState,
)
from daras_ai_v2.fastapi_tricks import fastapi_request_form
-from daras_ai_v2.ratelimits import ensure_rate_limits
+from functions.models import CalledFunctionResponse
from gooeysite.bg_db_conn import get_celery_result_db_safe
from routers.account import AccountTabs
@@ -119,9 +118,14 @@ def script_to_api(page_cls: typing.Type[BasePage]):
settings=(RunSettings, RunSettings()),
)
# encapsulate the response model with the ApiResponseModel
+ response_output_model = create_model(
+ page_cls.__name__ + "Output",
+ __base__=page_cls.ResponseModel,
+ called_functions=(list[CalledFunctionResponse], None),
+ )
response_model = create_model(
page_cls.__name__ + "Response",
- __base__=ApiResponseModelV2[page_cls.ResponseModel],
+ __base__=ApiResponseModelV2[response_output_model],
)
common_errs = {
@@ -244,7 +248,7 @@ def run_api_form(
response_model = create_model(
page_cls.__name__ + "StatusResponse",
- __base__=AsyncStatusResponseModelV3[page_cls.ResponseModel],
+ __base__=AsyncStatusResponseModelV3[response_output_model],
)
@app.get(
@@ -281,7 +285,7 @@ def get_run_status(
status = self.get_run_state(sr.to_dict())
ret |= {"detail": sr.run_status or "", "status": status}
if status == RecipeRunState.completed and sr.state:
- ret |= {"output": sr.state}
+ ret |= {"output": sr.api_output()}
if sr.retention_policy == RetentionPolicy.delete:
sr.state = {}
sr.save(update_fields=["state"])
@@ -419,20 +423,18 @@ def build_api_response(
# wait for the result
get_celery_result_db_safe(result)
sr = page.run_doc_sr(run_id, uid)
- state = sr.to_dict()
if sr.retention_policy == RetentionPolicy.delete:
sr.state = {}
sr.save(update_fields=["state"])
# check for errors
- err_msg = state.get(StateKeys.error_msg)
- if err_msg:
+ if sr.error_msg:
raise HTTPException(
status_code=500,
detail={
"id": run_id,
"url": web_url,
"created_at": sr.created_at.isoformat(),
- "error": err_msg,
+ "error": sr.error_msg,
},
)
else:
@@ -441,7 +443,7 @@ def build_api_response(
"id": run_id,
"url": web_url,
"created_at": sr.created_at.isoformat(),
- "output": state,
+ "output": sr.api_output(),
}
diff --git a/scripts/deno-deploy.sh b/scripts/deno-deploy.sh
deleted file mode 100755
index d97b110c9..000000000
--- a/scripts/deno-deploy.sh
+++ /dev/null
@@ -1,5 +0,0 @@
-#!/usr/bin/env bash
-
-set -ex
-
-deployctl deploy --include function-executor.js function-executor.js