Skip to content

Commit

Permalink
fix parent published run not set when calling submit_api_call()
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Jul 5, 2024
1 parent bb2b75a commit e2b21a0
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 10 deletions.
28 changes: 22 additions & 6 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,19 @@ def submit_api_call(
current_user: AppUser,
request_body: dict,
enable_rate_limits: bool = False,
parent_pr: "PublishedRun" = None,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
from routers.api import submit_api_call

# run in a thread to avoid messing up threadlocals
with ThreadPool(1) as pool:
pr = self.parent_published_run()
query_params = dict(
example_id=(pr and pr.published_run_id) or self.example_id,
run_id=self.run_id,
uid=self.uid,
)
if parent_pr and parent_pr.saved_run == self:
# avoid passing run_id and uid for examples
query_params = dict(example_id=parent_pr.published_run_id)
else:
query_params = dict(
example_id=self.example_id, run_id=self.run_id, uid=self.uid
)
page, result, run_id, uid = pool.apply(
submit_api_call,
kwds=dict(
Expand Down Expand Up @@ -1678,6 +1680,20 @@ def get_run_count(self):
or 0
)

def submit_api_call(
self,
*,
current_user: AppUser,
request_body: dict,
enable_rate_limits: bool = False,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
return self.saved_run.submit_api_call(
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
parent_pr=self,
)


class PublishedRunVersion(models.Model):
version_id = models.CharField(max_length=128, unique=True)
Expand Down
4 changes: 3 additions & 1 deletion bots/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None):

# make the api call
result, sr = analysis_sr.submit_api_call(
current_user=billing_account, request_body=dict(variables=variables)
current_user=billing_account,
request_body=dict(variables=variables),
parent_pr=anal.published_run,
)

# save the run before the result is ready
Expand Down
2 changes: 1 addition & 1 deletion daras_ai_v2/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def safety_checker_text(text_input: str):
result, sr = (
CompareLLMPage()
.get_published_run(published_run_id=settings.SAFTY_CHECKER_EXAMPLE_ID)
.saved_run.submit_api_call(
.submit_api_call(
current_user=billing_account,
request_body=dict(variables=dict(input=text_input)),
)
Expand Down
1 change: 1 addition & 0 deletions functions/recipe_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def call_recipe_functions(
page_cls, sr, pr = url_to_runs(fun.url)
result, sr = sr.submit_api_call(
current_user=current_user,
parent_pr=pr,
request_body=dict(
variables=sr.state.get("variables", {})
| variables
Expand Down
6 changes: 4 additions & 2 deletions recipes/BulkRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ def run_v2(
yield f"{progress}%"

result, sr = sr.submit_api_call(
current_user=self.request.user, request_body=request_body
current_user=self.request.user,
request_body=request_body,
parent_pr=pr,
)
get_celery_result_db_safe(result)
sr.refresh_from_db()
Expand Down Expand Up @@ -388,7 +390,7 @@ def run_v2(
documents=response.output_documents
).dict(exclude_unset=True)
result, sr = sr.submit_api_call(
current_user=self.request.user, request_body=request_body
current_user=self.request.user, request_body=request_body, parent_pr=pr
)
get_celery_result_db_safe(result)
sr.refresh_from_db()
Expand Down

0 comments on commit e2b21a0

Please sign in to comment.