From c5b9baed3de7d98ed0620e1ade7c6f1a35a24a8d Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 6 Sep 2024 14:00:12 +0530 Subject: [PATCH] Refactor consistent usage of submit_api_call() - Introduce `SavedRun.wait_for_celery_result` to encapsulate common logic. - Change BasePage `endpoint` to method `api_endpoint`. - Update `submit_api_call`, `build_sync_api_response`, and `build_async_api_response` signatures. --- bots/models.py | 13 ++++--- bots/tasks.py | 4 +- daras_ai_v2/base.py | 7 ++-- daras_ai_v2/bots.py | 22 +++++------ daras_ai_v2/safety_checker.py | 3 +- functions/recipe_functions.py | 3 +- recipes/BulkRunner.py | 12 +++--- routers/api.py | 72 +++++++++++++++++------------------ routers/bots_api.py | 14 +++---- routers/twilio_api.py | 3 +- 10 files changed, 72 insertions(+), 81 deletions(-) diff --git a/bots/models.py b/bots/models.py index 3ace6bc33..6c3c1d92b 100644 --- a/bots/models.py +++ b/bots/models.py @@ -18,6 +18,7 @@ from daras_ai_v2.crypto import get_random_doc_id from daras_ai_v2.language_model import format_chat_entry from functions.models import CalledFunction, CalledFunctionResponse +from gooeysite.bg_db_conn import get_celery_result_db_safe from gooeysite.custom_create import get_or_create_lazy if typing.TYPE_CHECKING: @@ -358,8 +359,8 @@ def submit_api_call( current_user: AppUser, request_body: dict, enable_rate_limits: bool = False, - parent_pr: "PublishedRun" = None, deduct_credits: bool = True, + parent_pr: "PublishedRun" = None, ) -> tuple["celery.result.AsyncResult", "SavedRun"]: from routers.api import submit_api_call @@ -373,19 +374,21 @@ def submit_api_call( query_params = page_cls.clean_query_params( example_id=self.example_id, run_id=self.run_id, uid=self.uid ) - page, result, run_id, uid = pool.apply( + return pool.apply( submit_api_call, kwds=dict( page_cls=page_cls, query_params=query_params, - user=current_user, + current_user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, deduct_credits=deduct_credits, ), ) - return result, page.current_sr + def wait_for_celery_result(self, result: "celery.result.AsyncResult"): + get_celery_result_db_safe(result) + self.refresh_from_db() def get_creator(self) -> AppUser | None: if self.uid: @@ -1839,8 +1842,8 @@ def submit_api_call( current_user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, - parent_pr=self, deduct_credits=deduct_credits, + parent_pr=self, ) diff --git a/bots/tasks.py b/bots/tasks.py index 70f381bf1..0a76e42bc 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -97,9 +97,7 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None): # save the run before the result is ready Message.objects.filter(id=msg_id).update(analysis_run=sr) - # wait for the result - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # if failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index e2f05a66f..90432553c 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -164,8 +164,7 @@ def __init__( self.run_user = run_user @classmethod - @property - def endpoint(cls) -> str: + def api_endpoint(cls) -> str: return f"/v2/{cls.slug_versions[0]}" def current_app_url( @@ -241,7 +240,9 @@ def api_url( query_params = dict(run_id=run_id, uid=uid) elif example_id: query_params = dict(example_id=example_id) - return furl(settings.API_BASE_URL, query_params=query_params) / cls.endpoint + return ( + furl(settings.API_BASE_URL, query_params=query_params) / cls.api_endpoint() + ) @classmethod def clean_query_params(cls, *, example_id, run_id, uid) -> dict: diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index eea020936..e8fb61c6b 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -1,6 +1,7 @@ import mimetypes import typing from datetime import datetime +from types import SimpleNamespace import gooey_gui as gui from django.db import transaction @@ -199,9 +200,7 @@ def get_input_documents(self) -> list[str] | None: def get_interactive_msg_info(self) -> ButtonPressed: raise NotImplementedError("This bot does not support interactive messages.") - def on_run_created( - self, page: BasePage, result: "celery.result.AsyncResult", run_id: str, uid: str - ): + def on_run_created(self, sr: "SavedRun"): pass def send_run_status(self, update_msg_id: str | None) -> str | None: @@ -376,13 +375,13 @@ def _process_and_send_msg( variables.update(bot.request_overrides["variables"]) except KeyError: pass - page, result, run_id, uid = submit_api_call( + result, sr = submit_api_call( page_cls=bot.page_cls, - user=billing_account_user, - request_body=body, query_params=bot.query_params, + current_user=billing_account_user, + request_body=body, ) - bot.on_run_created(page, result, run_id, uid) + bot.on_run_created(sr) if bot.show_feedback_buttons: buttons = _feedback_start_buttons() @@ -394,10 +393,10 @@ def _process_and_send_msg( last_idx = 0 # this is the last index of the text sent to the user if bot.streaming_enabled: # subscribe to the realtime channel for updates - channel = page.realtime_channel_name(run_id, uid) + channel = bot.page_cls.realtime_channel_name(sr.run_id, sr.uid) with gui.realtime_subscribe(channel) as realtime_gen: for state in realtime_gen: - bot.recipe_run_state = page.get_run_state(state) + bot.recipe_run_state = bot.page_cls.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" # check for errors if bot.recipe_run_state == RecipeRunState.failed: @@ -438,11 +437,10 @@ def _process_and_send_msg( break # we're done streaming, stop the loop # wait for the celery task to finish - get_celery_result_db_safe(result) + sr.wait_for_celery_result(result) # get the final state from db - sr = page.current_sr state = sr.to_dict() - bot.recipe_run_state = page.get_run_state(state) + bot.recipe_run_state = bot.page_cls.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" # check for errors err_msg = state.get(StateKeys.error_msg) diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 8a8d67336..7faf6b66b 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -38,8 +38,7 @@ def safety_checker_text(text_input: str): ) # wait for checker - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # if checker failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py index b7fd36fdb..21d3fa185 100644 --- a/functions/recipe_functions.py +++ b/functions/recipe_functions.py @@ -63,8 +63,7 @@ def call_recipe_functions( # wait for the result if its a pre request function if trigger == FunctionTrigger.post: continue - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # if failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index f700effac..7ad9e67e9 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -2,10 +2,10 @@ import typing import uuid +import gooey_gui as gui from furl import furl from pydantic import BaseModel, Field -import gooey_gui as gui from bots.models import Workflow, SavedRun from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import icons @@ -322,8 +322,7 @@ def run_v2( request_body=request_body, parent_pr=pr, ) - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) run_time = datetime.timedelta( seconds=int(sr.run_time.total_seconds()) @@ -390,10 +389,11 @@ def run_v2( documents=response.output_documents ).dict(exclude_unset=True) result, sr = sr.submit_api_call( - current_user=self.request.user, request_body=request_body, parent_pr=pr + current_user=self.request.user, + request_body=request_body, + parent_pr=pr, ) - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) response.eval_runs.append(sr.get_app_url()) def preview_description(self, state: dict) -> str: diff --git a/routers/api.py b/routers/api.py index d1878bd53..68ad4c2b6 100644 --- a/routers/api.py +++ b/routers/api.py @@ -31,7 +31,7 @@ from app_users.models import AppUser from auth.token_authentication import api_auth_header -from bots.models import RetentionPolicy +from bots.models import RetentionPolicy, Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings from daras_ai_v2.all_pages import all_api_pages @@ -41,9 +41,12 @@ ) from daras_ai_v2.fastapi_tricks import fastapi_request_form from functions.models import CalledFunctionResponse -from gooeysite.bg_db_conn import get_celery_result_db_safe from routers.custom_api_router import CustomAPIRouter +if typing.TYPE_CHECKING: + from bots.models import SavedRun + import celery.result + app = CustomAPIRouter() @@ -117,7 +120,7 @@ class RunSettings(BaseModel): def script_to_api(page_cls: typing.Type[BasePage]): - endpoint = page_cls().endpoint.rstrip("/") + endpoint = page_cls.api_endpoint().rstrip("/") # add the common settings to the request model request_model = create_model( page_cls.__name__ + "Request", @@ -156,15 +159,15 @@ def run_api_json( page_request: request_model, user: AppUser = Depends(api_auth_header), ): - page, result, run_id, uid = submit_api_call( + result, sr = submit_api_call( page_cls=page_cls, - user=user, - request_body=page_request.dict(exclude_unset=True), query_params=dict(request.query_params), retention_policy=RetentionPolicy[page_request.settings.retention_policy], + current_user=user, + request_body=page_request.dict(exclude_unset=True), enable_rate_limits=True, ) - return build_sync_api_response(page=page, result=result, run_id=run_id, uid=uid) + return build_sync_api_response(result, sr) @app.post( os.path.join(endpoint, "form"), @@ -205,15 +208,15 @@ def run_api_json_async( page_request: request_model, user: AppUser = Depends(api_auth_header), ): - page, _, run_id, uid = submit_api_call( + result, sr = submit_api_call( page_cls=page_cls, - user=user, - request_body=page_request.dict(exclude_unset=True), query_params=dict(request.query_params), retention_policy=RetentionPolicy[page_request.settings.retention_policy], + current_user=user, + request_body=page_request.dict(exclude_unset=True), enable_rate_limits=True, ) - ret = build_async_api_response(page=page, run_id=run_id, uid=uid) + ret = build_async_api_response(sr) response.headers["Location"] = ret["status_url"] response.headers["Access-Control-Expose-Headers"] = "Location" return ret @@ -332,19 +335,21 @@ def _parse_form_data( def submit_api_call( *, page_cls: typing.Type[BasePage], - request_body: dict, - user: AppUser, query_params: dict, retention_policy: RetentionPolicy = None, + current_user: AppUser, + request_body: dict, enable_rate_limits: bool = False, deduct_credits: bool = True, -) -> tuple[BasePage, "celery.result.AsyncResult", str, str]: +) -> tuple["celery.result.AsyncResult", "SavedRun"]: # init a new page for every request - query_params.setdefault("uid", user.uid) - self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) + query_params.setdefault("uid", current_user.uid) + page = page_cls( + request=SimpleNamespace(user=current_user, query_params=query_params) + ) # get saved state from db - state = self.current_sr_to_session_state() + state = page.current_sr_to_session_state() # load request data state.update(request_body) @@ -353,7 +358,7 @@ def submit_api_call( # create a new run try: - sr = self.create_new_run( + sr = page.create_new_run( enable_rate_limits=enable_rate_limits, is_api_call=True, retention_policy=retention_policy or RetentionPolicy.keep, @@ -361,20 +366,19 @@ def submit_api_call( except ValidationError as e: raise RequestValidationError(e.raw_errors, body=gui.session_state) from e # submit the task - result = self.call_runner_task(sr, deduct_credits=deduct_credits) - return self, result, sr.run_id, sr.uid + result = page.call_runner_task(sr, deduct_credits=deduct_credits) + return result, sr -def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: - web_url = page.app_url(run_id=run_id, uid=uid) +def build_async_api_response(sr: "SavedRun") -> dict: + web_url = sr.get_app_url() status_url = str( - furl(settings.API_BASE_URL, query_params=dict(run_id=run_id)) - / page.endpoint.replace("v2", "v3") + furl(settings.API_BASE_URL, query_params=dict(run_id=sr.run_id)) + / Workflow(sr.workflow).page_cls.api_endpoint().replace("v2", "v3") / "status/" ) - sr = page.current_sr return dict( - run_id=run_id, + run_id=sr.run_id, web_url=web_url, created_at=sr.created_at.isoformat(), status_url=status_url, @@ -382,17 +386,11 @@ def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: def build_sync_api_response( - *, - page: BasePage, - result: "celery.result.AsyncResult", - run_id: str, - uid: str, + result: "celery.result.AsyncResult", sr: "SavedRun" ) -> JSONResponse: - web_url = page.app_url(run_id=run_id, uid=uid) + web_url = sr.get_app_url() # wait for the result - get_celery_result_db_safe(result) - sr = page.current_sr - sr.refresh_from_db() + sr.wait_for_celery_result(result) if sr.retention_policy == RetentionPolicy.delete: sr.state = {} sr.save(update_fields=["state"]) @@ -401,7 +399,7 @@ def build_sync_api_response( return JSONResponse( dict( detail=dict( - id=run_id, + id=sr.run_id, url=web_url, created_at=sr.created_at.isoformat(), error=sr.error_msg, @@ -414,7 +412,7 @@ def build_sync_api_response( return JSONResponse( jsonable_encoder( dict( - id=run_id, + id=sr.run_id, url=web_url, created_at=sr.created_at.isoformat(), output=sr.api_output(), diff --git a/routers/bots_api.py b/routers/bots_api.py index 64285cbd2..021caae4b 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from starlette.responses import StreamingResponse, Response -from bots.models import Platform, Conversation, BotIntegration, Message +from bots.models import Platform, Conversation, BotIntegration, Message, SavedRun from celeryapp.tasks import err_msg_for_exc from daras_ai_v2 import settings from daras_ai_v2.base import RecipeRunState, BasePage, StateKeys @@ -320,14 +320,10 @@ def runner(self): finally: self.queue.put(None) - def on_run_created( - self, page: BasePage, result: "celery.result.AsyncResult", run_id: str, uid: str - ): - self.run_id = run_id - self.uid = uid - self.queue.put( - RunStart(**build_async_api_response(page=page, run_id=run_id, uid=uid)) - ) + def on_run_created(self, sr: SavedRun): + self.run_id = sr.run_id + self.uid = sr.uid + self.queue.put(RunStart(**build_async_api_response(sr))) def send_run_status(self, update_msg_id: str | None) -> str | None: self.queue.put( diff --git a/routers/twilio_api.py b/routers/twilio_api.py index 0223636a9..108f3e47b 100644 --- a/routers/twilio_api.py +++ b/routers/twilio_api.py @@ -267,8 +267,7 @@ def resp_say_or_tts_play( request_body=tts_state, ) # wait for the TTS to finish - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # check for errors if sr.error_msg: raise RuntimeError(sr.error_msg)