Skip to content

Commit

Permalink
Add deduct_credits flag to control credit deduction after task execution
Browse files Browse the repository at this point in the history
- Don't deduct credits when running called functions, safety checker and admin re-runs
  • Loading branch information
devxpy authored and anish-work committed Aug 14, 2024
1 parent a513df1 commit ba66c02
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 9 deletions.
2 changes: 1 addition & 1 deletion bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def rerun_tasks(self, request, queryset):
page = Workflow(sr.workflow).page_cls(
request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid))
)
page.call_runner_task(sr)
page.call_runner_task(sr, deduct_credits=False)
self.message_user(
request,
f"Started re-running {queryset.count()} tasks in the background.",
Expand Down
4 changes: 4 additions & 0 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def submit_api_call(
request_body: dict,
enable_rate_limits: bool = False,
parent_pr: "PublishedRun" = None,
deduct_credits: bool = True,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
from routers.api import submit_api_call

Expand All @@ -384,6 +385,7 @@ def submit_api_call(
user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
deduct_credits=deduct_credits,
),
)

Expand Down Expand Up @@ -1818,12 +1820,14 @@ def submit_api_call(
current_user: AppUser,
request_body: dict,
enable_rate_limits: bool = False,
deduct_credits: bool = True,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
return self.saved_run.submit_api_call(
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
parent_pr=self,
deduct_credits=deduct_credits,
)


Expand Down
4 changes: 3 additions & 1 deletion celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def runner_task(
uid: str,
channel: str,
unsaved_state: dict[str, typing.Any] = None,
deduct_credits: bool = True,
) -> int:
start_time = time()
error_msg = None
Expand Down Expand Up @@ -107,7 +108,8 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False

# run completed successfully, deduct credits
else:
sr.transaction, sr.price = page.deduct_credits(gui.session_state)
if deduct_credits:
sr.transaction, sr.price = page.deduct_credits(gui.session_state)

# save everything, mark run as completed
finally:
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,7 @@ def dump_state_to_sr(self, state: dict, sr: SavedRun):
}
)

def call_runner_task(self, sr: SavedRun):
def call_runner_task(self, sr: SavedRun, deduct_credits: bool = True):
from celeryapp.tasks import runner_task, post_runner_tasks

chain = (
Expand All @@ -1697,6 +1697,7 @@ def call_runner_task(self, sr: SavedRun):
uid=sr.uid,
channel=self.realtime_channel_name(sr.run_id, sr.uid),
unsaved_state=self._unsaved_state(),
deduct_credits=deduct_credits,
)
| post_runner_tasks.s()
)
Expand Down
1 change: 1 addition & 0 deletions daras_ai_v2/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def safety_checker_text(text_input: str):
.submit_api_call(
current_user=billing_account,
request_body=dict(variables=dict(input=text_input)),
deduct_credits=False,
)
)

Expand Down
1 change: 1 addition & 0 deletions functions/recipe_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def call_recipe_functions(
request_body=dict(
variables=sr.state.get("variables", {}) | variables | fn_vars,
),
deduct_credits=False,
)

CalledFunction.objects.create(
Expand Down
10 changes: 5 additions & 5 deletions recipes/Functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def render_form_v2(self):
)

def get_price_roundoff(self, state: dict) -> float:
try:
# called from another workflow don't charge any credits
CalledFunction.objects.get(function_run=self.get_current_sr())
if CalledFunction.objects.filter(function_run=self.get_current_sr()).exists():
return 0
except CalledFunction.DoesNotExist:
return self.price
return super().get_price_roundoff(state)

def additional_notes(self):
return "\nFunctions are free if called from another workflow."

def render_variables(self):
variables_input(
Expand Down
3 changes: 2 additions & 1 deletion routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def submit_api_call(
query_params: dict,
retention_policy: RetentionPolicy = None,
enable_rate_limits: bool = False,
deduct_credits: bool = True,
) -> tuple[BasePage, "celery.result.AsyncResult", str, str]:
# init a new page for every request
self = page_cls(request=SimpleNamespace(user=user))
Expand All @@ -357,7 +358,7 @@ def submit_api_call(
except ValidationError as e:
raise RequestValidationError(e.raw_errors, body=gui.session_state) from e
# submit the task
result = self.call_runner_task(sr)
result = self.call_runner_task(sr, deduct_credits=deduct_credits)
return self, result, sr.run_id, sr.uid


Expand Down

0 comments on commit ba66c02

Please sign in to comment.