Skip to content

Commit

Permalink
Refactor BasePage methods to consolidate SavedRun and `PublishedR…
Browse files Browse the repository at this point in the history
…un` retrieval logic into `get_sr_pr`

Remove usage of global gui.get_query_params
  • Loading branch information
devxpy committed Aug 29, 2024
1 parent 904e836 commit 13a4e5c
Show file tree
Hide file tree
Showing 21 changed files with 192 additions and 248 deletions.
5 changes: 4 additions & 1 deletion bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def rerun_tasks(self, request, queryset):
sr: SavedRun
for sr in queryset.all():
page = Workflow(sr.workflow).page_cls(
request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid))
request=SimpleNamespace(
user=AppUser.objects.get(uid=sr.uid),
query_params=dict(run_id=sr.run_id, uid=sr.uid),
)
)
page.call_runner_task(sr, deduct_credits=False)
self.message_user(
Expand Down
2 changes: 1 addition & 1 deletion bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def submit_api_call(
),
)

return result, page.run_doc_sr(run_id, uid)
return result, page.get_sr_pr()[0]

def get_creator(self) -> AppUser | None:
if self.uid:
Expand Down
23 changes: 19 additions & 4 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import html
import threading
import traceback
import typing
from time import time
Expand Down Expand Up @@ -31,6 +32,15 @@

DEFAULT_RUN_STATUS = "Running..."

threadlocal = threading.local()


def get_current_saved_run() -> SavedRun | None:
try:
return threadlocal.saved_run
except AttributeError:
return None


@app.task
def runner_task(
Expand Down Expand Up @@ -81,12 +91,16 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
# save to db
page.dump_state_to_sr(gui.session_state | output, sr)

user = AppUser.objects.get(id=user_id)
page = page_cls(request=SimpleNamespace(user=user))
page = page_cls(
request=SimpleNamespace(
user=AppUser.objects.get(id=user_id),
query_params=dict(run_id=run_id, uid=uid),
),
)
page.setup_sentry()
sr = page.run_doc_sr(run_id, uid)
sr = page.get_sr_pr()[0]
threadlocal.saved_run = sr
gui.set_session_state(sr.to_dict() | (unsaved_state or {}))
gui.set_query_params(dict(run_id=run_id, uid=uid))

try:
save_on_step()
Expand Down Expand Up @@ -114,6 +128,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
# save everything, mark run as completed
finally:
save_on_step(done=True)
threadlocal.saved_run = None

post_runner_tasks.delay(sr.id)

Expand Down
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def mock_celery_tasks():
def _mock_runner_task(
*, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs
):
sr = page_cls.run_doc_sr(run_id, uid)
sr = page_cls.get_saved_run(run_id, uid)
sr.set(sr.parent.to_dict())
sr.save()
channel = page_cls().realtime_channel_name(run_id, uid)
channel = page_cls.realtime_channel_name(run_id, uid)
_mock_realtime_push(channel, sr.to_dict())


Expand Down
Loading

0 comments on commit 13a4e5c

Please sign in to comment.