From ba66c0240125657832c0e227dfbfd916385a180c Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 14 Aug 2024 20:38:33 +0530 Subject: [PATCH] Add deduct_credits flag to control credit deduction after task execution - Don't deduct credits when running called functions, safety checker and admin re-runs --- bots/admin.py | 2 +- bots/models.py | 4 ++++ celeryapp/tasks.py | 4 +++- daras_ai_v2/base.py | 3 ++- daras_ai_v2/safety_checker.py | 1 + functions/recipe_functions.py | 1 + recipes/Functions.py | 10 +++++----- routers/api.py | 3 ++- 8 files changed, 19 insertions(+), 9 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index 4b6a731c8..0b8b28d10 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -438,7 +438,7 @@ def rerun_tasks(self, request, queryset): page = Workflow(sr.workflow).page_cls( request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid)) ) - page.call_runner_task(sr) + page.call_runner_task(sr, deduct_credits=False) self.message_user( request, f"Started re-running {queryset.count()} tasks in the background.", diff --git a/bots/models.py b/bots/models.py index 0407569d5..51eaebc59 100644 --- a/bots/models.py +++ b/bots/models.py @@ -363,6 +363,7 @@ def submit_api_call( request_body: dict, enable_rate_limits: bool = False, parent_pr: "PublishedRun" = None, + deduct_credits: bool = True, ) -> tuple["celery.result.AsyncResult", "SavedRun"]: from routers.api import submit_api_call @@ -384,6 +385,7 @@ def submit_api_call( user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, + deduct_credits=deduct_credits, ), ) @@ -1818,12 +1820,14 @@ def submit_api_call( current_user: AppUser, request_body: dict, enable_rate_limits: bool = False, + deduct_credits: bool = True, ) -> tuple["celery.result.AsyncResult", "SavedRun"]: return self.saved_run.submit_api_call( current_user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, parent_pr=self, + deduct_credits=deduct_credits, ) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 794b5a061..2651fd80c 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -41,6 +41,7 @@ def runner_task( uid: str, channel: str, unsaved_state: dict[str, typing.Any] = None, + deduct_credits: bool = True, ) -> int: start_time = time() error_msg = None @@ -107,7 +108,8 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # run completed successfully, deduct credits else: - sr.transaction, sr.price = page.deduct_credits(gui.session_state) + if deduct_credits: + sr.transaction, sr.price = page.deduct_credits(gui.session_state) # save everything, mark run as completed finally: diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index b22e22e8b..7ee73dcfc 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1686,7 +1686,7 @@ def dump_state_to_sr(self, state: dict, sr: SavedRun): } ) - def call_runner_task(self, sr: SavedRun): + def call_runner_task(self, sr: SavedRun, deduct_credits: bool = True): from celeryapp.tasks import runner_task, post_runner_tasks chain = ( @@ -1697,6 +1697,7 @@ def call_runner_task(self, sr: SavedRun): uid=sr.uid, channel=self.realtime_channel_name(sr.run_id, sr.uid), unsaved_state=self._unsaved_state(), + deduct_credits=deduct_credits, ) | post_runner_tasks.s() ) diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 841817d94..338d22614 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -33,6 +33,7 @@ def safety_checker_text(text_input: str): .submit_api_call( current_user=billing_account, request_body=dict(variables=dict(input=text_input)), + deduct_credits=False, ) ) diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py index 5d4c07b87..b7fd36fdb 100644 --- a/functions/recipe_functions.py +++ b/functions/recipe_functions.py @@ -53,6 +53,7 @@ def call_recipe_functions( request_body=dict( variables=sr.state.get("variables", {}) | variables | fn_vars, ), + deduct_credits=False, ) CalledFunction.objects.create( diff --git a/recipes/Functions.py b/recipes/Functions.py index 5d415ad89..356381343 100644 --- a/recipes/Functions.py +++ b/recipes/Functions.py @@ -86,12 +86,12 @@ def render_form_v2(self): ) def get_price_roundoff(self, state: dict) -> float: - try: - # called from another workflow don't charge any credits - CalledFunction.objects.get(function_run=self.get_current_sr()) + if CalledFunction.objects.filter(function_run=self.get_current_sr()).exists(): return 0 - except CalledFunction.DoesNotExist: - return self.price + return super().get_price_roundoff(state) + + def additional_notes(self): + return "\nFunctions are free if called from another workflow." def render_variables(self): variables_input( diff --git a/routers/api.py b/routers/api.py index 171daa34c..9b795d426 100644 --- a/routers/api.py +++ b/routers/api.py @@ -332,6 +332,7 @@ def submit_api_call( query_params: dict, retention_policy: RetentionPolicy = None, enable_rate_limits: bool = False, + deduct_credits: bool = True, ) -> tuple[BasePage, "celery.result.AsyncResult", str, str]: # init a new page for every request self = page_cls(request=SimpleNamespace(user=user)) @@ -357,7 +358,7 @@ def submit_api_call( except ValidationError as e: raise RequestValidationError(e.raw_errors, body=gui.session_state) from e # submit the task - result = self.call_runner_task(sr) + result = self.call_runner_task(sr, deduct_credits=deduct_credits) return self, result, sr.run_id, sr.uid