Skip to content

Commit

Permalink
fix published run init for new recipes
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Jul 3, 2024
1 parent 92bd027 commit 3f67698
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 66 deletions.
66 changes: 46 additions & 20 deletions bots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from bots.custom_fields import PostgresJSONEncoder, CustomURLField
from daras_ai_v2.crypto import get_random_doc_id
from daras_ai_v2.language_model import format_chat_entry
from gooeysite.custom_create import get_or_create_lazy

if typing.TYPE_CHECKING:
from daras_ai_v2.base import BasePage
Expand Down Expand Up @@ -117,42 +118,43 @@ def page_cls(self) -> typing.Type["BasePage"]:
return workflow_map[self]

def get_or_create_metadata(self) -> "WorkflowMetadata":
metadata, _created = WorkflowMetadata.objects.get_or_create(
return get_or_create_lazy(
WorkflowMetadata,
workflow=self,
defaults=dict(
short_title=lambda: (
create=lambda **kwargs: WorkflowMetadata.objects.create(
**kwargs,
short_title=(
self.page_cls.get_root_published_run().title or self.page_cls.title
),
default_image=self.page_cls.explore_image or "",
meta_title=lambda: (
meta_title=(
self.page_cls.get_root_published_run().title or self.page_cls.title
),
meta_description=lambda: (
meta_description=(
self.page_cls().preview_description(state={})
or self.page_cls.get_root_published_run().notes
),
meta_image=self.page_cls.explore_image or "",
),
)
return metadata
)[0]


class WorkflowMetadata(models.Model):
workflow = models.IntegerField(choices=Workflow.choices, unique=True)
short_title = models.TextField()
help_url = models.URLField(blank=True, default="")

# TODO: support the below fields
short_title = models.TextField(help_text="Title used in breadcrumbs")
default_image = models.URLField(
blank=True, default="", help_text="(not implemented)"
blank=True, default="", help_text="Image shown on explore page"
)

meta_title = models.TextField()
meta_description = models.TextField(blank=True, default="")
meta_image = CustomURLField(default="", blank=True)

meta_keywords = models.JSONField(
default=list, blank=True, help_text="(not implemented)"
default=list, blank=True, help_text="(Not implemented)"
)
help_url = models.URLField(blank=True, default="", help_text="(Not implemented)")

created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
Expand Down Expand Up @@ -1442,34 +1444,58 @@ class FeedbackComment(models.Model):


class PublishedRunQuerySet(models.QuerySet):
def create_published_run(
def get_or_create_with_version(
self,
*,
workflow: Workflow,
published_run_id: str,
saved_run: SavedRun,
user: AppUser,
user: AppUser | None,
title: str,
notes: str,
visibility: PublishedRunVisibility,
):
return get_or_create_lazy(
PublishedRun,
workflow=workflow,
published_run_id=published_run_id,
create=lambda **kwargs: self.create_with_version(
**kwargs,
saved_run=saved_run,
user=user,
title=title,
notes=notes,
visibility=visibility,
),
)

def create_with_version(
self,
*,
workflow: Workflow,
published_run_id: str,
saved_run: SavedRun,
user: AppUser | None,
title: str,
notes: str,
visibility: PublishedRunVisibility,
):
with transaction.atomic():
published_run = PublishedRun(
pr = self.create(
workflow=workflow,
published_run_id=published_run_id,
created_by=user,
last_edited_by=user,
title=title,
)
published_run.save()
published_run.add_version(
pr.add_version(
user=user,
saved_run=saved_run,
title=title,
visibility=visibility,
notes=notes,
)
return published_run
return pr


class PublishedRun(models.Model):
Expand Down Expand Up @@ -1571,7 +1597,7 @@ def duplicate(
notes: str,
visibility: PublishedRunVisibility,
) -> "PublishedRun":
return PublishedRun.objects.create_published_run(
return PublishedRun.objects.create_with_version(
workflow=Workflow(self.workflow),
published_run_id=get_random_doc_id(),
saved_run=self.saved_run,
Expand All @@ -1589,7 +1615,7 @@ def get_app_url(self):
def add_version(
self,
*,
user: AppUser,
user: AppUser | None,
saved_run: SavedRun,
visibility: PublishedRunVisibility,
title: str,
Expand Down
59 changes: 20 additions & 39 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,7 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR

@classmethod
def get_recipe_title(cls) -> str:
return (
cls.get_or_create_root_published_run().title
or cls.title
or cls.workflow.label
)
return cls.get_root_published_run().title or cls.title or cls.workflow.label

def get_explore_image(self) -> str:
meta = self.workflow.get_or_create_metadata()
Expand Down Expand Up @@ -1076,10 +1072,6 @@ def get_pr_from_query_params(
else:
return cls.get_root_published_run()

@classmethod
def get_root_published_run(cls) -> PublishedRun:
return cls.get_published_run(published_run_id="")

@classmethod
def get_published_run(cls, *, published_run_id: str):
return PublishedRun.objects.get(
Expand Down Expand Up @@ -1120,37 +1112,10 @@ def get_total_runs(cls) -> int:
# TODO: fix to also handle published run case
return SavedRun.objects.filter(workflow=cls.workflow).count()

@classmethod
def get_or_create_root_published_run(cls) -> PublishedRun:
def get_defaults():
return dict(
saved_run=(
SavedRun.objects.get_or_create(
example_id="",
workflow=cls.workflow,
defaults=dict(state=cls.load_state_defaults({})),
)[0]
),
created_by=None,
last_edited_by=None,
title=cls.title,
notes=cls().preview_description(state=cls.sane_defaults),
visibility=PublishedRunVisibility(PublishedRunVisibility.PUBLIC),
is_approved_example=True,
)

published_run, _ = get_or_create_lazy(
PublishedRun,
workflow=cls.workflow,
published_run_id="",
get_defaults=get_defaults,
)
return published_run

@classmethod
def recipe_doc_sr(cls, create: bool = True) -> SavedRun:
if create:
return cls.get_or_create_root_published_run().saved_run
return cls.get_root_published_run().saved_run
else:
return cls.get_root_published_run().saved_run

Expand All @@ -1168,18 +1133,34 @@ def run_doc_sr(
else:
return SavedRun.objects.get(**config)

@classmethod
def get_root_published_run(cls) -> PublishedRun:
return PublishedRun.objects.get_or_create_with_version(
workflow=cls.workflow,
published_run_id="",
saved_run=SavedRun.objects.get_or_create(
example_id="",
workflow=cls.workflow,
defaults=dict(state=cls.load_state_defaults({})),
)[0],
user=None,
title=cls.title,
notes=cls().preview_description(state=cls.sane_defaults),
visibility=PublishedRunVisibility.PUBLIC,
)[0]

@classmethod
def create_published_run(
cls,
*,
published_run_id: str,
saved_run: SavedRun,
user: AppUser,
user: AppUser | None,
title: str,
notes: str,
visibility: PublishedRunVisibility,
):
return PublishedRun.objects.create_published_run(
return PublishedRun.objects.create_with_version(
workflow=cls.workflow,
published_run_id=published_run_id,
saved_run=saved_run,
Expand Down
2 changes: 1 addition & 1 deletion explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def render_description(page: BasePage):
with gui.link(to=page.app_url()):
gui.markdown(f"#### {page.get_recipe_title()}")

root_pr = page.get_or_create_root_published_run()
root_pr = page.get_root_published_run()
notes = root_pr.notes or page.preview_description(state=page.sane_defaults)
with gui.tag("p", style={"marginBottom": "25px"}):
gui.write(notes, line_clamp=4)
Expand Down
8 changes: 2 additions & 6 deletions gooeysite/custom_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from django.db import transaction, IntegrityError
from django.db.models import Model
from django.db.models.utils import resolve_callables


def get_or_create_lazy(
model: typing.Type[Model], get_defaults: typing.Callable[..., dict] = None, **kwargs
model: typing.Type[Model], create: typing.Callable[..., dict] = None, **kwargs
):
"""
Look up an object with the given kwargs, creating one if necessary.
Expand All @@ -20,13 +19,10 @@ def get_or_create_lazy(
try:
return self.get(**kwargs), False
except self.model.DoesNotExist:
defaults = get_defaults and get_defaults()
params = self._extract_model_params(defaults, **kwargs)
# Try to create an object using passed params.
try:
with transaction.atomic(using=self.db):
params = dict(resolve_callables(params))
return self.create(**params), True
return create(**kwargs), True
except IntegrityError:
try:
return self.get(**kwargs), False
Expand Down

0 comments on commit 3f67698

Please sign in to comment.