Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functions: charge 1 credit in /functions and no credit from another workflow #436

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def rerun_tasks(self, request, queryset):
page = Workflow(sr.workflow).page_cls(
request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid))
)
page.call_runner_task(sr)
page.call_runner_task(sr, deduct_credits=False)
self.message_user(
request,
f"Started re-running {queryset.count()} tasks in the background.",
Expand Down
4 changes: 4 additions & 0 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def submit_api_call(
request_body: dict,
enable_rate_limits: bool = False,
parent_pr: "PublishedRun" = None,
deduct_credits: bool = True,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
from routers.api import submit_api_call

Expand All @@ -384,6 +385,7 @@ def submit_api_call(
user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
deduct_credits=deduct_credits,
),
)

Expand Down Expand Up @@ -1818,12 +1820,14 @@ def submit_api_call(
current_user: AppUser,
request_body: dict,
enable_rate_limits: bool = False,
deduct_credits: bool = True,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
return self.saved_run.submit_api_call(
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
parent_pr=self,
deduct_credits=deduct_credits,
)


Expand Down
4 changes: 3 additions & 1 deletion celeryapp/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def runner_task(
uid: str,
channel: str,
unsaved_state: dict[str, typing.Any] = None,
deduct_credits: bool = True,
) -> int:
start_time = time()
error_msg = None
Expand Down Expand Up @@ -107,7 +108,8 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False

# run completed successfully, deduct credits
else:
sr.transaction, sr.price = page.deduct_credits(gui.session_state)
if deduct_credits:
sr.transaction, sr.price = page.deduct_credits(gui.session_state)

# save everything, mark run as completed
finally:
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,7 @@ def dump_state_to_sr(self, state: dict, sr: SavedRun):
}
)

def call_runner_task(self, sr: SavedRun):
def call_runner_task(self, sr: SavedRun, deduct_credits: bool = True):
from celeryapp.tasks import runner_task, post_runner_tasks

chain = (
Expand All @@ -1697,6 +1697,7 @@ def call_runner_task(self, sr: SavedRun):
uid=sr.uid,
channel=self.realtime_channel_name(sr.run_id, sr.uid),
unsaved_state=self._unsaved_state(),
deduct_credits=deduct_credits,
)
| post_runner_tasks.s()
)
Expand Down
1 change: 1 addition & 0 deletions daras_ai_v2/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def safety_checker_text(text_input: str):
.submit_api_call(
current_user=billing_account,
request_body=dict(variables=dict(input=text_input)),
deduct_credits=False,
)
)

Expand Down
1 change: 1 addition & 0 deletions functions/recipe_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def call_recipe_functions(
request_body=dict(
variables=sr.state.get("variables", {}) | variables | fn_vars,
),
deduct_credits=False,
)

CalledFunction.objects.create(
Expand Down
10 changes: 10 additions & 0 deletions recipes/Functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from daras_ai_v2.exceptions import raise_for_status
from daras_ai_v2.field_render import field_title_desc
from daras_ai_v2.prompt_vars import variables_input
from functions.models import CalledFunction


class ConsoleLogs(BaseModel):
Expand All @@ -22,6 +23,7 @@ class FunctionsPage(BasePage):
workflow = Workflow.FUNCTIONS
slug_versions = ["functions", "tools", "function", "fn", "functions"]
show_settings = False
price = 1

class RequestModel(BaseModel):
code: str = Field(
Expand Down Expand Up @@ -83,6 +85,14 @@ def render_form_v2(self):
height=300,
)

def get_price_roundoff(self, state: dict) -> float:
devxpy marked this conversation as resolved.
Show resolved Hide resolved
if CalledFunction.objects.filter(function_run=self.get_current_sr()).exists():
return 0
return super().get_price_roundoff(state)

def additional_notes(self):
return "\nFunctions are free if called from another workflow."

def render_variables(self):
variables_input(
template_keys=["code"],
Expand Down
3 changes: 2 additions & 1 deletion routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def submit_api_call(
query_params: dict,
retention_policy: RetentionPolicy = None,
enable_rate_limits: bool = False,
deduct_credits: bool = True,
) -> tuple[BasePage, "celery.result.AsyncResult", str, str]:
# init a new page for every request
self = page_cls(request=SimpleNamespace(user=user))
Expand All @@ -357,7 +358,7 @@ def submit_api_call(
except ValidationError as e:
raise RequestValidationError(e.raw_errors, body=gui.session_state) from e
# submit the task
result = self.call_runner_task(sr)
result = self.call_runner_task(sr, deduct_credits=deduct_credits)
return self, result, sr.run_id, sr.uid


Expand Down
Loading