Skip to content

Commit

Permalink
refactor & split tool construction, bind() and __call__
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Dec 5, 2024
1 parent 77b6c47 commit 3a75df0
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 142 deletions.
10 changes: 6 additions & 4 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_functions_enabled,
render_called_functions,
LLMTool,
get_tools_from_state,
)
from payments.auto_recharge import (
should_attempt_auto_recharge,
Expand Down Expand Up @@ -1346,9 +1347,9 @@ def update_flag_for_run(self, is_flagged: bool):
sr.save(update_fields=["is_flagged"])
gui.session_state["is_flagged"] = is_flagged

def get_current_llm_tools(self) -> dict[str, LLMTool]:
return dict(
call_recipe_functions(
def get_current_llm_tools(self) -> list[LLMTool]:
return [
tool.bind(
saved_run=self.current_sr,
workspace=self.current_workspace,
current_user=self.request.user,
Expand All @@ -1357,7 +1358,8 @@ def get_current_llm_tools(self) -> dict[str, LLMTool]:
state=gui.session_state,
trigger=FunctionTrigger.prompt,
)
)
for tool in get_tools_from_state(gui.session_state, FunctionTrigger.prompt)
]

@cached_property
def current_workspace(self) -> Workspace:
Expand Down
232 changes: 115 additions & 117 deletions functions/recipe_functions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import json
import tempfile
import typing
from enum import Enum
from functools import partial

import gooey_gui as gui
from django.utils.text import slugify

from app_users.models import AppUser
from daras_ai.image_input import upload_file_from_bytes
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.field_render import field_title_desc
from daras_ai_v2.settings import templates
from functions.models import CalledFunction, FunctionTrigger

if typing.TYPE_CHECKING:
Expand All @@ -20,38 +15,62 @@
from workspaces.models import Workspace


def call_recipe_functions(
*,
saved_run: "SavedRun",
workspace: "Workspace",
current_user: AppUser,
request_model: typing.Type["BasePage.RequestModel"],
response_model: typing.Type["BasePage.ResponseModel"],
state: dict,
trigger: FunctionTrigger,
) -> typing.Generator[typing.Union[str, tuple[str, "LLMTool"]], None, None]:
from daras_ai_v2.workflow_url_input import url_to_runs

functions = state.get("functions") or []
functions = [fun for fun in functions if fun.get("trigger") == trigger.name]
if not functions:
return

request = request_model.parse_obj(state)
variables = state.setdefault("variables", {})
fn_vars = dict(
web_url=saved_run.get_app_url(),
request=json.loads(request.json(exclude_unset=True, exclude={"variables"})),
response={k: v for k, v in state.items() if k in response_model.__fields__},
)

if trigger != FunctionTrigger.prompt:
yield f"Running {trigger.name} hooks..."
class LLMTool:
def __init__(
self,
function_url: str,
name: str,
label: str,
spec: dict,
):
self.function_url = function_url
self.name = name
self.label = label
self.spec = spec

def bind(
self,
saved_run: "SavedRun",
workspace: "Workspace",
current_user: AppUser,
request_model: typing.Type["BasePage.RequestModel"],
response_model: typing.Type["BasePage.ResponseModel"],
state: dict,
trigger: FunctionTrigger,
) -> "LLMTool":
self.saved_run = saved_run
self.workspace = workspace
self.current_user = current_user
self.request_model = request_model
self.response_model = response_model
self.state = state
self.trigger = trigger
return self

def __call__(self, **kwargs):
from daras_ai_v2.workflow_url_input import url_to_runs

try:
self.saved_run
except AttributeError:
raise RuntimeError("This LLMTool instance is not yet bound")

request = self.request_model.parse_obj(self.state)
variables = self.state.setdefault("variables", {})
fn_vars = dict(
web_url=self.saved_run.get_app_url(),
request=json.loads(request.json(exclude_unset=True, exclude={"variables"})),
response={
k: v
for k, v in self.state.items()
if k in self.response_model.__fields__
},
)

def run(sr, pr, /, **kwargs):
_, sr, pr = url_to_runs(self.function_url)
result, sr = sr.submit_api_call(
workspace=workspace,
current_user=current_user,
workspace=self.workspace,
current_user=self.current_user,
parent_pr=pr,
request_body=dict(
variables=sr.state.get("variables", {}) | variables | fn_vars | kwargs,
Expand All @@ -60,11 +79,11 @@ def run(sr, pr, /, **kwargs):
)

CalledFunction.objects.create(
saved_run=saved_run, function_run=sr, trigger=trigger.db_value
saved_run=self.saved_run, function_run=sr, trigger=self.trigger.db_value
)

# wait for the result if its a pre request function
if trigger == FunctionTrigger.post:
if self.trigger == FunctionTrigger.post:
return
sr.wait_for_celery_result(result)
# if failed, raise error
Expand All @@ -77,39 +96,75 @@ def run(sr, pr, /, **kwargs):
return
if isinstance(return_value, dict):
for k, v in return_value.items():
if k in request_model.__fields__ or k in response_model.__fields__:
state[k] = v
if (
k in self.request_model.__fields__
or k in self.response_model.__fields__
):
self.state[k] = v
else:
variables[k] = v
else:
variables["return_value"] = return_value

return return_value

for fun in functions:
_, sr, pr = url_to_runs(fun.get("url"))
if trigger != FunctionTrigger.prompt:
run(sr, pr)
else:
fn_name = slugify(pr.title).replace("-", "_")
yield fn_name, LLMTool(
fn=partial(run, sr, pr),
label=pr.title,
spec={
"type": "function",
"function": {
"name": fn_name,
"description": pr.notes,
"parameters": {
"type": "object",
"properties": {
key: {"type": get_json_type(value)}
for key, value in sr.state.get("variables", {}).items()
},

def call_recipe_functions(
*,
saved_run: "SavedRun",
workspace: "Workspace",
current_user: AppUser,
request_model: typing.Type["BasePage.RequestModel"],
response_model: typing.Type["BasePage.ResponseModel"],
state: dict,
trigger: FunctionTrigger,
) -> typing.Iterable[str]:
yield f"Running {trigger.name} hooks..."
for tool in get_tools_from_state(state, trigger):
tool.bind(
saved_run=saved_run,
workspace=workspace,
current_user=current_user,
request_model=request_model,
response_model=response_model,
state=state,
trigger=trigger,
)
tool()


def get_tools_from_state(
state: dict, trigger: FunctionTrigger
) -> typing.Iterable[LLMTool]:
from daras_ai_v2.workflow_url_input import url_to_runs

functions = state.get("functions")
if not functions:
return
for function in functions:
if function.get("trigger") != trigger.name:
continue
_, sr, pr = url_to_runs(function.get("url"))
fn_name = slugify(pr.title).replace("-", "_")
yield LLMTool(
function_url=function.get("url"),
name=fn_name,
label=pr.title,
spec={
"type": "function",
"function": {
"name": fn_name,
"description": pr.notes,
"parameters": {
"type": "object",
"properties": {
key: {"type": get_json_type(value)}
for key, value in sr.state.get("variables", {}).items()
},
},
},
)
},
)


def get_json_type(val) -> str:
Expand Down Expand Up @@ -210,60 +265,3 @@ def render_function_input(list_key: str, del_key: str, d: dict):
)
else:
gui.session_state.pop(key, None)


def json_to_pdf(filename: str, data: str) -> str:
html = templates.get_template("form_output.html").render(data=json.loads(data))
pdf_bytes = html_to_pdf(html)
if not filename.endswith(".pdf"):
filename += ".pdf"
return upload_file_from_bytes(filename, pdf_bytes, "application/pdf")


def html_to_pdf(html: str) -> bytes:
from playwright.sync_api import sync_playwright

with sync_playwright() as p:
browser = p.chromium.launch()
page = browser.new_page()
page.set_content(html)
with tempfile.NamedTemporaryFile(suffix=".pdf") as outfile:
page.pdf(path=outfile.name, format="A4")
ret = outfile.read()
browser.close()

return ret


class LLMTool(typing.NamedTuple):
fn: typing.Callable
label: str
spec: dict


class LLMTools(LLMTool, Enum):
json_to_pdf = LLMTool(
fn=json_to_pdf,
label="Save JSON as PDF",
spec={
"type": "function",
"function": {
"name": json_to_pdf.__name__,
"description": "Save JSON data to PDF",
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "A short but descriptive filename for the PDF",
},
"data": {
"type": "string",
"description": "The JSON data to write to the PDF",
},
},
"required": ["filename", "data"],
},
},
},
)
34 changes: 13 additions & 21 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
from functions.recipe_functions import (
LLMTool,
render_called_functions,
call_recipe_functions,
get_tools_from_state,
)
from recipes.DocSearch import (
get_top_k_references,
Expand Down Expand Up @@ -746,6 +746,15 @@ def render_steps(self):
gui.write("###### `references`")
gui.json(references)

if gui.session_state.get("functions"):
prompt_funcs = list(
get_tools_from_state(gui.session_state, FunctionTrigger.prompt)
)
if prompt_funcs:
gui.write(f"🧩 `{FunctionTrigger.prompt.name} functions`")
for name, tool in prompt_funcs:
gui.json(tool.spec.get("function", tool.spec), depth=3)

final_prompt = gui.session_state.get("final_prompt")
if final_prompt:
if isinstance(final_prompt, str):
Expand All @@ -757,21 +766,6 @@ def render_steps(self):
)
gui.json(final_prompt, depth=5)

if gui.session_state.get("functions"):
prompt_funcs = call_recipe_functions(
saved_run=self.current_sr,
workspace=None,
current_user=self.request.user,
request_model=self.RequestModel,
response_model=self.ResponseModel,
state=gui.session_state,
trigger=FunctionTrigger.prompt,
)
if prompt_funcs:
gui.write(f"🧩 `{FunctionTrigger.prompt.name} functions`")
for name, tool in prompt_funcs:
gui.json(tool.spec.get("function", tool.spec), depth=3)

for k in ["raw_output_text", "output_text", "raw_tts_text"]:
for idx, text in enumerate(gui.session_state.get(k) or []):
gui.text_area(
Expand Down Expand Up @@ -1058,8 +1052,6 @@ def llm_loop(
) -> typing.Iterator[str | None]:
yield f"Summarizing with {model.value}..."

tools = self.get_current_llm_tools()

chunks: typing.Generator[list[dict], None, None] = run_language_model(
model=request.selected_model,
messages=response.final_prompt,
Expand All @@ -1068,7 +1060,7 @@ def llm_loop(
temperature=request.sampling_temperature,
avoid_repetition=request.avoid_repetition,
response_format_type=request.response_format_type,
tools=list(tools.values()),
tools=self.get_current_llm_tools(),
stream=True,
)

Expand Down Expand Up @@ -1133,7 +1125,7 @@ def llm_loop(
)
)
for call in tool_calls:
result = yield from exec_tool_call(call, tools)
result = yield from exec_tool_call(call, self.get_current_llm_tools())
response.final_prompt.append(
dict(
role="tool",
Expand Down Expand Up @@ -1727,7 +1719,7 @@ def exec_tool_call(call: dict, tools: dict[str, "LLMTool"]):
tool = tools[tool_name]
yield f"🛠 {tool.label}..."
kwargs = json.loads(call["function"]["arguments"])
return tool.fn(**kwargs)
return tool(**kwargs)


class ConnectChoice(typing.NamedTuple):
Expand Down

0 comments on commit 3a75df0

Please sign in to comment.