diff --git a/bots/models.py b/bots/models.py index 3a78c46b3..fbb124726 100644 --- a/bots/models.py +++ b/bots/models.py @@ -17,6 +17,7 @@ from bots.custom_fields import PostgresJSONEncoder, CustomURLField from daras_ai_v2.crypto import get_random_doc_id from daras_ai_v2.language_model import format_chat_entry +from gooeysite.custom_create import get_or_create_lazy if typing.TYPE_CHECKING: from daras_ai_v2.base import BasePage @@ -117,42 +118,43 @@ def page_cls(self) -> typing.Type["BasePage"]: return workflow_map[self] def get_or_create_metadata(self) -> "WorkflowMetadata": - metadata, _created = WorkflowMetadata.objects.get_or_create( + return get_or_create_lazy( + WorkflowMetadata, workflow=self, - defaults=dict( - short_title=lambda: ( + create=lambda **kwargs: WorkflowMetadata.objects.create( + **kwargs, + short_title=( self.page_cls.get_root_published_run().title or self.page_cls.title ), default_image=self.page_cls.explore_image or "", - meta_title=lambda: ( + meta_title=( self.page_cls.get_root_published_run().title or self.page_cls.title ), - meta_description=lambda: ( + meta_description=( self.page_cls().preview_description(state={}) or self.page_cls.get_root_published_run().notes ), meta_image=self.page_cls.explore_image or "", ), - ) - return metadata + )[0] class WorkflowMetadata(models.Model): workflow = models.IntegerField(choices=Workflow.choices, unique=True) - short_title = models.TextField() - help_url = models.URLField(blank=True, default="") - # TODO: support the below fields + short_title = models.TextField(help_text="Title used in breadcrumbs") default_image = models.URLField( - blank=True, default="", help_text="(not implemented)" + blank=True, default="", help_text="Image shown on explore page" ) meta_title = models.TextField() meta_description = models.TextField(blank=True, default="") meta_image = CustomURLField(default="", blank=True) + meta_keywords = models.JSONField( - default=list, blank=True, help_text="(not implemented)" + default=list, blank=True, help_text="(Not implemented)" ) + help_url = models.URLField(blank=True, default="", help_text="(Not implemented)") created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -1442,34 +1444,58 @@ class FeedbackComment(models.Model): class PublishedRunQuerySet(models.QuerySet): - def create_published_run( + def get_or_create_with_version( self, *, workflow: Workflow, published_run_id: str, saved_run: SavedRun, - user: AppUser, + user: AppUser | None, + title: str, + notes: str, + visibility: PublishedRunVisibility, + ): + return get_or_create_lazy( + PublishedRun, + workflow=workflow, + published_run_id=published_run_id, + create=lambda **kwargs: self.create_with_version( + **kwargs, + saved_run=saved_run, + user=user, + title=title, + notes=notes, + visibility=visibility, + ), + ) + + def create_with_version( + self, + *, + workflow: Workflow, + published_run_id: str, + saved_run: SavedRun, + user: AppUser | None, title: str, notes: str, visibility: PublishedRunVisibility, ): with transaction.atomic(): - published_run = PublishedRun( + pr = self.create( workflow=workflow, published_run_id=published_run_id, created_by=user, last_edited_by=user, title=title, ) - published_run.save() - published_run.add_version( + pr.add_version( user=user, saved_run=saved_run, title=title, visibility=visibility, notes=notes, ) - return published_run + return pr class PublishedRun(models.Model): @@ -1571,7 +1597,7 @@ def duplicate( notes: str, visibility: PublishedRunVisibility, ) -> "PublishedRun": - return PublishedRun.objects.create_published_run( + return PublishedRun.objects.create_with_version( workflow=Workflow(self.workflow), published_run_id=get_random_doc_id(), saved_run=self.saved_run, @@ -1589,7 +1615,7 @@ def get_app_url(self): def add_version( self, *, - user: AppUser, + user: AppUser | None, saved_run: SavedRun, visibility: PublishedRunVisibility, title: str, diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 3f582abba..6b641314d 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -796,11 +796,7 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR @classmethod def get_recipe_title(cls) -> str: - return ( - cls.get_or_create_root_published_run().title - or cls.title - or cls.workflow.label - ) + return cls.get_root_published_run().title or cls.title or cls.workflow.label def get_explore_image(self) -> str: meta = self.workflow.get_or_create_metadata() @@ -1076,10 +1072,6 @@ def get_pr_from_query_params( else: return cls.get_root_published_run() - @classmethod - def get_root_published_run(cls) -> PublishedRun: - return cls.get_published_run(published_run_id="") - @classmethod def get_published_run(cls, *, published_run_id: str): return PublishedRun.objects.get( @@ -1120,37 +1112,10 @@ def get_total_runs(cls) -> int: # TODO: fix to also handle published run case return SavedRun.objects.filter(workflow=cls.workflow).count() - @classmethod - def get_or_create_root_published_run(cls) -> PublishedRun: - def get_defaults(): - return dict( - saved_run=( - SavedRun.objects.get_or_create( - example_id="", - workflow=cls.workflow, - defaults=dict(state=cls.load_state_defaults({})), - )[0] - ), - created_by=None, - last_edited_by=None, - title=cls.title, - notes=cls().preview_description(state=cls.sane_defaults), - visibility=PublishedRunVisibility(PublishedRunVisibility.PUBLIC), - is_approved_example=True, - ) - - published_run, _ = get_or_create_lazy( - PublishedRun, - workflow=cls.workflow, - published_run_id="", - get_defaults=get_defaults, - ) - return published_run - @classmethod def recipe_doc_sr(cls, create: bool = True) -> SavedRun: if create: - return cls.get_or_create_root_published_run().saved_run + return cls.get_root_published_run().saved_run else: return cls.get_root_published_run().saved_run @@ -1168,18 +1133,34 @@ def run_doc_sr( else: return SavedRun.objects.get(**config) + @classmethod + def get_root_published_run(cls) -> PublishedRun: + return PublishedRun.objects.get_or_create_with_version( + workflow=cls.workflow, + published_run_id="", + saved_run=SavedRun.objects.get_or_create( + example_id="", + workflow=cls.workflow, + defaults=dict(state=cls.load_state_defaults({})), + )[0], + user=None, + title=cls.title, + notes=cls().preview_description(state=cls.sane_defaults), + visibility=PublishedRunVisibility.PUBLIC, + )[0] + @classmethod def create_published_run( cls, *, published_run_id: str, saved_run: SavedRun, - user: AppUser, + user: AppUser | None, title: str, notes: str, visibility: PublishedRunVisibility, ): - return PublishedRun.objects.create_published_run( + return PublishedRun.objects.create_with_version( workflow=cls.workflow, published_run_id=published_run_id, saved_run=saved_run, diff --git a/explore.py b/explore.py index 2bf7279fe..47bf15ffa 100644 --- a/explore.py +++ b/explore.py @@ -85,7 +85,7 @@ def render_description(page: BasePage): with gui.link(to=page.app_url()): gui.markdown(f"#### {page.get_recipe_title()}") - root_pr = page.get_or_create_root_published_run() + root_pr = page.get_root_published_run() notes = root_pr.notes or page.preview_description(state=page.sane_defaults) with gui.tag("p", style={"marginBottom": "25px"}): gui.write(notes, line_clamp=4) diff --git a/gooeysite/custom_create.py b/gooeysite/custom_create.py index d1a328776..3668eb8fa 100644 --- a/gooeysite/custom_create.py +++ b/gooeysite/custom_create.py @@ -2,11 +2,10 @@ from django.db import transaction, IntegrityError from django.db.models import Model -from django.db.models.utils import resolve_callables def get_or_create_lazy( - model: typing.Type[Model], get_defaults: typing.Callable[..., dict] = None, **kwargs + model: typing.Type[Model], create: typing.Callable[..., dict] = None, **kwargs ): """ Look up an object with the given kwargs, creating one if necessary. @@ -20,13 +19,10 @@ def get_or_create_lazy( try: return self.get(**kwargs), False except self.model.DoesNotExist: - defaults = get_defaults and get_defaults() - params = self._extract_model_params(defaults, **kwargs) # Try to create an object using passed params. try: with transaction.atomic(using=self.db): - params = dict(resolve_callables(params)) - return self.create(**params), True + return create(**kwargs), True except IntegrityError: try: return self.get(**kwargs), False