Skip to content

Commit

Permalink
Refactor consistent usage of submit_api_call()
Browse files Browse the repository at this point in the history
- Introduce `SavedRun.wait_for_celery_result` to encapsulate common logic.
- Change BasePage `endpoint` to method `api_endpoint`.
- Update `submit_api_call`, `build_sync_api_response`, and `build_async_api_response` signatures.
  • Loading branch information
devxpy committed Sep 6, 2024
1 parent 8b59c92 commit c5b9bae
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 81 deletions.
13 changes: 8 additions & 5 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.language_model import format_chat_entry
from functions.models import CalledFunction, CalledFunctionResponse
from gooeysite.bg_db_conn import get_celery_result_db_safe
from gooeysite.custom_create import get_or_create_lazy

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -358,8 +359,8 @@ def submit_api_call(
current_user: AppUser,
request_body: dict,
enable_rate_limits: bool = False,
parent_pr: "PublishedRun" = None,
deduct_credits: bool = True,
parent_pr: "PublishedRun" = None,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
from routers.api import submit_api_call

Expand All @@ -373,19 +374,21 @@ def submit_api_call(
query_params = page_cls.clean_query_params(
example_id=self.example_id, run_id=self.run_id, uid=self.uid
)
page, result, run_id, uid = pool.apply(
return pool.apply(
submit_api_call,
kwds=dict(
page_cls=page_cls,
query_params=query_params,
user=current_user,
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
deduct_credits=deduct_credits,
),
)

return result, page.current_sr
def wait_for_celery_result(self, result: "celery.result.AsyncResult"):
get_celery_result_db_safe(result)
self.refresh_from_db()

def get_creator(self) -> AppUser | None:
if self.uid:
Expand Down Expand Up @@ -1839,8 +1842,8 @@ def submit_api_call(
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
parent_pr=self,
deduct_credits=deduct_credits,
parent_pr=self,
)


Expand Down
4 changes: 1 addition & 3 deletions bots/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None):
# save the run before the result is ready
Message.objects.filter(id=msg_id).update(analysis_run=sr)

# wait for the result
get_celery_result_db_safe(result)
sr.refresh_from_db()
sr.wait_for_celery_result(result)
# if failed, raise error
if sr.error_msg:
raise RuntimeError(sr.error_msg)
Expand Down
7 changes: 4 additions & 3 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def __init__(
self.run_user = run_user

@classmethod
@property
def endpoint(cls) -> str:
def api_endpoint(cls) -> str:
return f"/v2/{cls.slug_versions[0]}"

def current_app_url(
Expand Down Expand Up @@ -241,7 +240,9 @@ def api_url(
query_params = dict(run_id=run_id, uid=uid)
elif example_id:
query_params = dict(example_id=example_id)
return furl(settings.API_BASE_URL, query_params=query_params) / cls.endpoint
return (
furl(settings.API_BASE_URL, query_params=query_params) / cls.api_endpoint()
)

@classmethod
def clean_query_params(cls, *, example_id, run_id, uid) -> dict:
Expand Down
22 changes: 10 additions & 12 deletions daras_ai_v2/bots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import mimetypes
import typing
from datetime import datetime
from types import SimpleNamespace

import gooey_gui as gui
from django.db import transaction
Expand Down Expand Up @@ -199,9 +200,7 @@ def get_input_documents(self) -> list[str] | None:
def get_interactive_msg_info(self) -> ButtonPressed:
raise NotImplementedError("This bot does not support interactive messages.")

def on_run_created(
self, page: BasePage, result: "celery.result.AsyncResult", run_id: str, uid: str
):
def on_run_created(self, sr: "SavedRun"):
pass

def send_run_status(self, update_msg_id: str | None) -> str | None:
Expand Down Expand Up @@ -376,13 +375,13 @@ def _process_and_send_msg(
variables.update(bot.request_overrides["variables"])
except KeyError:
pass
page, result, run_id, uid = submit_api_call(
result, sr = submit_api_call(
page_cls=bot.page_cls,
user=billing_account_user,
request_body=body,
query_params=bot.query_params,
current_user=billing_account_user,
request_body=body,
)
bot.on_run_created(page, result, run_id, uid)
bot.on_run_created(sr)

if bot.show_feedback_buttons:
buttons = _feedback_start_buttons()
Expand All @@ -394,10 +393,10 @@ def _process_and_send_msg(
last_idx = 0 # this is the last index of the text sent to the user
if bot.streaming_enabled:
# subscribe to the realtime channel for updates
channel = page.realtime_channel_name(run_id, uid)
channel = bot.page_cls.realtime_channel_name(sr.run_id, sr.uid)
with gui.realtime_subscribe(channel) as realtime_gen:
for state in realtime_gen:
bot.recipe_run_state = page.get_run_state(state)
bot.recipe_run_state = bot.page_cls.get_run_state(state)
bot.run_status = state.get(StateKeys.run_status) or ""
# check for errors
if bot.recipe_run_state == RecipeRunState.failed:
Expand Down Expand Up @@ -438,11 +437,10 @@ def _process_and_send_msg(
break # we're done streaming, stop the loop

# wait for the celery task to finish
get_celery_result_db_safe(result)
sr.wait_for_celery_result(result)
# get the final state from db
sr = page.current_sr
state = sr.to_dict()
bot.recipe_run_state = page.get_run_state(state)
bot.recipe_run_state = bot.page_cls.get_run_state(state)
bot.run_status = state.get(StateKeys.run_status) or ""
# check for errors
err_msg = state.get(StateKeys.error_msg)
Expand Down
3 changes: 1 addition & 2 deletions daras_ai_v2/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def safety_checker_text(text_input: str):
)

# wait for checker
get_celery_result_db_safe(result)
sr.refresh_from_db()
sr.wait_for_celery_result(result)
# if checker failed, raise error
if sr.error_msg:
raise RuntimeError(sr.error_msg)
Expand Down
3 changes: 1 addition & 2 deletions functions/recipe_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def call_recipe_functions(
# wait for the result if its a pre request function
if trigger == FunctionTrigger.post:
continue
get_celery_result_db_safe(result)
sr.refresh_from_db()
sr.wait_for_celery_result(result)
# if failed, raise error
if sr.error_msg:
raise RuntimeError(sr.error_msg)
Expand Down
12 changes: 6 additions & 6 deletions recipes/BulkRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import typing
import uuid

import gooey_gui as gui
from furl import furl
from pydantic import BaseModel, Field

import gooey_gui as gui
from bots.models import Workflow, SavedRun
from daras_ai.image_input import upload_file_from_bytes
from daras_ai_v2 import icons
Expand Down Expand Up @@ -322,8 +322,7 @@ def run_v2(
request_body=request_body,
parent_pr=pr,
)
get_celery_result_db_safe(result)
sr.refresh_from_db()
sr.wait_for_celery_result(result)

run_time = datetime.timedelta(
seconds=int(sr.run_time.total_seconds())
Expand Down Expand Up @@ -390,10 +389,11 @@ def run_v2(
documents=response.output_documents
).dict(exclude_unset=True)
result, sr = sr.submit_api_call(
current_user=self.request.user, request_body=request_body, parent_pr=pr
current_user=self.request.user,
request_body=request_body,
parent_pr=pr,
)
get_celery_result_db_safe(result)
sr.refresh_from_db()
sr.wait_for_celery_result(result)
response.eval_runs.append(sr.get_app_url())

def preview_description(self, state: dict) -> str:
Expand Down
72 changes: 35 additions & 37 deletions routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from app_users.models import AppUser
from auth.token_authentication import api_auth_header
from bots.models import RetentionPolicy
from bots.models import RetentionPolicy, Workflow
from daras_ai.image_input import upload_file_from_bytes
from daras_ai_v2 import settings
from daras_ai_v2.all_pages import all_api_pages
Expand All @@ -41,9 +41,12 @@
)
from daras_ai_v2.fastapi_tricks import fastapi_request_form
from functions.models import CalledFunctionResponse
from gooeysite.bg_db_conn import get_celery_result_db_safe
from routers.custom_api_router import CustomAPIRouter

if typing.TYPE_CHECKING:
from bots.models import SavedRun
import celery.result

app = CustomAPIRouter()


Expand Down Expand Up @@ -117,7 +120,7 @@ class RunSettings(BaseModel):


def script_to_api(page_cls: typing.Type[BasePage]):
endpoint = page_cls().endpoint.rstrip("/")
endpoint = page_cls.api_endpoint().rstrip("/")
# add the common settings to the request model
request_model = create_model(
page_cls.__name__ + "Request",
Expand Down Expand Up @@ -156,15 +159,15 @@ def run_api_json(
page_request: request_model,
user: AppUser = Depends(api_auth_header),
):
page, result, run_id, uid = submit_api_call(
result, sr = submit_api_call(
page_cls=page_cls,
user=user,
request_body=page_request.dict(exclude_unset=True),
query_params=dict(request.query_params),
retention_policy=RetentionPolicy[page_request.settings.retention_policy],
current_user=user,
request_body=page_request.dict(exclude_unset=True),
enable_rate_limits=True,
)
return build_sync_api_response(page=page, result=result, run_id=run_id, uid=uid)
return build_sync_api_response(result, sr)

@app.post(
os.path.join(endpoint, "form"),
Expand Down Expand Up @@ -205,15 +208,15 @@ def run_api_json_async(
page_request: request_model,
user: AppUser = Depends(api_auth_header),
):
page, _, run_id, uid = submit_api_call(
result, sr = submit_api_call(
page_cls=page_cls,
user=user,
request_body=page_request.dict(exclude_unset=True),
query_params=dict(request.query_params),
retention_policy=RetentionPolicy[page_request.settings.retention_policy],
current_user=user,
request_body=page_request.dict(exclude_unset=True),
enable_rate_limits=True,
)
ret = build_async_api_response(page=page, run_id=run_id, uid=uid)
ret = build_async_api_response(sr)
response.headers["Location"] = ret["status_url"]
response.headers["Access-Control-Expose-Headers"] = "Location"
return ret
Expand Down Expand Up @@ -332,19 +335,21 @@ def _parse_form_data(
def submit_api_call(
*,
page_cls: typing.Type[BasePage],
request_body: dict,
user: AppUser,
query_params: dict,
retention_policy: RetentionPolicy = None,
current_user: AppUser,
request_body: dict,
enable_rate_limits: bool = False,
deduct_credits: bool = True,
) -> tuple[BasePage, "celery.result.AsyncResult", str, str]:
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
# init a new page for every request
query_params.setdefault("uid", user.uid)
self = page_cls(request=SimpleNamespace(user=user, query_params=query_params))
query_params.setdefault("uid", current_user.uid)
page = page_cls(
request=SimpleNamespace(user=current_user, query_params=query_params)
)

# get saved state from db
state = self.current_sr_to_session_state()
state = page.current_sr_to_session_state()
# load request data
state.update(request_body)

Expand All @@ -353,46 +358,39 @@ def submit_api_call(

# create a new run
try:
sr = self.create_new_run(
sr = page.create_new_run(
enable_rate_limits=enable_rate_limits,
is_api_call=True,
retention_policy=retention_policy or RetentionPolicy.keep,
)
except ValidationError as e:
raise RequestValidationError(e.raw_errors, body=gui.session_state) from e
# submit the task
result = self.call_runner_task(sr, deduct_credits=deduct_credits)
return self, result, sr.run_id, sr.uid
result = page.call_runner_task(sr, deduct_credits=deduct_credits)
return result, sr


def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict:
web_url = page.app_url(run_id=run_id, uid=uid)
def build_async_api_response(sr: "SavedRun") -> dict:
web_url = sr.get_app_url()
status_url = str(
furl(settings.API_BASE_URL, query_params=dict(run_id=run_id))
/ page.endpoint.replace("v2", "v3")
furl(settings.API_BASE_URL, query_params=dict(run_id=sr.run_id))
/ Workflow(sr.workflow).page_cls.api_endpoint().replace("v2", "v3")
/ "status/"
)
sr = page.current_sr
return dict(
run_id=run_id,
run_id=sr.run_id,
web_url=web_url,
created_at=sr.created_at.isoformat(),
status_url=status_url,
)


def build_sync_api_response(
*,
page: BasePage,
result: "celery.result.AsyncResult",
run_id: str,
uid: str,
result: "celery.result.AsyncResult", sr: "SavedRun"
) -> JSONResponse:
web_url = page.app_url(run_id=run_id, uid=uid)
web_url = sr.get_app_url()
# wait for the result
get_celery_result_db_safe(result)
sr = page.current_sr
sr.refresh_from_db()
sr.wait_for_celery_result(result)
if sr.retention_policy == RetentionPolicy.delete:
sr.state = {}
sr.save(update_fields=["state"])
Expand All @@ -401,7 +399,7 @@ def build_sync_api_response(
return JSONResponse(
dict(
detail=dict(
id=run_id,
id=sr.run_id,
url=web_url,
created_at=sr.created_at.isoformat(),
error=sr.error_msg,
Expand All @@ -414,7 +412,7 @@ def build_sync_api_response(
return JSONResponse(
jsonable_encoder(
dict(
id=run_id,
id=sr.run_id,
url=web_url,
created_at=sr.created_at.isoformat(),
output=sr.api_output(),
Expand Down
Loading

0 comments on commit c5b9bae

Please sign in to comment.