diff --git a/bots/migrations/0076_alter_workflowmetadata_default_image_and_more.py b/bots/migrations/0076_alter_workflowmetadata_default_image_and_more.py new file mode 100644 index 000000000..8be0bb4ed --- /dev/null +++ b/bots/migrations/0076_alter_workflowmetadata_default_image_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 4.2.7 on 2024-07-05 13:44 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0075_alter_publishedrun_workflow_alter_savedrun_workflow_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='workflowmetadata', + name='default_image', + field=models.URLField(blank=True, default='', help_text='Image shown on explore page'), + ), + migrations.AlterField( + model_name='workflowmetadata', + name='help_url', + field=models.URLField(blank=True, default='', help_text='(Not implemented)'), + ), + migrations.AlterField( + model_name='workflowmetadata', + name='meta_keywords', + field=models.JSONField(blank=True, default=list, help_text='(Not implemented)'), + ), + migrations.AlterField( + model_name='workflowmetadata', + name='short_title', + field=models.TextField(help_text='Title used in breadcrumbs'), + ), + ] diff --git a/bots/models.py b/bots/models.py index fbb124726..a2c5715a1 100644 --- a/bots/models.py +++ b/bots/models.py @@ -17,6 +17,7 @@ from bots.custom_fields import PostgresJSONEncoder, CustomURLField from daras_ai_v2.crypto import get_random_doc_id from daras_ai_v2.language_model import format_chat_entry +from functions.models import CalledFunction, CalledFunctionResponse from gooeysite.custom_create import get_or_create_lazy if typing.TYPE_CHECKING: @@ -370,6 +371,15 @@ def get_creator(self) -> AppUser | None: def open_in_gooey(self): return open_in_new_tab(self.get_app_url(), label=self.get_app_url()) + def api_output(self, state: dict = None) -> dict: + state = state or self.state + if self.state.get("functions"): + state["called_functions"] = [ + CalledFunctionResponse.from_db(called_fn) + for called_fn in self.called_functions.all() + ] + return state + def _parse_dt(dt) -> datetime.datetime | None: if isinstance(dt, str): diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index f535b6ff5..dbdb356b3 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -21,6 +21,7 @@ from daras_ai_v2.auto_recharge import auto_recharge_user from daras_ai_v2.base import StateKeys, BasePage from daras_ai_v2.exceptions import UserError +from daras_ai_v2.fastapi_tricks import extract_model_fields from daras_ai_v2.redis_cache import redis_lock from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email from daras_ai_v2.settings import templates @@ -82,13 +83,8 @@ def save(done=False): } output = ( status - | - # extract outputs from local state - { - k: v - for k, v in st.session_state.items() - if k in page.ResponseModel.__fields__ - } + # extract outputs from session state + | extract_model_fields(page.ResponseModel, st.session_state) | extra_output ) # send outputs to ui @@ -97,7 +93,7 @@ def save(done=False): page.dump_state_to_sr(st.session_state | output, sr) try: - gen = page.run(st.session_state) + gen = page.main(sr, st.session_state) save() while True: # record time diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 6b641314d..cde7ecbe0 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -6,6 +6,7 @@ import uuid from copy import deepcopy, copy from enum import Enum +from functools import cached_property from itertools import pairwise from random import Random from time import sleep @@ -18,7 +19,7 @@ from fastapi import HTTPException from firebase_admin import auth from furl import furl -from pydantic import BaseModel +from pydantic import BaseModel, Field from sentry_sdk.tracing import ( TRANSACTION_SOURCE_ROUTE, ) @@ -48,11 +49,12 @@ from daras_ai_v2.db import ( ANONYMOUS_USER_COOKIE, ) -from daras_ai_v2.fastapi_tricks import get_route_path +from daras_ai_v2.fastapi_tricks import get_route_path, extract_model_fields from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.html_spinner_widget import html_spinner from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_preview_url import meta_preview_url +from daras_ai_v2.prompt_vars import variables_input from daras_ai_v2.query_params import ( gooey_get_query_params, ) @@ -64,6 +66,16 @@ from daras_ai_v2.user_date_widgets import ( render_local_dt_attrs, ) +from functions.models import ( + RecipeFunction, + FunctionTrigger, +) +from functions.recipe_functions import ( + functions_input, + call_recipe_functions, + is_functions_enabled, + render_called_functions, +) from gooey_ui import ( realtime_clear_subs, RedirectException, @@ -71,7 +83,6 @@ from gooey_ui.components.modal import Modal from gooey_ui.components.pills import pill from gooey_ui.pubsub import realtime_pull -from gooeysite.custom_create import get_or_create_lazy from routers.account import AccountTabs from routers.root import RecipeTabs @@ -117,7 +128,28 @@ class BasePage: explore_image: str = None - RequestModel: typing.Type[BaseModel] + template_keys: typing.Iterable[str] = ( + "task_instructions", + "query_instructions", + "keyword_instructions", + "input_prompt", + "bot_script", + "text_prompt", + "search_query", + "title", + ) + + class RequestModel(BaseModel): + functions: list[RecipeFunction] | None = Field( + None, + title="🧩 Functions", + ) + variables: dict[str, typing.Any] = Field( + None, + title="âŒĨ Variables", + description="Variables to be used as Jinja prompt templates and in functions as arguments", + ) + ResponseModel: typing.Type[BaseModel] price = settings.CREDITS_TO_DEDUCT_PER_RUN @@ -1310,12 +1342,18 @@ def render_run_cost(self): st.caption(ret, line_clamp=1, unsafe_allow_html=True) def _render_step_row(self): - with st.expander("**ℹī¸ Details**"): + key = "details-expander" + with st.expander("**ℹī¸ Details**", key=key): + if not st.session_state.get(key): + return col1, col2 = st.columns([1, 2]) with col1: self.render_description() with col2: placeholder = st.div() + render_called_functions( + saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre + ) try: self.render_steps() except NotImplementedError: @@ -1323,6 +1361,9 @@ def _render_step_row(self): else: with placeholder: st.write("##### đŸ‘Ŗ Steps") + render_called_functions( + saved_run=self.get_current_sr(), trigger=FunctionTrigger.post + ) def _render_help(self): placeholder = st.div() @@ -1338,9 +1379,13 @@ def _render_help(self): """ ) + key = "discord-expander" with st.expander( - f"**🙋đŸŊ‍♀ī¸ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**" + f"**🙋đŸŊ‍♀ī¸ Need more help? [Join our Discord]({settings.DISCORD_INVITE_URL})**", + key=key, ): + if not st.session_state.get(key): + return st.markdown( """
@@ -1353,6 +1398,27 @@ def _render_help(self): def render_usage_guide(self): raise NotImplementedError + def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]: + yield from call_recipe_functions( + saved_run=sr, + current_user=self.request.user, + request_model=self.RequestModel, + response_model=self.ResponseModel, + state=state, + trigger=FunctionTrigger.pre, + ) + + yield from self.run(state) + + yield from call_recipe_functions( + saved_run=sr, + current_user=self.request.user, + request_model=self.RequestModel, + response_model=self.ResponseModel, + state=state, + trigger=FunctionTrigger.post, + ) + def run(self, state: dict) -> typing.Iterator[str | None]: # initialize request and response request = self.RequestModel.parse_obj(state) @@ -1373,7 +1439,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: self.ResponseModel.validate(response) def run_v2( - self, request: BaseModel, response: BaseModel + self, request: RequestModel, response: BaseModel ) -> typing.Iterator[str | None]: raise NotImplementedError @@ -1400,8 +1466,11 @@ def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool): def _render_input_col(self): self.render_form_v2() + self.render_variables() + with st.expander("⚙ī¸ Settings"): self.render_settings() + submitted = self.render_submit_button() with st.div(style={"textAlign": "right"}): st.caption( @@ -1410,6 +1479,13 @@ def _render_input_col(self): ) return submitted + def render_variables(self): + st.write("---") + functions_input(self.request.user) + variables_input( + template_keys=self.template_keys, allow_add=is_functions_enabled() + ) + @classmethod def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState: if state.get(StateKeys.run_status): @@ -2058,7 +2134,8 @@ def get_example_response_body( run_id=run_id, uid=self.request.user and self.request.user.uid, ) - output = extract_model_fields(self.ResponseModel, state, include_all=True) + sr = self.get_current_sr() + output = sr.api_output(extract_model_fields(self.ResponseModel, state)) if as_async: return dict( run_id=run_id, @@ -2143,32 +2220,6 @@ def render_output_caption(): ) -def extract_model_fields( - model: typing.Type[BaseModel], - state: dict, - include_all: bool = False, - preferred_fields: list[str] = None, - diff_from: dict | None = None, -) -> dict: - """ - Include a field in result if: - - include_all is true - - field is required - - field is preferred - - diff_from is provided and field value differs from diff_from - """ - return { - field_name: state.get(field_name) - for field_name, field in model.__fields__.items() - if ( - include_all - or field.required - or (preferred_fields and field_name in preferred_fields) - or (diff_from and state.get(field_name) != diff_from.get(field_name)) - ) - } - - def extract_nested_str(obj) -> str: if isinstance(obj, str): return obj diff --git a/daras_ai_v2/custom_enum.py b/daras_ai_v2/custom_enum.py new file mode 100644 index 000000000..b9aacb843 --- /dev/null +++ b/daras_ai_v2/custom_enum.py @@ -0,0 +1,31 @@ +import typing +from enum import Enum + +import typing_extensions + +T = typing.TypeVar("T", bound="GooeyEnum") + + +class GooeyEnum(Enum): + @classmethod + def db_choices(cls): + return [(e.db_value, e.label) for e in cls] + + @classmethod + def from_db(cls, db_value) -> typing_extensions.Self: + for e in cls: + if e.db_value == db_value: + return e + raise ValueError(f"Invalid {cls.__name__} {db_value=}") + + @classmethod + @property + def api_choices(cls): + return typing.Literal[tuple(e.name for e in cls)] + + @classmethod + def from_api(cls, name: str) -> typing_extensions.Self: + for e in cls: + if e.name == name: + return e + raise ValueError(f"Invalid {cls.__name__} {name=}") diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index 00d6c9983..10ae86a7d 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -10,7 +10,7 @@ from daras_ai_v2.embedding_model import EmbeddingModels from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder -from daras_ai_v2.prompt_vars import prompt_vars_widget +from daras_ai_v2.prompt_vars import variables_input from daras_ai_v2.search_ref import CitationStyles _user_media_url_prefix = os.path.join( @@ -110,9 +110,6 @@ def query_instructions_widget(): key="query_instructions", height=300, ) - prompt_vars_widget( - "query_instructions", - ) def keyword_instructions_widget(): @@ -124,9 +121,6 @@ def keyword_instructions_widget(): key="keyword_instructions", height=300, ) - prompt_vars_widget( - "keyword_instructions", - ) def doc_extract_selector(current_user: AppUser | None): diff --git a/daras_ai_v2/fastapi_tricks.py b/daras_ai_v2/fastapi_tricks.py index 494393ab2..3d080e5be 100644 --- a/daras_ai_v2/fastapi_tricks.py +++ b/daras_ai_v2/fastapi_tricks.py @@ -5,6 +5,7 @@ from fastapi import Depends from fastapi.routing import APIRoute from furl import furl +from pydantic import BaseModel from starlette.requests import Request from daras_ai_v2 import settings @@ -64,3 +65,29 @@ def get_route_path(route_fn: typing.Callable, params: dict = None) -> str: from server import app return os.path.join(app.url_path_for(route_fn.__name__, **(params or {})), "") + + +def extract_model_fields( + model: typing.Type[BaseModel], + state: dict, + include_all: bool = True, + preferred_fields: list[str] = None, + diff_from: dict | None = None, +) -> dict: + """ + Include a field in result if: + - include_all is true + - field is required + - field is preferred + - diff_from is provided and field value differs from diff_from + """ + return { + field_name: state.get(field_name) + for field_name, field in model.__fields__.items() + if ( + include_all + or field.required + or (preferred_fields and field_name in preferred_fields) + or (diff_from and state.get(field_name) != diff_from.get(field_name)) + ) + } diff --git a/daras_ai_v2/functions.py b/daras_ai_v2/functions.py deleted file mode 100644 index 2c6c03348..000000000 --- a/daras_ai_v2/functions.py +++ /dev/null @@ -1,70 +0,0 @@ -import json -import tempfile -import typing -from enum import Enum - -from daras_ai.image_input import upload_file_from_bytes -from daras_ai_v2.settings import templates - - -def json_to_pdf(filename: str, data: str) -> str: - html = templates.get_template("form_output.html").render(data=json.loads(data)) - pdf_bytes = html_to_pdf(html) - if not filename.endswith(".pdf"): - filename += ".pdf" - return upload_file_from_bytes(filename, pdf_bytes, "application/pdf") - - -def html_to_pdf(html: str) -> bytes: - from playwright.sync_api import sync_playwright - - with sync_playwright() as p: - browser = p.chromium.launch() - page = browser.new_page() - page.set_content(html) - with tempfile.NamedTemporaryFile(suffix=".pdf") as outfile: - page.pdf(path=outfile.name, format="A4") - ret = outfile.read() - browser.close() - - return ret - - -class LLMTools(Enum): - json_to_pdf = ( - json_to_pdf, - "Save JSON as PDF", - { - "type": "function", - "function": { - "name": json_to_pdf.__name__, - "description": "Save JSON data to PDF", - "parameters": { - "type": "object", - "properties": { - "filename": { - "type": "string", - "description": "A short but descriptive filename for the PDF", - }, - "data": { - "type": "string", - "description": "The JSON data to write to the PDF", - }, - }, - "required": ["filename", "data"], - }, - }, - }, - ) - # send_reply_buttons = (print, "Send back reply buttons to the user.", {}) - - def __new__(cls, fn: typing.Callable, label: str, spec: dict): - obj = object.__new__(cls) - obj._value_ = fn.__name__ - obj.fn = fn - obj.label = label - obj.spec = spec - return obj - - # def __init__(self, *args, **kwargs): - # self._value_ = self.name diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 0679f21ac..4ab64fe4e 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -24,7 +24,7 @@ from daras_ai.image_input import gs_url_to_uri, bytes_to_cv2_img, cv2_img_to_bytes from daras_ai_v2.asr import get_google_auth_session from daras_ai_v2.exceptions import raise_for_status, UserError -from daras_ai_v2.functions import LLMTools +from functions.recipe_functions import LLMTools from daras_ai_v2.gpu_server import call_celery_task from daras_ai_v2.text_splitter import ( default_length_function, diff --git a/daras_ai_v2/prompt_vars.py b/daras_ai_v2/prompt_vars.py index b6c2b4795..4ab9878a2 100644 --- a/daras_ai_v2/prompt_vars.py +++ b/daras_ai_v2/prompt_vars.py @@ -1,3 +1,5 @@ +import json +import typing from datetime import datetime from types import SimpleNamespace @@ -8,36 +10,105 @@ import gooey_ui as st -def prompt_vars_widget(*keys: str, variables_key: str = "variables"): +def variables_input( + *, + template_keys: typing.Iterable[str], + label: str = "###### âŒĨ Variables", + key: str = "variables", + allow_add: bool = False, +): + from daras_ai_v2.workflow_url_input import del_button + # find all variables in the prompts env = jinja2.sandbox.SandboxedEnvironment() - template_vars = set() + template_var_names = set() err = None - for k in keys: + for k in template_keys: try: parsed = env.parse(st.session_state.get(k, "")) except jinja2.exceptions.TemplateSyntaxError as e: err = e else: - template_vars |= jinja2.meta.find_undeclared_variables(parsed) + template_var_names |= jinja2.meta.find_undeclared_variables(parsed) + + old_vars = st.session_state.get(key, {}) + + var_add_key = f"--{key}:add_btn" + var_name_key = f"--{key}:add_name" + if st.session_state.pop(var_add_key, None): + if var_name := st.session_state.pop(var_name_key, None): + old_vars[var_name] = "" - # don't mistake globals for vars - template_vars -= set(context_globals().keys()) + all_var_names = ( + (template_var_names | set(old_vars)) + - set(context_globals().keys()) # dont show global context variables + - set(st.session_state.keys()) # dont show other session state variables + ) - if not (template_vars or err): - return + st.session_state[key] = new_vars = {} + title_shown = False + for name in sorted(all_var_names): + var_key = f"--{key}:{name}" - st.write("###### âŒĨ Variables") - old_state = st.session_state.get(variables_key, {}) - new_state = {} - for name in sorted(template_vars): - if name in st.session_state: + del_key = f"--{var_key}:del" + if st.session_state.get(del_key, None): continue - var_key = f"__{variables_key}_{name}" - st.session_state.setdefault(var_key, old_state.get(name, "")) - new_state[name] = st.text_area("`" + name + "`", key=var_key, height=300) - st.session_state[variables_key] = new_state + + if not title_shown: + st.write(label) + title_shown = True + + col1, col2 = st.columns([11, 1], responsive=False) + with col1: + value = old_vars.get(name) + try: + new_text_value = st.session_state[var_key] + except KeyError: + if value is None: + value = "" + is_json = isinstance(value, (dict, list)) + if is_json: + value = json.dumps(value, indent=2) + st.session_state[var_key] = str(value) + else: + try: + value = json.loads(new_text_value) + is_json = isinstance(value, (dict, list)) + if not is_json: + value = new_text_value + except json.JSONDecodeError: + is_json = False + value = new_text_value + new_vars[name] = value + + st.text_area( + "**```" + name + "```**" + (" (JSON)" if is_json else ""), + key=var_key, + height=300, + ) + if name not in template_var_names: + with col2, st.div(className="pt-3 mt-4"): + del_button(key=del_key) + + if allow_add: + if not title_shown: + st.write(label) + st.newline() + col1, col2, _ = st.columns([6, 2, 4], responsive=False) + with col1: + with st.div(style=dict(fontFamily="var(--bs-font-monospace)")): + st.text_input( + "", + key=var_name_key, + placeholder="my_var_name", + ) + with col2: + st.button( + ' Add', + key=var_add_key, + type="tertiary", + ) if err: st.error(f"{type(err).__qualname__}: {err.message}") diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index bec46ad4f..58878a9af 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -62,6 +62,7 @@ "embeddings", "handles", "payments", + "functions", ] MIDDLEWARE = [ diff --git a/daras_ai_v2/workflow_url_input.py b/daras_ai_v2/workflow_url_input.py index 76a7e6142..3660ba358 100644 --- a/daras_ai_v2/workflow_url_input.py +++ b/daras_ai_v2/workflow_url_input.py @@ -23,7 +23,7 @@ def workflow_url_input( current_user: AppUser | None = None, allow_none: bool = False, ) -> tuple[typing.Type[BasePage], SavedRun, PublishedRun | None] | None: - added_options = init_workflow_selector(internal_state, key) + init_workflow_selector(internal_state, key) col1, col2, col3, col4 = st.columns([9, 1, 1, 1], responsive=False) if not internal_state.get("workflow") and internal_state.get("url"): @@ -40,7 +40,7 @@ def workflow_url_input( internal_state["workflow"] = page_cls.workflow with col1: options = get_published_run_options(page_cls, current_user=current_user) - options.update(added_options) + options.update(internal_state.get("--added_workflows", {})) with st.div(className="pt-1"): url = st.selectbox( "", @@ -96,7 +96,7 @@ def del_button(key: str): def init_workflow_selector( internal_state: dict, key: str, -) -> dict: +): if st.session_state.get(key + ":edit-done"): st.session_state.pop(key + ":edit-mode", None) st.session_state.pop(key + ":edit-done", None) @@ -109,7 +109,7 @@ def init_workflow_selector( try: _, sr, pr = url_to_runs(str(internal_state["url"])) except Exception: - return {} + return workflow = sr.workflow page_cls = Workflow(workflow).page_cls @@ -122,9 +122,7 @@ def init_workflow_selector( internal_state["workflow"] = workflow internal_state["url"] = url - return {url: title} - - return {} + internal_state.setdefault("--added_workflows", {})[url] = title def url_to_runs( diff --git a/functions/__init__.py b/functions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions/admin.py b/functions/admin.py new file mode 100644 index 000000000..977ee15ef --- /dev/null +++ b/functions/admin.py @@ -0,0 +1,9 @@ +from django.contrib import admin + +from functions.models import CalledFunction + + +# Register your models here. +@admin.register(CalledFunction) +class CalledFunctionAdmin(admin.ModelAdmin): + autocomplete_fields = ["saved_run", "function_run"] diff --git a/functions/apps.py b/functions/apps.py new file mode 100644 index 000000000..3c317752d --- /dev/null +++ b/functions/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class FunctionsConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "functions" diff --git a/function-executor.js b/functions/executor.js similarity index 74% rename from function-executor.js rename to functions/executor.js index 8bcc63dd6..097cf8123 100644 --- a/function-executor.js +++ b/functions/executor.js @@ -1,17 +1,21 @@ +// +// To update this, run: +// deployctl deploy --include functions/executor.js functions/executor.js --prod +// (Exclude --prod when testing in development) +// Deno.serve(async (req) => { if (!isAuthenticated(req)) { return new Response("Unauthorized", { status: 401 }); } let logs = captureConsole(); - let code = await req.json(); + let { code, variables } = await req.json(); let status, response; try { - let Deno = undefined; // Deno should not available to user code - let retval = eval(code); + let retval = isolatedEval(code, variables); if (retval instanceof Function) { - retval = retval(); + retval = retval(variables); } if (retval instanceof Promise) { retval = await retval; @@ -27,6 +31,14 @@ Deno.serve(async (req) => { return new Response(body, { status }); }); +function isolatedEval(code, variables) { + // Hide global objects + let Deno = undefined; + let globalThis = undefined; + let window = undefined; + return eval(code); +} + function isAuthenticated(req) { let authorization = req.headers.get("Authorization"); if (!authorization) return false; diff --git a/functions/migrations/0001_initial.py b/functions/migrations/0001_initial.py new file mode 100644 index 000000000..b04707d16 --- /dev/null +++ b/functions/migrations/0001_initial.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.7 on 2024-07-05 13:44 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('bots', '0076_alter_workflowmetadata_default_image_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='CalledFunction', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('trigger', models.IntegerField(choices=[(1, 'Pre'), (2, 'Post')])), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('function_run', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='called_by_runs', to='bots.savedrun')), + ('saved_run', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='called_functions', to='bots.savedrun')), + ], + ), + ] diff --git a/functions/migrations/__init__.py b/functions/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions/models.py b/functions/models.py new file mode 100644 index 000000000..0901fb109 --- /dev/null +++ b/functions/models.py @@ -0,0 +1,63 @@ +import typing + +from django.db import models +from pydantic import BaseModel, Field + +from daras_ai_v2.custom_enum import GooeyEnum +from daras_ai_v2.pydantic_validation import FieldHttpUrl + + +class _TriggerData(typing.NamedTuple): + label: str + db_value: int + + +class FunctionTrigger(_TriggerData, GooeyEnum): + pre = _TriggerData(label="Pre", db_value=1) + post = _TriggerData(label="Post", db_value=2) + + +class RecipeFunction(BaseModel): + url: FieldHttpUrl = Field( + title="URL", + description="The URL of the [function](https://gooey.ai/functions) to call.", + ) + trigger: FunctionTrigger.api_choices = Field( + title="Trigger", + description="When to run this function. `pre` runs before the recipe, `post` runs after the recipe.", + ) + + +class CalledFunctionResponse(BaseModel): + url: str + trigger: FunctionTrigger.api_choices + return_value: typing.Any + + @classmethod + def from_db(cls, called_fn: "CalledFunction") -> "CalledFunctionResponse": + return cls( + url=called_fn.function_run.get_app_url(), + trigger=FunctionTrigger.from_db(called_fn.trigger).name, + return_value=called_fn.function_run.state.get("return_value"), + ) + + +class CalledFunction(models.Model): + saved_run = models.ForeignKey( + "bots.SavedRun", + on_delete=models.CASCADE, + related_name="called_functions", + ) + function_run = models.ForeignKey( + "bots.SavedRun", + on_delete=models.CASCADE, + related_name="called_by_runs", + ) + trigger = models.IntegerField( + choices=FunctionTrigger.db_choices(), + ) + + created_at = models.DateTimeField(auto_now_add=True) + + def __str__(self): + return f"{self.saved_run} -> {self.function_run} ({self.trigger})" diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py new file mode 100644 index 000000000..1ba22be2a --- /dev/null +++ b/functions/recipe_functions.py @@ -0,0 +1,216 @@ +import json +import tempfile +import typing +from enum import Enum + +from pydantic import BaseModel + +import gooey_ui as st +from app_users.models import AppUser +from daras_ai.image_input import upload_file_from_bytes +from daras_ai_v2.enum_selector_widget import enum_selector +from daras_ai_v2.fastapi_tricks import extract_model_fields +from daras_ai_v2.field_render import field_title_desc +from daras_ai_v2.settings import templates +from functions.models import CalledFunction, FunctionTrigger + +if typing.TYPE_CHECKING: + from bots.models import SavedRun + + +def call_recipe_functions( + *, + saved_run: "SavedRun", + current_user: AppUser, + request_model: typing.Type[BaseModel], + response_model: typing.Type[BaseModel], + state: dict, + trigger: FunctionTrigger, +): + from daras_ai_v2.workflow_url_input import url_to_runs + from gooeysite.bg_db_conn import get_celery_result_db_safe + + request = request_model.parse_obj(state) + + functions = getattr(request, "functions", None) or [] + functions = [fun for fun in functions if fun.trigger == trigger.name] + if not functions: + return + variables = state.setdefault("variables", {}) + + yield f"Running {trigger.name} hooks..." + + for fun in functions: + # run the function + page_cls, sr, pr = url_to_runs(fun.url) + result, sr = sr.submit_api_call( + current_user=current_user, + request_body=dict( + variables=sr.state.get("variables", {}) + | variables + | dict( + request=json.loads( + request.json(exclude_unset=True, exclude={"variables"}) + ), + response=extract_model_fields(response_model, state), + ), + ), + ) + + CalledFunction.objects.create( + saved_run=saved_run, function_run=sr, trigger=trigger.db_value + ) + + # wait for the result if its a pre request function + if trigger == FunctionTrigger.post: + continue + get_celery_result_db_safe(result) + sr.refresh_from_db() + # if failed, raise error + if sr.error_msg: + raise RuntimeError(sr.error_msg) + + # save the output from the function + return_value = sr.state.get("return_value") + if return_value is None: + continue + if isinstance(return_value, dict): + for k, v in return_value.items(): + if k in request_model.__fields__ or k in response_model.__fields__: + state[k] = v + else: + variables[k] = v + else: + variables["return_value"] = return_value + + +def render_called_functions(*, saved_run: "SavedRun", trigger: FunctionTrigger): + from recipes.Functions import FunctionsPage + from daras_ai_v2.breadcrumbs import get_title_breadcrumbs + + if not is_functions_enabled(): + return + qs = saved_run.called_functions.filter(trigger=trigger.db_value) + if not qs.exists(): + return + for called_fn in qs: + tb = get_title_breadcrumbs( + FunctionsPage, + called_fn.function_run, + called_fn.function_run.parent_published_run(), + ) + title = (tb.published_title and tb.published_title.title) or tb.h1_title + st.write(f"###### 🧩 Called [{title}]({called_fn.function_run.get_app_url()})") + return_value = called_fn.function_run.state.get("return_value") + if return_value is not None: + st.json(return_value) + st.newline() + + +def is_functions_enabled(key="functions") -> bool: + return bool(st.session_state.get(f"--enable-{key}")) + + +def functions_input(current_user: AppUser, key="functions"): + from recipes.BulkRunner import list_view_editor + from daras_ai_v2.base import BasePage + + def render_function_input(list_key: str, del_key: str, d: dict): + from daras_ai_v2.workflow_url_input import workflow_url_input + from recipes.Functions import FunctionsPage + + col1, col2 = st.columns([2, 10], responsive=False) + with col1: + col1.node.props["className"] += " pt-1" + d["trigger"] = enum_selector( + enum_cls=FunctionTrigger, + use_selectbox=True, + key=list_key + ":trigger", + value=d.get("trigger"), + ) + with col2: + workflow_url_input( + page_cls=FunctionsPage, + key=list_key + ":url", + internal_state=d, + del_key=del_key, + current_user=current_user, + ) + + if st.checkbox( + f"##### {field_title_desc(BasePage.RequestModel, key)}", + key=f"--enable-{key}", + value=key in st.session_state, + ): + st.session_state.setdefault(key, [{}]) + list_view_editor( + add_btn_label="➕ Add Function", + key=key, + render_inputs=render_function_input, + ) + st.write("---") + else: + st.session_state.pop(key, None) + + +def json_to_pdf(filename: str, data: str) -> str: + html = templates.get_template("form_output.html").render(data=json.loads(data)) + pdf_bytes = html_to_pdf(html) + if not filename.endswith(".pdf"): + filename += ".pdf" + return upload_file_from_bytes(filename, pdf_bytes, "application/pdf") + + +def html_to_pdf(html: str) -> bytes: + from playwright.sync_api import sync_playwright + + with sync_playwright() as p: + browser = p.chromium.launch() + page = browser.new_page() + page.set_content(html) + with tempfile.NamedTemporaryFile(suffix=".pdf") as outfile: + page.pdf(path=outfile.name, format="A4") + ret = outfile.read() + browser.close() + + return ret + + +class LLMTools(Enum): + json_to_pdf = ( + json_to_pdf, + "Save JSON as PDF", + { + "type": "function", + "function": { + "name": json_to_pdf.__name__, + "description": "Save JSON data to PDF", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "A short but descriptive filename for the PDF", + }, + "data": { + "type": "string", + "description": "The JSON data to write to the PDF", + }, + }, + "required": ["filename", "data"], + }, + }, + }, + ) + # send_reply_buttons = (print, "Send back reply buttons to the user.", {}) + + def __new__(cls, fn: typing.Callable, label: str, spec: dict): + obj = object.__new__(cls) + obj._value_ = fn.__name__ + obj.fn = fn + obj.label = label + obj.spec = spec + return obj + + # def __init__(self, *args, **kwargs): + # self._value_ = self.name diff --git a/functions/tests.py b/functions/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/functions/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/functions/views.py b/functions/views.py new file mode 100644 index 000000000..91ea44a21 --- /dev/null +++ b/functions/views.py @@ -0,0 +1,3 @@ +from django.shortcuts import render + +# Create your views here. diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index fdeb5183f..c87e3a90d 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -551,12 +551,13 @@ def anchor( form_submit_button = button -def expander(label: str, *, expanded: bool = False, **props): +def expander(label: str, *, expanded: bool = False, key: str = None, **props): node = state.RenderTreeNode( name="expander", props=dict( label=dedent(label), open=expanded, + name=key or md5_values(label, expanded, props), **props, ), ) diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index b1a60157a..3712a6b8f 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -17,7 +17,7 @@ ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.loom_video_widget import youtube_video -from daras_ai_v2.prompt_vars import prompt_vars_widget, render_prompt_vars +from daras_ai_v2.prompt_vars import variables_input, render_prompt_vars DEFAULT_COMPARE_LM_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/fef06d86-1f70-11ef-b8ee-02420a00015b/LLMs.jpg" @@ -50,8 +50,6 @@ class RequestModel(BasePage.RequestModel): max_tokens: int | None sampling_temperature: float | None - variables: dict[str, typing.Any] | None - response_format_type: ResponseFormatType = Field( None, title="Response Format", @@ -81,7 +79,6 @@ def render_form_v2(self): help="What a fine day..", height=300, ) - prompt_vars_widget("input_prompt") enum_multiselect( LargeLanguageModels, diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 99776b1a1..b864d6700 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -22,7 +22,7 @@ ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.loom_video_widget import youtube_video -from daras_ai_v2.prompt_vars import prompt_vars_widget, render_prompt_vars +from daras_ai_v2.prompt_vars import variables_input, render_prompt_vars from daras_ai_v2.query_generator import generate_final_search_query from daras_ai_v2.search_ref import ( SearchReference, @@ -78,8 +78,6 @@ class RequestModel(DocSearchRequest, BasePage.RequestModel): citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None - variables: dict[str, typing.Any] | None - class ResponseModel(BaseModel): output_text: list[str] @@ -94,7 +92,6 @@ def get_example_preferred_fields(self, state: dict) -> list[str]: def render_form_v2(self): st.text_area("#### Search Query", key="search_query") bulk_documents_uploader("#### Documents") - prompt_vars_widget("task_instructions", "query_instructions") def validate_form_v2(self): search_query = st.session_state.get("search_query", "").strip() diff --git a/recipes/Functions.py b/recipes/Functions.py index c0ea22de5..b2b1c23dd 100644 --- a/recipes/Functions.py +++ b/recipes/Functions.py @@ -1,4 +1,3 @@ -import json import typing import requests @@ -9,6 +8,7 @@ from daras_ai_v2 import settings from daras_ai_v2.base import BasePage from daras_ai_v2.field_render import field_title_desc +from daras_ai_v2.prompt_vars import variables_input class ConsoleLogs(BaseModel): @@ -27,6 +27,11 @@ class RequestModel(BaseModel): title="Code", description="The JS code to be executed.", ) + variables: dict[str, typing.Any] = Field( + {}, + title="Variables", + description="Variables to be used in the code", + ) class ResponseModel(BaseModel): return_value: typing.Any = Field( @@ -51,10 +56,11 @@ def run_v2( response: "FunctionsPage.ResponseModel", ) -> typing.Iterator[str | None]: yield "Running your code..." + # this will run functions/executor.js in deno deploy r = requests.post( settings.DENO_FUNCTIONS_URL, headers={"Authorization": f"Basic {settings.DENO_FUNCTIONS_AUTH_TOKEN}"}, - json=request.code, + json=dict(code=request.code, variables=request.variables or {}), ) data = r.json() response.logs = data.get("logs") @@ -67,9 +73,11 @@ def render_form_v2(self): st.text_area( "##### " + field_title_desc(self.RequestModel, "code"), key="code", - height=500, ) + def render_variables(self): + variables_input(template_keys=["code"], allow_add=True) + def render_output(self): if error := st.session_state.get("error"): with st.tag("pre", className="bg-danger bg-opacity-25"): diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 8328f9dc6..270e53f2f 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -17,7 +17,7 @@ ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.loom_video_widget import youtube_video -from daras_ai_v2.prompt_vars import render_prompt_vars, prompt_vars_widget +from daras_ai_v2.prompt_vars import render_prompt_vars, variables_input from daras_ai_v2.query_generator import generate_final_search_query from daras_ai_v2.search_ref import ( SearchReference, @@ -100,8 +100,6 @@ class RequestModel(GoogleSearchMixin, BasePage.RequestModel): "dense_weight" ].field_info - variables: dict[str, typing.Any] | None - class ResponseModel(BaseModel): output_text: list[str] @@ -115,7 +113,6 @@ class ResponseModel(BaseModel): def render_form_v2(self): st.text_area("#### Google Search Query", key="search_query") st.text_input("Search on a specific site *(optional)*", key="site_filter") - prompt_vars_widget("task_instructions", "query_instructions") def validate_form_v2(self): assert st.session_state.get( diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 099b1d72b..2f81abc43 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -52,7 +52,6 @@ from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.exceptions import UserError from daras_ai_v2.field_render import field_title_desc, field_desc, field_title -from daras_ai_v2.functions import LLMTools from daras_ai_v2.glossary import validate_glossary_document from daras_ai_v2.language_model import ( run_language_model, @@ -71,7 +70,7 @@ from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModel from daras_ai_v2.lipsync_settings_widgets import lipsync_settings from daras_ai_v2.loom_video_widget import youtube_video -from daras_ai_v2.prompt_vars import render_prompt_vars, prompt_vars_widget +from daras_ai_v2.prompt_vars import render_prompt_vars from daras_ai_v2.pydantic_validation import FieldHttpUrl from daras_ai_v2.query_generator import generate_final_search_query from daras_ai_v2.query_params import gooey_get_query_params @@ -89,6 +88,7 @@ text_to_speech_provider_selector, ) from daras_ai_v2.vector_search import DocSearchRequest +from functions.recipe_functions import LLMTools from gooey_ui import RedirectException from recipes.DocSearch import ( get_top_k_references, @@ -247,8 +247,6 @@ class RequestModelBase(BasePage.RequestModel): LipsyncModel.Wav2Lip.name ) - variables: dict[str, typing.Any] | None - tools: list[LLMTools] | None = Field( title="🛠ī¸ Tools", description="Give your copilot superpowers by giving it access to tools. Powered by [Function calling](https://platform.openai.com/docs/guides/function-calling).", @@ -331,9 +329,6 @@ def render_form_v2(self): key="bot_script", height=300, ) - prompt_vars_widget( - "bot_script", - ) enum_selector( LargeLanguageModels, @@ -519,9 +514,6 @@ def render_settings(self): key="task_instructions", height=300, ) - prompt_vars_widget( - "task_instructions", - ) citation_style_selector() st.checkbox("🔗 Shorten Citation URLs", key="use_url_shortener") @@ -604,7 +596,10 @@ def render_output(self): references = st.session_state.get("references", []) if not references: return - with st.expander("💁‍♀ī¸ Sources"): + key = "sources-expander" + with st.expander("💁‍♀ī¸ Sources", key=key): + if not st.session_state.get(key): + return for idx, ref in enumerate(references): st.write(f"**{idx + 1}**. [{ref['title']}]({ref['url']})") text_output( diff --git a/routers/api.py b/routers/api.py index 979af1c3a..816837dd9 100644 --- a/routers/api.py +++ b/routers/api.py @@ -19,22 +19,21 @@ from starlette.datastructures import UploadFile from starlette.requests import Request -from celeryapp.tasks import auto_recharge -from daras_ai_v2.auto_recharge import user_should_auto_recharge import gooey_ui as st from app_users.models import AppUser from auth.token_authentication import api_auth_header from bots.models import RetentionPolicy +from celeryapp.tasks import auto_recharge from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings from daras_ai_v2.all_pages import all_api_pages +from daras_ai_v2.auto_recharge import user_should_auto_recharge from daras_ai_v2.base import ( BasePage, - StateKeys, RecipeRunState, ) from daras_ai_v2.fastapi_tricks import fastapi_request_form -from daras_ai_v2.ratelimits import ensure_rate_limits +from functions.models import CalledFunctionResponse from gooeysite.bg_db_conn import get_celery_result_db_safe from routers.account import AccountTabs @@ -119,9 +118,14 @@ def script_to_api(page_cls: typing.Type[BasePage]): settings=(RunSettings, RunSettings()), ) # encapsulate the response model with the ApiResponseModel + response_output_model = create_model( + page_cls.__name__ + "Output", + __base__=page_cls.ResponseModel, + called_functions=(list[CalledFunctionResponse], None), + ) response_model = create_model( page_cls.__name__ + "Response", - __base__=ApiResponseModelV2[page_cls.ResponseModel], + __base__=ApiResponseModelV2[response_output_model], ) common_errs = { @@ -244,7 +248,7 @@ def run_api_form( response_model = create_model( page_cls.__name__ + "StatusResponse", - __base__=AsyncStatusResponseModelV3[page_cls.ResponseModel], + __base__=AsyncStatusResponseModelV3[response_output_model], ) @app.get( @@ -281,7 +285,7 @@ def get_run_status( status = self.get_run_state(sr.to_dict()) ret |= {"detail": sr.run_status or "", "status": status} if status == RecipeRunState.completed and sr.state: - ret |= {"output": sr.state} + ret |= {"output": sr.api_output()} if sr.retention_policy == RetentionPolicy.delete: sr.state = {} sr.save(update_fields=["state"]) @@ -419,20 +423,18 @@ def build_api_response( # wait for the result get_celery_result_db_safe(result) sr = page.run_doc_sr(run_id, uid) - state = sr.to_dict() if sr.retention_policy == RetentionPolicy.delete: sr.state = {} sr.save(update_fields=["state"]) # check for errors - err_msg = state.get(StateKeys.error_msg) - if err_msg: + if sr.error_msg: raise HTTPException( status_code=500, detail={ "id": run_id, "url": web_url, "created_at": sr.created_at.isoformat(), - "error": err_msg, + "error": sr.error_msg, }, ) else: @@ -441,7 +443,7 @@ def build_api_response( "id": run_id, "url": web_url, "created_at": sr.created_at.isoformat(), - "output": state, + "output": sr.api_output(), } diff --git a/scripts/deno-deploy.sh b/scripts/deno-deploy.sh deleted file mode 100755 index d97b110c9..000000000 --- a/scripts/deno-deploy.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash - -set -ex - -deployctl deploy --include function-executor.js function-executor.js