Skip to content

Commit

Permalink
Refactor BasePage methods to consolidate SavedRun and `PublishedR…
Browse files Browse the repository at this point in the history
…un` retrieval logic into `get_sr_pr`

Remove usage of global gui.get_query_params
  • Loading branch information
devxpy committed Aug 31, 2024
1 parent 7dfbe6c commit 19e4f97
Show file tree
Hide file tree
Showing 25 changed files with 237 additions and 307 deletions.
5 changes: 4 additions & 1 deletion bots/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,10 @@ def rerun_tasks(self, request, queryset):
sr: SavedRun
for sr in queryset.all():
page = Workflow(sr.workflow).page_cls(
request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid))
request=SimpleNamespace(
user=AppUser.objects.get(uid=sr.uid),
query_params=dict(run_id=sr.run_id, uid=sr.uid),
)
)
page.call_runner_task(sr, deduct_credits=False)
self.message_user(
Expand Down
12 changes: 4 additions & 8 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,12 @@ def get_or_create_metadata(self) -> "WorkflowMetadata":
workflow=self,
create=lambda **kwargs: WorkflowMetadata.objects.create(
**kwargs,
short_title=(
self.page_cls.get_root_published_run().title or self.page_cls.title
),
short_title=(self.page_cls.get_root_pr().title or self.page_cls.title),
default_image=self.page_cls.explore_image or "",
meta_title=(
self.page_cls.get_root_published_run().title or self.page_cls.title
),
meta_title=(self.page_cls.get_root_pr().title or self.page_cls.title),
meta_description=(
self.page_cls().preview_description(state={})
or self.page_cls.get_root_published_run().notes
or self.page_cls.get_root_pr().notes
),
meta_image=self.page_cls.explore_image or "",
),
Expand Down Expand Up @@ -389,7 +385,7 @@ def submit_api_call(
),
)

return result, page.run_doc_sr(run_id, uid)
return result, page.current_sr

def get_creator(self) -> AppUser | None:
if self.uid:
Expand Down
23 changes: 19 additions & 4 deletions celeryapp/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import html
import threading
import traceback
import typing
from time import time
Expand Down Expand Up @@ -31,6 +32,15 @@

DEFAULT_RUN_STATUS = "Running..."

threadlocal = threading.local()


def get_running_saved_run() -> SavedRun | None:
try:
return threadlocal.saved_run
except AttributeError:
return None


@app.task
def runner_task(
Expand Down Expand Up @@ -81,12 +91,16 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
# save to db
page.dump_state_to_sr(gui.session_state | output, sr)

user = AppUser.objects.get(id=user_id)
page = page_cls(request=SimpleNamespace(user=user))
page = page_cls(
request=SimpleNamespace(
user=AppUser.objects.get(id=user_id),
query_params=dict(run_id=run_id, uid=uid),
),
)
page.setup_sentry()
sr = page.run_doc_sr(run_id, uid)
sr = page.current_sr
threadlocal.saved_run = sr
gui.set_session_state(sr.to_dict() | (unsaved_state or {}))
gui.set_query_params(dict(run_id=run_id, uid=uid))

try:
save_on_step()
Expand Down Expand Up @@ -114,6 +128,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
# save everything, mark run as completed
finally:
save_on_step(done=True)
threadlocal.saved_run = None

post_runner_tasks.delay(sr.id)

Expand Down
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def mock_celery_tasks():
def _mock_runner_task(
*, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs
):
sr = page_cls.run_doc_sr(run_id, uid)
sr = page_cls.get_sr_from_ids(run_id, uid)
sr.set(sr.parent.to_dict())
sr.save()
channel = page_cls().realtime_channel_name(run_id, uid)
channel = page_cls.realtime_channel_name(run_id, uid)
_mock_realtime_push(channel, sr.to_dict())


Expand Down
Loading

0 comments on commit 19e4f97

Please sign in to comment.