Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor BasePage methods to consolidate SavedRun and PublishedRun retrieval logic into get_sr_pr + caching #451

Merged
merged 4 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bots/admin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
import json
from types import SimpleNamespace

import django.db.models
from django import forms
Expand Down Expand Up @@ -439,7 +438,8 @@ def rerun_tasks(self, request, queryset):
sr: SavedRun
for sr in queryset.all():
page = Workflow(sr.workflow).page_cls(
request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid))
user=AppUser.objects.get(uid=sr.uid),
query_params=dict(run_id=sr.run_id, uid=sr.uid),
)
page.call_runner_task(sr, deduct_credits=False)
self.message_user(
Expand Down
23 changes: 11 additions & 12 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.language_model import format_chat_entry
from functions.models import CalledFunction, CalledFunctionResponse
from gooeysite.bg_db_conn import get_celery_result_db_safe
from gooeysite.custom_create import get_or_create_lazy

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -127,16 +128,12 @@ def get_or_create_metadata(self) -> "WorkflowMetadata":
workflow=self,
create=lambda **kwargs: WorkflowMetadata.objects.create(
**kwargs,
short_title=(
self.page_cls.get_root_published_run().title or self.page_cls.title
),
short_title=(self.page_cls.get_root_pr().title or self.page_cls.title),
default_image=self.page_cls.explore_image or "",
meta_title=(
self.page_cls.get_root_published_run().title or self.page_cls.title
),
meta_title=(self.page_cls.get_root_pr().title or self.page_cls.title),
meta_description=(
self.page_cls().preview_description(state={})
or self.page_cls.get_root_published_run().notes
or self.page_cls.get_root_pr().notes
),
meta_image=self.page_cls.explore_image or "",
),
Expand Down Expand Up @@ -362,8 +359,8 @@ def submit_api_call(
current_user: AppUser,
request_body: dict,
enable_rate_limits: bool = False,
parent_pr: "PublishedRun" = None,
deduct_credits: bool = True,
parent_pr: "PublishedRun" = None,
) -> tuple["celery.result.AsyncResult", "SavedRun"]:
from routers.api import submit_api_call

Expand All @@ -377,19 +374,21 @@ def submit_api_call(
query_params = page_cls.clean_query_params(
example_id=self.example_id, run_id=self.run_id, uid=self.uid
)
page, result, run_id, uid = pool.apply(
return pool.apply(
submit_api_call,
kwds=dict(
page_cls=page_cls,
query_params=query_params,
user=current_user,
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
deduct_credits=deduct_credits,
),
)

return result, page.run_doc_sr(run_id, uid)
def wait_for_celery_result(self, result: "celery.result.AsyncResult"):
get_celery_result_db_safe(result)
self.refresh_from_db()

def get_creator(self) -> AppUser | None:
if self.uid:
Expand Down Expand Up @@ -1843,8 +1842,8 @@ def submit_api_call(
current_user=current_user,
request_body=request_body,
enable_rate_limits=enable_rate_limits,
parent_pr=self,
deduct_credits=deduct_credits,
parent_pr=self,
)


Expand Down
4 changes: 1 addition & 3 deletions bots/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None):
# save the run before the result is ready
Message.objects.filter(id=msg_id).update(analysis_run=sr)

# wait for the result
get_celery_result_db_safe(result)
sr.refresh_from_db()
sr.wait_for_celery_result(result)
# if failed, raise error
if sr.error_msg:
raise RuntimeError(sr.error_msg)
Expand Down
21 changes: 16 additions & 5 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import datetime
import html
import threading
import traceback
import typing
from time import time
from types import SimpleNamespace

import gooey_gui as gui
import requests
Expand Down Expand Up @@ -31,6 +31,15 @@

DEFAULT_RUN_STATUS = "Running..."

threadlocal = threading.local()


def get_running_saved_run() -> SavedRun | None:
try:
return threadlocal.saved_run
except AttributeError:
return None


@app.task
def runner_task(
Expand Down Expand Up @@ -81,12 +90,13 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
# save to db
page.dump_state_to_sr(gui.session_state | output, sr)

user = AppUser.objects.get(id=user_id)
page = page_cls(request=SimpleNamespace(user=user))
page = page_cls(
user=AppUser.objects.get(id=user_id), query_params=dict(run_id=run_id, uid=uid)
)
page.setup_sentry()
sr = page.run_doc_sr(run_id, uid)
sr = page.current_sr
threadlocal.saved_run = sr
gui.set_session_state(sr.to_dict() | (unsaved_state or {}))
gui.set_query_params(dict(run_id=run_id, uid=uid))

try:
save_on_step()
Expand Down Expand Up @@ -114,6 +124,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
# save everything, mark run as completed
finally:
save_on_step(done=True)
threadlocal.saved_run = None

post_runner_tasks.delay(sr.id)

Expand Down
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def mock_celery_tasks():
def _mock_runner_task(
*, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs
):
sr = page_cls.run_doc_sr(run_id, uid)
sr = page_cls.get_sr_from_ids(run_id, uid)
sr.set(sr.parent.to_dict())
sr.save()
channel = page_cls().realtime_channel_name(run_id, uid)
channel = page_cls.realtime_channel_name(run_id, uid)
_mock_realtime_push(channel, sr.to_dict())


Expand Down
Loading
Loading