From e2b21a0ee21e3405b74113720cacdba73c46d140 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sat, 6 Jul 2024 01:46:31 +0530 Subject: [PATCH] fix parent published run not set when calling submit_api_call() --- bots/models.py | 28 ++++++++++++++++++++++------ bots/tasks.py | 4 +++- daras_ai_v2/safety_checker.py | 2 +- functions/recipe_functions.py | 1 + recipes/BulkRunner.py | 6 ++++-- 5 files changed, 31 insertions(+), 10 deletions(-) diff --git a/bots/models.py b/bots/models.py index 02512dd2a..6ff6674ff 100644 --- a/bots/models.py +++ b/bots/models.py @@ -342,17 +342,19 @@ def submit_api_call( current_user: AppUser, request_body: dict, enable_rate_limits: bool = False, + parent_pr: "PublishedRun" = None, ) -> tuple["celery.result.AsyncResult", "SavedRun"]: from routers.api import submit_api_call # run in a thread to avoid messing up threadlocals with ThreadPool(1) as pool: - pr = self.parent_published_run() - query_params = dict( - example_id=(pr and pr.published_run_id) or self.example_id, - run_id=self.run_id, - uid=self.uid, - ) + if parent_pr and parent_pr.saved_run == self: + # avoid passing run_id and uid for examples + query_params = dict(example_id=parent_pr.published_run_id) + else: + query_params = dict( + example_id=self.example_id, run_id=self.run_id, uid=self.uid + ) page, result, run_id, uid = pool.apply( submit_api_call, kwds=dict( @@ -1678,6 +1680,20 @@ def get_run_count(self): or 0 ) + def submit_api_call( + self, + *, + current_user: AppUser, + request_body: dict, + enable_rate_limits: bool = False, + ) -> tuple["celery.result.AsyncResult", "SavedRun"]: + return self.saved_run.submit_api_call( + current_user=current_user, + request_body=request_body, + enable_rate_limits=enable_rate_limits, + parent_pr=self, + ) + class PublishedRunVersion(models.Model): version_id = models.CharField(max_length=128, unique=True) diff --git a/bots/tasks.py b/bots/tasks.py index aca2ec369..d49abd6bc 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -92,7 +92,9 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None): # make the api call result, sr = analysis_sr.submit_api_call( - current_user=billing_account, request_body=dict(variables=variables) + current_user=billing_account, + request_body=dict(variables=variables), + parent_pr=anal.published_run, ) # save the run before the result is ready diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 8ae962cc4..841817d94 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -30,7 +30,7 @@ def safety_checker_text(text_input: str): result, sr = ( CompareLLMPage() .get_published_run(published_run_id=settings.SAFTY_CHECKER_EXAMPLE_ID) - .saved_run.submit_api_call( + .submit_api_call( current_user=billing_account, request_body=dict(variables=dict(input=text_input)), ) diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py index ff3c2f734..ae34ca54f 100644 --- a/functions/recipe_functions.py +++ b/functions/recipe_functions.py @@ -44,6 +44,7 @@ def call_recipe_functions( page_cls, sr, pr = url_to_runs(fun.url) result, sr = sr.submit_api_call( current_user=current_user, + parent_pr=pr, request_body=dict( variables=sr.state.get("variables", {}) | variables diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index 6fb6dd5d9..87b218954 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -318,7 +318,9 @@ def run_v2( yield f"{progress}%" result, sr = sr.submit_api_call( - current_user=self.request.user, request_body=request_body + current_user=self.request.user, + request_body=request_body, + parent_pr=pr, ) get_celery_result_db_safe(result) sr.refresh_from_db() @@ -388,7 +390,7 @@ def run_v2( documents=response.output_documents ).dict(exclude_unset=True) result, sr = sr.submit_api_call( - current_user=self.request.user, request_body=request_body + current_user=self.request.user, request_body=request_body, parent_pr=pr ) get_celery_result_db_safe(result) sr.refresh_from_db()