diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a1786492..65af59112 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: - - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml - repo: local diff --git a/Dockerfile b/Dockerfile index 38576f406..b3b015f4c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,6 +42,8 @@ RUN pip install --no-cache-dir -U poetry pip && poetry install --no-cache --only # install nltk stopwords RUN poetry run python -c 'import nltk; nltk.download("stopwords")' +# install playwright +RUN poetry run playwright install-deps && poetry run playwright install # copy the code into the container COPY . . diff --git a/README.md b/README.md index beb3748f9..a2fbe6a7a 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ * Install [pyenv](https://github.com/pyenv/pyenv) & install the same python version as in our [Dockerfile](Dockerfile) * Install [poetry](https://python-poetry.org/docs/) +* Clone the github repo to gooey-server (and make sure that's the folder name) * Create & activate a virtualenv (e.g. `poetry shell`) * Run `poetry install --with dev` * Install [redis](https://redis.io/docs/getting-started/installation/install-redis-on-mac-os/), [rabbitmq](https://www.rabbitmq.com/install-homebrew.html), and [postgresql](https://formulae.brew.sh/formula/postgresql@15) (e.g. `brew install redis rabbitmq postgresql@15`) @@ -87,7 +88,7 @@ gooey.ai (dev) App ID: 228027632918921 ``` -Create a [meta developer account](https://developers.facebook.com/docs/development/register/) & ask someone to add you to the test app [here](https://developers.facebook.com/apps/228027632918921/roles/roles/?business_id=549319917267066) +Create a [meta developer account](https://developers.facebook.com/docs/development/register/) & send admin your **facebook ID** to add you to the test app [here](https://developers.facebook.com/apps/228027632918921/roles/roles/?business_id=549319917267066) 1. start ngrok @@ -107,6 +108,12 @@ ngrok http 8080 5. Copy the temporary access token there and set env var `WHATSAPP_ACCESS_TOKEN = XXXX` +**(Optional) Use the test script to send yourself messages** + +```bash +python manage.py runscript test_wa_msg_send --script-args 104696745926402 +918764022384 +``` +Replace `+918764022384` with your number and `104696745926402` with the test number ID ## Dangerous postgres commands @@ -126,6 +133,7 @@ docker cp $cid:/app/$fname . echo $PWD/$fname ``` +**on local** ```bash # reset the database ./manage.py reset_db -c @@ -137,6 +145,7 @@ pg_restore --no-privileges --no-owner -d $PGDATABASE $fname ### create & load fixtures +**on server** ```bash # select a running container cid=$(docker ps | grep gooey-api-prod | cut -d " " -f 1 | head -1) @@ -148,6 +157,12 @@ docker cp $cid:/app/fixture.json . echo $PWD/fixture.json ``` +**on local** +```bash +# copy fixture.json from server to local +rsync -P -a @captain.us-1.gooey.ai:/home//fixture.json . +``` + ```bash # reset the database ./manage.py reset_db -c @@ -157,12 +172,16 @@ echo $PWD/fixture.json ./manage.py migrate # load the fixture ./manage.py loaddata fixture.json +# create a superuser to access admin +./manage.py createsuperuser ``` ### copy one postgres db to another -``` +**on server** +```bash ./manage.py reset_db createdb -T template0 $PGDATABASE pg_dump $SOURCE_DATABASE | psql -q $PGDATABASE ``` + diff --git a/app_users/admin.py b/app_users/admin.py index 5a90c302c..191197eb6 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -1,4 +1,5 @@ from django.contrib import admin +from django.contrib.admin.models import LogEntry from app_users import models from bots.admin_links import open_in_new_tab, list_related_html_url @@ -128,10 +129,40 @@ class AppUserTransactionAdmin(admin.ModelAdmin): "invoice_id", "user", "amount", - "created_at", "end_balance", + "payment_provider", + "dollar_amount", + "created_at", ] readonly_fields = ["created_at"] - list_filter = ["created_at", IsStripeFilter] + list_filter = ["created_at", IsStripeFilter, "payment_provider"] inlines = [SavedRunInline] ordering = ["-created_at"] + + @admin.display(description="Charged Amount") + def dollar_amount(self, obj: models.AppUserTransaction): + if not obj.payment_provider: + return + return f"${obj.charged_amount / 100}" + + +@admin.register(LogEntry) +class LogEntryAdmin(admin.ModelAdmin): + list_display = readonly_fields = [ + "action_time", + "user", + "action_flag", + "content_type", + "object_repr", + "object_id", + "change_message", + ] + + # to have a date-based drilldown navigation in the admin page + date_hierarchy = "action_time" + + # to filter the results by users, content types and action flags + list_filter = ["action_time", "user", "content_type", "action_flag"] + + # when searching the user will be able to search in both object_repr and change_message + search_fields = ["object_repr", "change_message"] diff --git a/app_users/migrations/0011_appusertransaction_charged_amount_and_more.py b/app_users/migrations/0011_appusertransaction_charged_amount_and_more.py new file mode 100644 index 000000000..f11e4c324 --- /dev/null +++ b/app_users/migrations/0011_appusertransaction_charged_amount_and_more.py @@ -0,0 +1,54 @@ +# Generated by Django 4.2.7 on 2024-01-08 13:43 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("app_users", "0010_alter_appuser_balance_alter_appuser_created_at_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="appusertransaction", + name="charged_amount", + field=models.PositiveIntegerField( + default=0, + help_text="The charged dollar amount in the currencyโ€™s smallest unit.
E.g. for 10 USD, this would be of 1000 (that is, 1000 cents).
Learn More", + ), + ), + migrations.AddField( + model_name="appusertransaction", + name="payment_provider", + field=models.IntegerField( + blank=True, + choices=[(1, "Stripe"), (2, "Paypal")], + default=None, + help_text="The payment provider used for this transaction.
If this is provided, the Charged Amount should also be provided.", + null=True, + ), + ), + migrations.AlterField( + model_name="appusertransaction", + name="amount", + field=models.IntegerField( + help_text="The amount (Gooey credits) added/deducted in this transaction.
Positive for credits added, negative for credits deducted." + ), + ), + migrations.AlterField( + model_name="appusertransaction", + name="end_balance", + field=models.IntegerField( + help_text="The end balance (Gooey credits) of the user after this transaction" + ), + ), + migrations.AlterField( + model_name="appusertransaction", + name="invoice_id", + field=models.CharField( + help_text="The Payment Provider's Invoice ID for this transaction.
For Gooey, this will be of the form 'gooey_in_{uuid}'", + max_length=255, + unique=True, + ), + ), + ] diff --git a/app_users/models.py b/app_users/models.py index 5e373f414..576ea0390 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -108,7 +108,9 @@ def first_name(self): @db_middleware @transaction.atomic - def add_balance(self, amount: int, invoice_id: str) -> "AppUserTransaction": + def add_balance( + self, amount: int, invoice_id: str, **kwargs + ) -> "AppUserTransaction": """ Used to add/deduct credits when they are bought or consumed. @@ -117,7 +119,6 @@ def add_balance(self, amount: int, invoice_id: str) -> "AppUserTransaction": When credits are deducted due to a run -- invoice_id is of the form "gooey_in_{uuid}" """ - # if an invoice entry exists try: # avoid updating twice for same invoice @@ -133,12 +134,12 @@ def add_balance(self, amount: int, invoice_id: str) -> "AppUserTransaction": user: AppUser = AppUser.objects.select_for_update().get(pk=self.pk) user.balance += amount user.save(update_fields=["balance"]) - return AppUserTransaction.objects.create( user=self, invoice_id=invoice_id, amount=amount, end_balance=user.balance, + **kwargs, ) def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser": @@ -177,7 +178,9 @@ def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser": default_balance = settings.LOGIN_USER_FREE_CREDITS if self.is_anonymous: default_balance = settings.ANON_USER_FREE_CREDITS - elif provider_list[-1].provider_id == "password": + elif ( + "+" in str(self.email) or "@gmail.com" not in str(self.email) + ) and provider_list[-1].provider_id == "password": default_balance = settings.EMAIL_USER_FREE_CREDITS self.balance = db.get_doc_field( doc_ref=db.get_user_doc_ref(user.uid), @@ -217,13 +220,45 @@ def search_stripe_customer(self) -> stripe.Customer | None: return customer +class PaymentProvider(models.IntegerChoices): + STRIPE = 1, "Stripe" + PAYPAL = 2, "Paypal" + + class AppUserTransaction(models.Model): user = models.ForeignKey( "AppUser", on_delete=models.CASCADE, related_name="transactions" ) - invoice_id = models.CharField(max_length=255, unique=True) - amount = models.IntegerField() - end_balance = models.IntegerField() + invoice_id = models.CharField( + max_length=255, + unique=True, + help_text="The Payment Provider's Invoice ID for this transaction.
" + "For Gooey, this will be of the form 'gooey_in_{uuid}'", + ) + + amount = models.IntegerField( + help_text="The amount (Gooey credits) added/deducted in this transaction.
" + "Positive for credits added, negative for credits deducted." + ) + end_balance = models.IntegerField( + help_text="The end balance (Gooey credits) of the user after this transaction" + ) + + payment_provider = models.IntegerField( + choices=PaymentProvider.choices, + null=True, + blank=True, + default=None, + help_text="The payment provider used for this transaction.
" + "If this is provided, the Charged Amount should also be provided.", + ) + charged_amount = models.PositiveIntegerField( + help_text="The charged dollar amount in the currencyโ€™s smallest unit.
" + "E.g. for 10 USD, this would be of 1000 (that is, 1000 cents).
" + "Learn More", + default=0, + ) + created_at = models.DateTimeField(editable=False, blank=True, default=timezone.now) class Meta: diff --git a/bots/admin.py b/bots/admin.py index f1d924434..4f69a1081 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -11,19 +11,25 @@ from django.utils.safestring import mark_safe from django.utils.timesince import timesince +from app_users.models import AppUser from bots.admin_links import list_related_html_url, change_obj_url from bots.models import ( FeedbackComment, CHATML_ROLE_ASSISSTANT, SavedRun, + PublishedRun, + PublishedRunVersion, Message, Platform, Feedback, Conversation, BotIntegration, + MessageAttachment, + WorkflowMetadata, ) -from app_users.models import AppUser from bots.tasks import create_personal_channels_for_all_members +from daras_ai.image_input import truncate_text_words +from daras_ai_v2.base import BasePage from gooeysite.custom_actions import export_to_excel, export_to_csv from gooeysite.custom_filters import ( related_json_field_summary, @@ -91,13 +97,14 @@ class BotIntegrationAdmin(admin.ModelAdmin): "updated_at", "billing_account_uid", "saved_run", + "published_run", "analysis_run", ] list_filter = ["platform"] form = BotIntegrationAdminForm - autocomplete_fields = ["saved_run", "analysis_run"] + autocomplete_fields = ["saved_run", "published_run", "analysis_run"] readonly_fields = [ "fb_page_access_token", @@ -117,6 +124,7 @@ class BotIntegrationAdmin(admin.ModelAdmin): "fields": [ "name", "saved_run", + "published_run", "billing_account_uid", "user_language", ], @@ -160,6 +168,7 @@ class BotIntegrationAdmin(admin.ModelAdmin): "Settings", { "fields": [ + "streaming_enabled", "show_feedback_buttons", "analysis_run", "view_analysis_results", @@ -203,20 +212,66 @@ def view_analysis_results(self, bi: BotIntegration): return html +@admin.register(PublishedRun) +class PublishedRunAdmin(admin.ModelAdmin): + list_display = [ + "__str__", + "visibility", + "view_user", + "open_in_gooey", + "linked_saved_run", + "view_runs", + "created_at", + "updated_at", + ] + list_filter = ["workflow", "visibility", "created_by__is_paying"] + search_fields = ["workflow", "published_run_id", "title", "notes"] + autocomplete_fields = ["saved_run", "created_by", "last_edited_by"] + readonly_fields = [ + "open_in_gooey", + "view_runs", + "created_at", + "updated_at", + ] + + def view_user(self, published_run: PublishedRun): + if published_run.created_by is None: + return None + return change_obj_url(published_run.created_by) + + view_user.short_description = "View User" + + def linked_saved_run(self, published_run: PublishedRun): + return change_obj_url(published_run.saved_run) + + linked_saved_run.short_description = "Linked Run" + + @admin.display(description="View Runs") + def view_runs(self, published_run: PublishedRun): + return list_related_html_url( + SavedRun.objects.filter(parent_version__published_run=published_run), + query_param="parent_version__published_run__id__exact", + instance_id=published_run.id, + show_add=False, + ) + + @admin.register(SavedRun) class SavedRunAdmin(admin.ModelAdmin): list_display = [ "__str__", - "example_id", "run_id", "view_user", - "created_at", + "open_in_gooey", + "view_parent_published_run", "run_time", - "updated_at", "price", + "created_at", + "updated_at", ] list_filter = ["workflow"] search_fields = ["workflow", "example_id", "run_id", "uid"] + autocomplete_fields = ["parent_version"] readonly_fields = [ "open_in_gooey", @@ -235,6 +290,11 @@ class SavedRunAdmin(admin.ModelAdmin): django.db.models.JSONField: {"widget": JSONEditorWidget}, } + def lookup_allowed(self, key, value): + if key in ["parent_version__published_run__id__exact"]: + return True + return super().lookup_allowed(key, value) + def view_user(self, saved_run: SavedRun): return change_obj_url( AppUser.objects.get(uid=saved_run.uid), @@ -248,6 +308,17 @@ def view_bots(self, saved_run: SavedRun): view_bots.short_description = "View Bots" + @admin.display(description="View Published Run") + def view_parent_published_run(self, saved_run: SavedRun): + pr = saved_run.parent_published_run() + return pr and change_obj_url(pr) + + +@admin.register(PublishedRunVersion) +class PublishedRunVersionAdmin(admin.ModelAdmin): + search_fields = ["id", "version_id", "published_run__published_run_id"] + autocomplete_fields = ["published_run", "saved_run", "changed_by"] + class LastActiveDeltaFilter(admin.SimpleListFilter): title = Conversation.last_active_delta.short_description @@ -360,6 +431,12 @@ class FeedbackInline(admin.TabularInline): readonly_fields = ["created_at"] +class MessageAttachmentInline(admin.TabularInline): + model = MessageAttachment + extra = 0 + readonly_fields = ["url", "metadata", "created_at"] + + class AnalysisResultFilter(admin.SimpleListFilter): title = "analysis_result" parameter_name = "analysis_result" @@ -419,7 +496,7 @@ class MessageAdmin(admin.ModelAdmin): ordering = ["created_at"] actions = [export_to_csv, export_to_excel] - inlines = [FeedbackInline] + inlines = [MessageAttachmentInline, FeedbackInline] formfield_overrides = { django.db.models.JSONField: {"widget": JSONEditorWidget}, @@ -643,3 +720,18 @@ def conversation_link(self, feedback: Feedback): ) conversation_link.short_description = "Conversation" + + +@admin.register(WorkflowMetadata) +class WorkflowMetadata(admin.ModelAdmin): + list_display = [ + "workflow", + "short_title", + "meta_title", + "meta_description", + "created_at", + "updated_at", + ] + search_fields = ["workflow", "meta_title", "meta_description"] + list_filter = ["workflow"] + readonly_fields = ["created_at", "updated_at"] diff --git a/bots/admin_links.py b/bots/admin_links.py index 4a09cbebf..94086e348 100644 --- a/bots/admin_links.py +++ b/bots/admin_links.py @@ -1,3 +1,4 @@ +import re import typing from django.db import models @@ -8,6 +9,7 @@ def open_in_new_tab(url: str, *, label: str = "", add_related_url: str = None) -> str: + label = re.sub(r"https?://", "", label) context = { "url": url, "label": label, diff --git a/bots/migrations/0047_messageattachment_alter_feedback_options_and_more.py b/bots/migrations/0047_messageattachment_alter_feedback_options_and_more.py new file mode 100644 index 000000000..f8e9d8910 --- /dev/null +++ b/bots/migrations/0047_messageattachment_alter_feedback_options_and_more.py @@ -0,0 +1,64 @@ +# Generated by Django 4.2.5 on 2023-11-22 13:45 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("files", "0001_initial"), + ("bots", "0046_savedrun_bots_savedr_created_cb8e09_idx_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="MessageAttachment", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("url", models.TextField()), + ("created_at", models.DateTimeField(auto_now_add=True, db_index=True)), + ], + options={ + "ordering": ["created_at"], + }, + ), + migrations.AlterModelOptions( + name="feedback", + options={"get_latest_by": "created_at", "ordering": ["-created_at"]}, + ), + migrations.AddIndex( + model_name="feedback", + index=models.Index( + fields=["-created_at"], name="bots_feedba_created_fbd16a_idx" + ), + ), + migrations.AddField( + model_name="messageattachment", + name="message", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="attachments", + to="bots.message", + ), + ), + migrations.AddField( + model_name="messageattachment", + name="metadata", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="message_attachments", + to="files.filemetadata", + ), + ), + ] diff --git a/bots/migrations/0048_alter_messageattachment_url.py b/bots/migrations/0048_alter_messageattachment_url.py new file mode 100644 index 000000000..8b5774643 --- /dev/null +++ b/bots/migrations/0048_alter_messageattachment_url.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.5 on 2023-11-25 12:38 + +import bots.custom_fields +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0047_messageattachment_alter_feedback_options_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="messageattachment", + name="url", + field=bots.custom_fields.CustomURLField(max_length=2048), + ), + ] diff --git a/bots/migrations/0049_publishedrun_publishedrunversion.py b/bots/migrations/0049_publishedrun_publishedrunversion.py new file mode 100644 index 000000000..51c209825 --- /dev/null +++ b/bots/migrations/0049_publishedrun_publishedrunversion.py @@ -0,0 +1,265 @@ +# Generated by Django 4.2.7 on 2023-12-05 13:39 + +from django.db import migrations, models +import django.db.models.deletion + +from bots.models import PublishedRunVisibility +from daras_ai_v2.crypto import get_random_doc_id + + +def set_field_attribute(instance, field_name, **attrs): + for field in instance._meta.local_fields: + if field.name == field_name: + for attr, value in attrs.items(): + setattr(field, attr, value) + + +def create_published_run_from_example( + *, + published_run_model, + published_run_version_model, + saved_run, + user, + published_run_id, +): + published_run = published_run_model( + workflow=saved_run.workflow, + published_run_id=published_run_id, + created_by=user, + last_edited_by=user, + saved_run=saved_run, + title=saved_run.page_title, + notes=saved_run.page_notes, + visibility=PublishedRunVisibility.PUBLIC, + is_approved_example=not saved_run.hidden, + ) + set_field_attribute(published_run, "created_at", auto_now_add=False) + set_field_attribute(published_run, "updated_at", auto_now=False) + published_run.created_at = saved_run.created_at + published_run.updated_at = saved_run.updated_at + published_run.save() + set_field_attribute(published_run, "created_at", auto_now_add=True) + set_field_attribute(published_run, "updated_at", auto_now=True) + + version = published_run_version_model( + published_run=published_run, + version_id=get_random_doc_id(), + saved_run=saved_run, + changed_by=user, + title=saved_run.page_title, + notes=saved_run.page_notes, + visibility=PublishedRunVisibility.PUBLIC, + ) + set_field_attribute(published_run, "created_at", auto_now_add=False) + version.created_at = saved_run.updated_at + version.save() + set_field_attribute(published_run, "created_at", auto_now_add=True) + + return published_run + + +def forwards_func(apps, schema_editor): + # if example_id is not null, create published run with + # is_approved_example to True and visibility to Public + saved_run_model = apps.get_model("bots", "SavedRun") + published_run_model = apps.get_model("bots", "PublishedRun") + published_run_version_model = apps.get_model("bots", "PublishedRunVersion") + db_alias = schema_editor.connection.alias + + # all examples + for saved_run in saved_run_model.objects.using(db_alias).filter( + example_id__isnull=False, + ): + create_published_run_from_example( + published_run_model=published_run_model, + published_run_version_model=published_run_version_model, + saved_run=saved_run, + user=None, # TODO: use gooey-support user instead? + published_run_id=saved_run.example_id, + ) + + # recipe root examples + for saved_run in saved_run_model.objects.using(db_alias).filter( + example_id__isnull=True, + run_id__isnull=True, + uid__isnull=True, + ): + create_published_run_from_example( + published_run_model=published_run_model, + published_run_version_model=published_run_version_model, + saved_run=saved_run, + user=None, + published_run_id="", + ) + + +def backwards_func(apps, schema_editor): + pass + + +class Migration(migrations.Migration): + dependencies = [ + ("app_users", "0010_alter_appuser_balance_alter_appuser_created_at_and_more"), + ("bots", "0048_alter_messageattachment_url"), + ] + + operations = [ + migrations.CreateModel( + name="PublishedRun", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("published_run_id", models.CharField(blank=True, max_length=128)), + ( + "workflow", + models.IntegerField( + choices=[ + (1, "Doc Search"), + (2, "Doc Summary"), + (3, "Google GPT"), + (4, "Copilot"), + (5, "Lipysnc + TTS"), + (6, "Text to Speech"), + (7, "Speech Recognition"), + (8, "Lipsync"), + (9, "Deforum Animation"), + (10, "Compare Text2Img"), + (11, "Text2Audio"), + (12, "Img2Img"), + (13, "Face Inpainting"), + (14, "Google Image Gen"), + (15, "Compare AI Upscalers"), + (16, "SEO Summary"), + (17, "Email Face Inpainting"), + (18, "Social Lookup Email"), + (19, "Object Inpainting"), + (20, "Image Segmentation"), + (21, "Compare LLM"), + (22, "Chyron Plant"), + (23, "Letter Writer"), + (24, "Smart GPT"), + (25, "AI QR Code"), + (26, "Doc Extract"), + (27, "Related QnA Maker"), + (28, "Related QnA Maker Doc"), + (29, "Embeddings"), + (30, "Bulk Runner"), + ] + ), + ), + ("title", models.TextField(blank=True, default="")), + ("notes", models.TextField(blank=True, default="")), + ( + "visibility", + models.IntegerField( + choices=[(1, "Unlisted"), (2, "Public")], default=1 + ), + ), + ("is_approved_example", models.BooleanField(default=False)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "created_by", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="published_runs", + to="app_users.appuser", + ), + ), + ( + "last_edited_by", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="app_users.appuser", + ), + ), + ( + "saved_run", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.PROTECT, + related_name="published_runs", + to="bots.savedrun", + ), + ), + ], + options={ + "ordering": ["-updated_at"], + "unique_together": {("workflow", "published_run_id")}, + }, + ), + migrations.CreateModel( + name="PublishedRunVersion", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("version_id", models.CharField(max_length=128, unique=True)), + ("title", models.TextField(blank=True, default="")), + ("notes", models.TextField(blank=True, default="")), + ( + "visibility", + models.IntegerField( + choices=[(1, "Unlisted"), (2, "Public")], default=1 + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ( + "changed_by", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="app_users.appuser", + ), + ), + ( + "published_run", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="versions", + to="bots.publishedrun", + ), + ), + ( + "saved_run", + models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="published_run_versions", + to="bots.savedrun", + ), + ), + ], + options={ + "ordering": ["-created_at"], + "get_latest_by": "created_at", + "indexes": [ + models.Index( + fields=["published_run", "-created_at"], + name="bots_publis_publish_9cd246_idx", + ), + models.Index( + fields=["version_id"], name="bots_publis_version_c121d4_idx" + ), + ], + }, + ), + migrations.RunPython( + forwards_func, + backwards_func, + ), + ] diff --git a/bots/migrations/0050_botintegration_published_run.py b/bots/migrations/0050_botintegration_published_run.py new file mode 100644 index 000000000..2b8e2035d --- /dev/null +++ b/bots/migrations/0050_botintegration_published_run.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.7 on 2023-12-05 13:39 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0049_publishedrun_publishedrunversion"), + ] + + operations = [ + migrations.AddField( + model_name="botintegration", + name="published_run", + field=models.ForeignKey( + blank=True, + default=None, + help_text="The saved run that the bot is based on", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="botintegrations", + to="bots.publishedrun", + ), + ), + ] diff --git a/bots/migrations/0051_savedrun_parent_version.py b/bots/migrations/0051_savedrun_parent_version.py new file mode 100644 index 000000000..3c2c16b18 --- /dev/null +++ b/bots/migrations/0051_savedrun_parent_version.py @@ -0,0 +1,24 @@ +# Generated by Django 4.2.7 on 2023-12-08 10:57 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0050_botintegration_published_run"), + ] + + operations = [ + migrations.AddField( + model_name="savedrun", + name="parent_version", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="children_runs", + to="bots.publishedrunversion", + ), + ), + ] diff --git a/bots/migrations/0052_alter_publishedrun_options_and_more.py b/bots/migrations/0052_alter_publishedrun_options_and_more.py new file mode 100644 index 000000000..4d6dfe2f8 --- /dev/null +++ b/bots/migrations/0052_alter_publishedrun_options_and_more.py @@ -0,0 +1,42 @@ +# Generated by Django 4.2.7 on 2023-12-11 05:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0051_savedrun_parent_version"), + ] + + operations = [ + migrations.AlterModelOptions( + name="publishedrun", + options={"get_latest_by": "updated_at", "ordering": ["-updated_at"]}, + ), + migrations.AddIndex( + model_name="publishedrun", + index=models.Index( + fields=["workflow"], name="bots_publis_workflo_a0953a_idx" + ), + ), + migrations.AddIndex( + model_name="publishedrun", + index=models.Index( + fields=["workflow", "created_by"], name="bots_publis_workflo_c75a55_idx" + ), + ), + migrations.AddIndex( + model_name="publishedrun", + index=models.Index( + fields=["workflow", "published_run_id"], + name="bots_publis_workflo_87bece_idx", + ), + ), + migrations.AddIndex( + model_name="publishedrun", + index=models.Index( + fields=["workflow", "visibility", "is_approved_example"], + name="bots_publis_workflo_36a83a_idx", + ), + ), + ] diff --git a/bots/migrations/0053_alter_publishedrun_workflow_alter_savedrun_workflow_and_more.py b/bots/migrations/0053_alter_publishedrun_workflow_alter_savedrun_workflow_and_more.py new file mode 100644 index 000000000..4398af91d --- /dev/null +++ b/bots/migrations/0053_alter_publishedrun_workflow_alter_savedrun_workflow_and_more.py @@ -0,0 +1,103 @@ +# Generated by Django 4.2.7 on 2023-12-21 15:25 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0052_alter_publishedrun_options_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="publishedrun", + name="workflow", + field=models.IntegerField( + choices=[ + (1, "Doc Search"), + (2, "Doc Summary"), + (3, "Google GPT"), + (4, "Copilot"), + (5, "Lipysnc + TTS"), + (6, "Text to Speech"), + (7, "Speech Recognition"), + (8, "Lipsync"), + (9, "Deforum Animation"), + (10, "Compare Text2Img"), + (11, "Text2Audio"), + (12, "Img2Img"), + (13, "Face Inpainting"), + (14, "Google Image Gen"), + (15, "Compare AI Upscalers"), + (16, "SEO Summary"), + (17, "Email Face Inpainting"), + (18, "Social Lookup Email"), + (19, "Object Inpainting"), + (20, "Image Segmentation"), + (21, "Compare LLM"), + (22, "Chyron Plant"), + (23, "Letter Writer"), + (24, "Smart GPT"), + (25, "AI QR Code"), + (26, "Doc Extract"), + (27, "Related QnA Maker"), + (28, "Related QnA Maker Doc"), + (29, "Embeddings"), + (30, "Bulk Runner"), + (31, "Bulk Evaluator"), + ] + ), + ), + migrations.AlterField( + model_name="savedrun", + name="workflow", + field=models.IntegerField( + choices=[ + (1, "Doc Search"), + (2, "Doc Summary"), + (3, "Google GPT"), + (4, "Copilot"), + (5, "Lipysnc + TTS"), + (6, "Text to Speech"), + (7, "Speech Recognition"), + (8, "Lipsync"), + (9, "Deforum Animation"), + (10, "Compare Text2Img"), + (11, "Text2Audio"), + (12, "Img2Img"), + (13, "Face Inpainting"), + (14, "Google Image Gen"), + (15, "Compare AI Upscalers"), + (16, "SEO Summary"), + (17, "Email Face Inpainting"), + (18, "Social Lookup Email"), + (19, "Object Inpainting"), + (20, "Image Segmentation"), + (21, "Compare LLM"), + (22, "Chyron Plant"), + (23, "Letter Writer"), + (24, "Smart GPT"), + (25, "AI QR Code"), + (26, "Doc Extract"), + (27, "Related QnA Maker"), + (28, "Related QnA Maker Doc"), + (29, "Embeddings"), + (30, "Bulk Runner"), + (31, "Bulk Evaluator"), + ], + default=4, + ), + ), + migrations.AddIndex( + model_name="publishedrun", + index=models.Index( + fields=[ + "workflow", + "visibility", + "is_approved_example", + "published_run_id", + ], + name="bots_publis_workflo_d3ad4e_idx", + ), + ), + ] diff --git a/bots/migrations/0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more.py b/bots/migrations/0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more.py new file mode 100644 index 000000000..9248ee4da --- /dev/null +++ b/bots/migrations/0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more.py @@ -0,0 +1,57 @@ +# Generated by Django 4.2.7 on 2024-01-02 16:05 + +from django.db import migrations, models + + +def migrate_keyword_instructions(apps, schema_editor): + SavedRun = apps.get_model("bots", "SavedRun") + db_alias = schema_editor.connection.alias + objects = SavedRun.objects.using(db_alias) + + new_prompt = """ + +{{ final_search_query }} + +\ +You are a BM25 tokenizer. Extract rare terms like part numbers from the message. Return a JSON List of tokenized query terms. +""".strip() + + qs = objects.exclude( + state__keyword_instructions__isnull=True, + ).exclude( + state__keyword_instructions="", + ) + for sr in qs: + sr.state |= dict(keyword_instructions=new_prompt) + sr.save() + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0053_alter_publishedrun_workflow_alter_savedrun_workflow_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="savedrun", + name="example_id", + field=models.CharField( + blank=True, + default=None, + help_text="(Deprecated)", + max_length=128, + null=True, + ), + ), + migrations.AlterField( + model_name="savedrun", + name="page_notes", + field=models.TextField(blank=True, default="", help_text="(Deprecated)"), + ), + migrations.AlterField( + model_name="savedrun", + name="page_title", + field=models.TextField(blank=True, default="", help_text="(Deprecated)"), + ), + migrations.RunPython(migrate_keyword_instructions, migrations.RunPython.noop), + ] diff --git a/bots/migrations/0055_workflowmetadata.py b/bots/migrations/0055_workflowmetadata.py new file mode 100644 index 000000000..4279cda49 --- /dev/null +++ b/bots/migrations/0055_workflowmetadata.py @@ -0,0 +1,338 @@ +# Generated by Django 4.2.7 on 2024-01-08 23:57 + +from django.db import migrations, models + +import bots.custom_fields + + +def forwards_func(apps, schema_editor): + from bots.models import Workflow + + WorkflowMetadata = apps.get_model("bots", "WorkflowMetadata") + objs = [ + WorkflowMetadata( + workflow=Workflow.DOC_SEARCH, + short_title="Doc Search", + meta_title="Advanced Document Search Solution", + meta_description=""" + Easily search within PDFs, Word documents, and other formats using Gooey AI's doc-search feature. Improve efficiency and knowledge extraction with our advanced AI tools. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/bcc7aa58-93fe-11ee-a083-02420a0001c8/Search%20your%20docs.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.DOC_SUMMARY, + short_title="Summarize", + meta_title="AI Document Summarization & Transcription", + meta_description=""" + Effortlessly summarize large files and collections of PDFs, docs and audio files using AI with Gooey.AI | Gooey.AI Doc-Summary Solution. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f35796d2-93fe-11ee-b86c-02420a0001c7/Summarize%20with%20GPT.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.GOOGLE_GPT, + short_title="LLM Web Search", + meta_title="Browse the web using ChatGPT", + meta_description=""" + Like Bing + ChatGPT or perplexity.ai, this workflow queries Google and then summarizes the results (with citations!) using an editable GPT3 script. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/85ed60a2-9405-11ee-9747-02420a0001ce/Web%20search%20GPT.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.VIDEO_BOTS, + short_title="Copilot", + meta_title="Advanced AI Copilot for Farming Solutions", + meta_description=""" + Discover Gooey.AI's Copilot, the most advanced AI bot offering GPT4, PaLM2, LLaAM2, knowledge base integration, conversation analysis & more for farming solutions. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f454d64a-9457-11ee-b6d5-02420a0001cb/Copilot.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.LIPSYNC_TTS, + short_title="Lipsync + Voice", + meta_title="Lipsync Video Maker with AI Voice Generation", + meta_description=""" + Create realistic lipsync videos with custom voices. Just upload a video or image, choose or bring your own voice from EvelenLabs to generate amazing videos with the Gooey.AI Lipsync Maker. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/13b4d352-9456-11ee-8edd-02420a0001c7/Lipsync%20TTS.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.TEXT_TO_SPEECH, + short_title="Text to Speech", + meta_title="Compare Text-to-Speech AI Engines", + meta_description=""" + Experience the most powerful text-to-speech APIs with Gooey.AI. Compare and choose the best voice for podcasts, YouTube videos, websites, bots, and more. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a73181ce-9457-11ee-8edd-02420a0001c7/Voice%20generators.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.ASR, + short_title="Speech", + meta_title="Speech and AI Services", + meta_description=""" + Generate realistic audio files, lip-sync videos, and experience multilingual chatbots with Gooey.AI speech and AI-based services. Improve user experience! + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/1916825c-93fa-11ee-97be-02420a0001c8/Speech.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.LIPSYNC, + short_title="Lipsync", + meta_title="Lipsync Animation Generator with Audio Input", + meta_description=""" + Achieve high-quality, realistic Lipsync animations with Gooey.AI's Lipsync - Just input a face and audio to generate your tailored animation. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7fc4d302-9402-11ee-98dc-02420a0001ca/Lip%20Sync.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.DEFORUM_SD, + short_title="Animation", + meta_title="Animation Generator=AI-Powered Animations Simplified", + meta_description=""" + Create AI-generated animations effortlessly with Gooey.AI's Animation Generator and Stable Diffusion's Deforum technology. No complex CoLab notebooks required! + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7dc25196-93fe-11ee-9e3a-02420a0001ce/AI%20Animation%20generator.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.COMPARE_TEXT2IMG, + short_title="Image Generator", + meta_title="AI Image Generators Comparison", + meta_description=""" + Discover the most effective AI image generator for your needs by comparing different models like Stable Diffusion, Dall-E, and more at Gooey.AI. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ae7b2940-93fc-11ee-8edc-02420a0001cc/Compare%20image%20generators.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.TEXT_2_AUDIO, + short_title="Music", + meta_title="Text2Audio - AI-Driven Text-to-Sound Generator | Gooey.AI", + meta_description=""" + Transform text into realistic audio with Gooey.AI's text2audio tool. Create custom sounds using AI-driven technology for your projects and content. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/85cf8ea4-9457-11ee-bd77-02420a0001ce/Text%20guided%20audio.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.IMG_2_IMG, + short_title="Photo Editor", + meta_title="AI Photo Editor for Stunning Image Transformations", + meta_description=""" + Transform your images with our AI Photo Editor utilizing the latest AI technology for incredible results. Enhance your photos, create unique art, and more + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/cc2804ea-9401-11ee-940a-02420a0001c7/Edit%20an%20image.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.FACE_INPAINTING, + short_title="Face Editor", + meta_title="AI Face Extraction & Generation", + meta_description=""" + Explore Gooey.AI's revolutionary face extraction and AI-generated photo technology, where you can upload, extract, and bring your desired character to life in a new image. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a146bfc0-93ff-11ee-b86c-02420a0001c7/Face%20in%20painting.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.GOOGLE_IMAGE_GEN, + short_title="SEO Renderer", + meta_title="AI Image Rendering & Generation Solution", + meta_description=""" + Discover the power of AI in image rendering with Gooey.AI's cutting-edge technology, transforming text prompts into stunning visuals for any search query. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/dcd82b68-9400-11ee-9e3a-02420a0001ce/Search%20result%20photo.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.COMPARE_UPSCALER, + short_title="Upscaler", + meta_title="AI Upscalers Comparison & Examples", + meta_description=""" + Explore the benefits of AI upscalers and discover how they enhance image quality through cutting-edge technology at Gooey.ai. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/2e8ee512-93fe-11ee-a083-02420a0001c8/Image%20upscaler.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.SEO_SUMMARY, + short_title="SEO Renderer", + meta_title="SEO Paragraph Generator for Enhanced Content", + meta_description=""" + Optimize your content with Gooey's SEO Paragraph Generator - AI powered content optimization for improved search engine rankings and increased traffic. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/13d3ab1e-9457-11ee-98a6-02420a0001c9/SEO.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.EMAIL_FACE_INPAINTING, + short_title="Email to Image", + meta_title="AI Image from Email Lookup", + meta_description=""" + Discover the AI-based solution for generating images from email lookups, creating unique and engaging visuals using email addresses and AI-generated scenes. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6937427a-9522-11ee-b6d3-02420a0001ea/Email%20photo.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.SOCIAL_LOOKUP_EMAIL, + short_title="Emailer", + meta_title="AI-Powered Email Writer with Profile Lookup", + meta_description=""" + Enhance your outreach with Gooey.AI's Email Writer that finds public social profiles and creates personalized emails using advanced AI mail merge. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6729ea44-9457-11ee-bd77-02420a0001ce/Profile%20look%20up%20gpt%20email.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.OBJECT_INPAINTING, + short_title="Background Maker", + meta_title="Product Photo Background Generator", + meta_description=""" + Generate professional background scenery for your product photos with Gooey.AI's advanced inpainting AI technology. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/4bca6982-9456-11ee-bc12-02420a0001cc/Product%20photo%20backgrounds.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.COMPARE_LLM, + short_title="LLM", + meta_title="Compare GPT-4, PaLM2, and LLaMA2 | Large Language Model Comparison", + meta_description=""" + Compare popular large language models like GPT-4, PaLM2, and LLaMA2 to determine which one performs best for your specific needs | Gooey.AI + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/5e4f4c58-93fc-11ee-a39e-02420a0001ce/LLMs.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.SMART_GPT, + short_title="SmartGPT", + meta_title="SmartGPT - Advanced AI Language Model", + meta_description=""" + Explore powerful AI solutions with Gooey.AI's SmartGPT, a cutting-edge language model designed to transform industries and simplify your work. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/3d71b434-9457-11ee-8edd-02420a0001c7/Smart%20GPT.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.QR_CODE, + short_title="QR Code", + meta_title="AI Art QR Code Generator", + meta_description=""" + Generate AI-empowered artistic QR codes tailored to your style for impactful marketing, branding & more with Gooey.AI. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a679a410-9456-11ee-bd77-02420a0001ce/QR%20Code.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.DOC_EXTRACT, + short_title="Synthetic Data", + meta_title="Efficient YouTube Video Transcription & GPT4 Integration", + meta_description=""" + Automate YouTube video transcription, run GPT4 prompts, and save data to Google Sheets with Gooey AI's YouTube Bot. Elevate your content creation strategy! + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ddc8ffac-93fb-11ee-89fb-02420a0001cb/Youtube%20transcripts.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.RELATED_QNA_MAKER, + short_title="People Also Ask", + meta_title="Related QnA Maker API for Document Search", + meta_description=""" + Enhance your document search experience with Gooey.AI's Related QnA Maker API, leveraging advanced machine learning to deliver relevant information from your doc, pdf, or files. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/cbd2c94e-9456-11ee-a95e-02420a0001cc/People%20also%20ask.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.BULK_RUNNER, + short_title="Bulk", + meta_title="Bulk Runner", + meta_description=""" + Which AI model actually works best for your needs? Upload your own data and evaluate any Gooey.AI workflow, LLM or AI model against any other. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d80fd4d8-93fa-11ee-bc13-02420a0001cc/Bulk%20Runner.jpg.png", + ), + WorkflowMetadata( + workflow=Workflow.BULK_EVAL, + short_title="Eval", + meta_title="Bulk Evaluator", + meta_description=""" + Summarize and score every row of any CSV, google sheet or excel with GPT4 (or any LLM you choose). Then average every score in any column to generate automated evaluations. + """.strip(), + meta_image="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/9631fb74-9a97-11ee-971f-02420a0001c4/evaluator.png.png", + ), + ] + + db_alias = schema_editor.connection.alias + objects = WorkflowMetadata.objects.using(db_alias) + objects.bulk_create(objs) + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="WorkflowMetadata", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "workflow", + models.IntegerField( + choices=[ + (1, "Doc Search"), + (2, "Doc Summary"), + (3, "Google GPT"), + (4, "Copilot"), + (5, "Lipysnc + TTS"), + (6, "Text to Speech"), + (7, "Speech Recognition"), + (8, "Lipsync"), + (9, "Deforum Animation"), + (10, "Compare Text2Img"), + (11, "Text2Audio"), + (12, "Img2Img"), + (13, "Face Inpainting"), + (14, "Google Image Gen"), + (15, "Compare AI Upscalers"), + (16, "SEO Summary"), + (17, "Email Face Inpainting"), + (18, "Social Lookup Email"), + (19, "Object Inpainting"), + (20, "Image Segmentation"), + (21, "Compare LLM"), + (22, "Chyron Plant"), + (23, "Letter Writer"), + (24, "Smart GPT"), + (25, "AI QR Code"), + (26, "Doc Extract"), + (27, "Related QnA Maker"), + (28, "Related QnA Maker Doc"), + (29, "Embeddings"), + (30, "Bulk Runner"), + (31, "Bulk Evaluator"), + ], + unique=True, + ), + ), + ("short_title", models.TextField()), + ("help_url", models.URLField(blank=True, default="")), + ( + "default_image", + models.URLField( + blank=True, default="", help_text="(not implemented)" + ), + ), + ("meta_title", models.TextField()), + ("meta_description", models.TextField(blank=True, default="")), + ( + "meta_image", + bots.custom_fields.CustomURLField( + blank=True, default="", max_length=2048 + ), + ), + ( + "meta_keywords", + models.JSONField( + blank=True, default=list, help_text="(not implemented)" + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ], + ), + migrations.RunPython(forwards_func, migrations.RunPython.noop), + ] diff --git a/bots/migrations/0056_botintegration_streaming_enabled.py b/bots/migrations/0056_botintegration_streaming_enabled.py new file mode 100644 index 000000000..dcf1366fa --- /dev/null +++ b/bots/migrations/0056_botintegration_streaming_enabled.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.7 on 2024-01-31 19:14 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0055_workflowmetadata"), + ] + + operations = [ + migrations.AddField( + model_name="botintegration", + name="streaming_enabled", + field=models.BooleanField( + default=False, + help_text="If set, the bot will stream messages to the frontend", + ), + ), + ] diff --git a/bots/models.py b/bots/models.py index 6e6dc2064..2f7e5676a 100644 --- a/bots/models.py +++ b/bots/models.py @@ -2,6 +2,7 @@ import typing import uuid from multiprocessing.pool import ThreadPool +from textwrap import dedent import pytz from django.conf import settings @@ -10,15 +11,17 @@ from django.core.validators import RegexValidator from django.db import models, transaction from django.db.models import Q -from django.utils.text import Truncator +from django.utils.text import Truncator, slugify from furl import furl from phonenumber_field.modelfields import PhoneNumberField from app_users.models import AppUser from bots.admin_links import open_in_new_tab -from bots.custom_fields import PostgresJSONEncoder -from daras_ai_v2.secrets_utils import GCPSecret +from bots.custom_fields import PostgresJSONEncoder, CustomURLField +from daras_ai_v2.crypto import get_random_doc_id +from daras_ai_v2.language_model import format_chat_entry +from daras_ai_v2.secrets_utils import GCPSecret if typing.TYPE_CHECKING: from daras_ai_v2.base import BasePage @@ -31,6 +34,47 @@ EPOCH = datetime.datetime.utcfromtimestamp(0) +class PublishedRunVisibility(models.IntegerChoices): + UNLISTED = 1 + PUBLIC = 2 + + def help_text(self): + match self: + case PublishedRunVisibility.UNLISTED: + return "Only me + people with a link" + case PublishedRunVisibility.PUBLIC: + return "Public" + case _: + return self.label + + def get_badge_html(self): + badge_container_class = ( + "text-sm bg-light border border-dark rounded-pill px-2 py-1" + ) + + match self: + case PublishedRunVisibility.UNLISTED: + return dedent( + f"""\ + + + Private + + """ + ) + case PublishedRunVisibility.PUBLIC: + return dedent( + f"""\ + + + Public + + """ + ) + case _: + raise NotImplementedError(self) + + class Platform(models.IntegerChoices): FACEBOOK = 1 INSTAGRAM = (2, "Instagram & FB") @@ -79,12 +123,13 @@ class Workflow(models.IntegerChoices): RELATED_QNA_MAKER_DOC = (28, "Related QnA Maker Doc") EMBEDDINGS = (29, "Embeddings") BULK_RUNNER = (30, "Bulk Runner") + BULK_EVAL = (31, "Bulk Evaluator") @property def short_slug(self): return min(self.page_cls.slug_versions, key=len) - def get_app_url(self, example_id: str, run_id: str, uid: str): + def get_app_url(self, example_id: str, run_id: str, uid: str, run_slug: str = ""): """return the url to the gooey app""" query_params = {} if run_id and uid: @@ -93,7 +138,8 @@ def get_app_url(self, example_id: str, run_id: str, uid: str): query_params |= dict(example_id=example_id) return str( furl(settings.APP_BASE_URL, query_params=query_params) - / self.short_slug + / self.page_cls.slug_versions[-1] + / run_slug / "/" ) @@ -103,6 +149,50 @@ def page_cls(self) -> typing.Type["BasePage"]: return workflow_map[self] + def get_or_create_metadata(self) -> "WorkflowMetadata": + metadata, _created = WorkflowMetadata.objects.get_or_create( + workflow=self, + defaults=dict( + short_title=lambda: ( + self.page_cls.get_root_published_run().title or self.page_cls.title + ), + default_image=self.page_cls.explore_image or None, + meta_title=lambda: ( + self.page_cls.get_root_published_run().title or self.page_cls.title + ), + meta_description=lambda: ( + self.page_cls().preview_description(state={}) + or self.page_cls.get_root_published_run().notes + ), + meta_image=lambda: (self.page_cls.explore_image or None), + ), + ) + return metadata + + +class WorkflowMetadata(models.Model): + workflow = models.IntegerField(choices=Workflow.choices, unique=True) + short_title = models.TextField() + help_url = models.URLField(blank=True, default="") + + # TODO: support the below fields + default_image = models.URLField( + blank=True, default="", help_text="(not implemented)" + ) + + meta_title = models.TextField() + meta_description = models.TextField(blank=True, default="") + meta_image = CustomURLField(default="", blank=True) + meta_keywords = models.JSONField( + default=list, blank=True, help_text="(not implemented)" + ) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + def __str__(self): + return self.meta_title + class SavedRunQuerySet(models.QuerySet): def to_df(self, tz=pytz.timezone(settings.TIME_ZONE)) -> "pd.DataFrame": @@ -130,11 +220,17 @@ class SavedRun(models.Model): blank=True, related_name="children", ) + parent_version = models.ForeignKey( + "bots.PublishedRunVersion", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="children_runs", + ) workflow = models.IntegerField( choices=Workflow.choices, default=Workflow.VIDEO_BOTS ) - example_id = models.CharField(max_length=128, default=None, null=True, blank=True) run_id = models.CharField(max_length=128, default=None, null=True, blank=True) uid = models.CharField(max_length=128, default=None, null=True, blank=True) @@ -143,8 +239,6 @@ class SavedRun(models.Model): error_msg = models.TextField(default="", blank=True) run_time = models.DurationField(default=datetime.timedelta, blank=True) run_status = models.TextField(default="", blank=True) - page_title = models.TextField(default="", blank=True) - page_notes = models.TextField(default="", blank=True) hidden = models.BooleanField(default=False) is_flagged = models.BooleanField(default=False) @@ -162,6 +256,12 @@ class SavedRun(models.Model): updated_at = models.DateTimeField(auto_now=True) created_at = models.DateTimeField(auto_now_add=True) + example_id = models.CharField( + max_length=128, default=None, null=True, blank=True, help_text="(Deprecated)" + ) + page_title = models.TextField(default="", blank=True, help_text="(Deprecated)") + page_notes = models.TextField(default="", blank=True, help_text="(Deprecated)") + objects = SavedRunQuerySet.as_manager() class Meta: @@ -188,7 +288,15 @@ class Meta: ] def __str__(self): - return self.get_app_url() + from daras_ai_v2.breadcrumbs import get_title_breadcrumbs + + title = get_title_breadcrumbs( + Workflow(self.workflow).page_cls, self, self.parent_published_run() + ).h1_title + return title or self.get_app_url() + + def parent_published_run(self) -> "PublishedRun": + return self.parent_version and self.parent_version.published_run def get_app_url(self): workflow = Workflow(self.workflow) @@ -208,10 +316,6 @@ def to_dict(self) -> dict: ret[StateKeys.run_time] = self.run_time.total_seconds() if self.run_status: ret[StateKeys.run_status] = self.run_status - if self.page_title: - ret[StateKeys.page_title] = self.page_title - if self.page_notes: - ret[StateKeys.page_notes] = self.page_notes if self.hidden: ret[StateKeys.hidden] = self.hidden if self.is_flagged: @@ -242,9 +346,6 @@ def copy_from_firebase_state(self, state: dict) -> "SavedRun": seconds=state.pop(StateKeys.run_time, None) or 0 ) self.run_status = state.pop(StateKeys.run_status, None) or "" - self.page_title = state.pop(StateKeys.page_title, None) or "" - self.page_notes = state.pop(StateKeys.page_notes, None) or "" - # self.hidden = state.pop(StateKeys.hidden, False) self.is_flagged = state.pop("is_flagged", False) self.state = state @@ -273,6 +374,12 @@ def submit_api_call( ) return result, page.run_doc_sr(run_id, uid) + def get_creator(self) -> AppUser | None: + if self.uid: + return AppUser.objects.filter(uid=self.uid).first() + else: + return None + @admin.display(description="Open in Gooey") def open_in_gooey(self): return open_in_new_tab(self.get_app_url(), label=self.get_app_url()) @@ -345,6 +452,15 @@ class BotIntegration(models.Model): blank=True, help_text="The saved run that the bot is based on", ) + published_run = models.ForeignKey( + "bots.PublishedRun", + on_delete=models.SET_NULL, + related_name="botintegrations", + null=True, + default=None, + blank=True, + help_text="The saved run that the bot is based on", + ) billing_account_uid = models.TextField( help_text="The gooey account uid where the credits will be deducted from", db_index=True, @@ -463,6 +579,11 @@ class BotIntegration(models.Model): help_text="If provided, the message content will be analyzed for this bot using this saved run", ) + streaming_enabled = models.BooleanField( + default=False, + help_text="If set, the bot will stream messages to the frontend (Slack only)", + ) + created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -485,6 +606,14 @@ def __str__(self): else: return self.name or platform_name + def get_active_saved_run(self) -> SavedRun | None: + if self.published_run: + return self.published_run.saved_run + elif self.saved_run: + return self.saved_run + else: + return None + def get_display_name(self): return ( (self.wa_phone_number and self.wa_phone_number.as_international) @@ -507,6 +636,12 @@ class ConvoState(models.IntegerChoices): class ConversationQuerySet(models.QuerySet): + def get_unique_users(self) -> "ConversationQuerySet": + """Get unique conversations""" + return self.distinct( + "fb_page_id", "ig_account_id", "wa_phone_number", "slack_user_id" + ) + def to_df(self, tz=pytz.timezone(settings.TIME_ZONE)) -> "pd.DataFrame": import pandas as pd @@ -538,6 +673,61 @@ def to_df(self, tz=pytz.timezone(settings.TIME_ZONE)) -> "pd.DataFrame": df = pd.DataFrame.from_records(rows) return df + def to_df_format( + self, tz=pytz.timezone(settings.TIME_ZONE), row_limit=1000 + ) -> "pd.DataFrame": + import pandas as pd + + qs = self.all() + rows = [] + for convo in qs[:row_limit]: + convo: Conversation + row = { + "Name": convo.get_display_name(), + "Messages": convo.messages.count(), + "Correct Answers": convo.messages.filter( + analysis_result__contains={"Answered": True} + ).count(), + "Thumbs up": convo.messages.filter( + feedbacks__rating=Feedback.Rating.RATING_THUMBS_UP + ).count(), + "Thumbs down": convo.messages.filter( + feedbacks__rating=Feedback.Rating.RATING_THUMBS_DOWN + ).count(), + } + try: + first_time = ( + convo.messages.earliest() + .created_at.astimezone(tz) + .replace(tzinfo=None) + ) + last_time = ( + convo.messages.latest() + .created_at.astimezone(tz) + .replace(tzinfo=None) + ) + row |= { + "Last Sent": last_time.strftime("%b %d, %Y %I:%M %p"), + "First Sent": first_time.strftime("%b %d, %Y %I:%M %p"), + "A7": not convo.d7(), + "A30": not convo.d30(), + "R1": last_time - first_time < datetime.timedelta(days=1), + "R7": last_time - first_time < datetime.timedelta(days=7), + "R30": last_time - first_time < datetime.timedelta(days=30), + "Delta Hours": round( + convo.last_active_delta().total_seconds() / 3600 + ), + } + except Message.DoesNotExist: + pass + row |= { + "Created At": convo.created_at.astimezone(tz).replace(tzinfo=None), + "Bot": str(convo.bot_integration), + } + rows.append(row) + df = pd.DataFrame.from_records(rows) + return df + class Conversation(models.Model): bot_integration = models.ForeignKey( @@ -709,6 +899,66 @@ def to_df(self, tz=pytz.timezone(settings.TIME_ZONE)) -> "pd.DataFrame": df = pd.DataFrame.from_records(rows) return df + def to_df_format( + self, tz=pytz.timezone(settings.TIME_ZONE), row_limit=10000 + ) -> "pd.DataFrame": + import pandas as pd + + qs = self.all().prefetch_related("feedbacks") + rows = [] + for message in qs[:row_limit]: + message: Message + row = { + "Name": message.conversation.get_display_name(), + "Role": message.role, + "Message (EN)": message.content, + "Sent": message.created_at.astimezone(tz) + .replace(tzinfo=None) + .strftime("%b %d, %Y %I:%M %p"), + "Feedback": message.feedbacks.first().get_display_text() + if message.feedbacks.first() + else None, # only show first feedback as per Sean's request + "Analysis JSON": message.analysis_result, + } + rows.append(row) + df = pd.DataFrame.from_records(rows) + return df + + def to_df_analysis_format( + self, tz=pytz.timezone(settings.TIME_ZONE), row_limit=10000 + ) -> "pd.DataFrame": + import pandas as pd + + qs = self.filter(role=CHATML_ROLE_USER).prefetch_related("feedbacks") + rows = [] + for message in qs[:row_limit]: + message: Message + row = { + "Name": message.conversation.get_display_name(), + "Question (EN)": message.content, + "Answer (EN)": message.get_next_by_created_at().content, + "Sent": message.created_at.astimezone(tz) + .replace(tzinfo=None) + .strftime("%b %d, %Y %I:%M %p"), + "Analysis JSON": message.analysis_result, + } + rows.append(row) + df = pd.DataFrame.from_records(rows) + return df + + def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]: + msgs = self.order_by("-created_at").prefetch_related("attachments")[:limit] + entries = [None] * len(msgs) + for i, msg in enumerate(reversed(msgs)): + entries[i] = format_chat_entry( + role=msg.role, + content=msg.content, + images=msg.attachments.filter( + metadata__mime_type__startswith="image/" + ).values_list("url", flat=True), + ) + return entries + class Message(models.Model): conversation = models.ForeignKey( @@ -791,6 +1041,32 @@ def local_lang(self): return Truncator(self.display_content).words(30) +class MessageAttachment(models.Model): + message = models.ForeignKey( + "bots.Message", + on_delete=models.CASCADE, + related_name="attachments", + ) + url = CustomURLField() + metadata = models.ForeignKey( + "files.FileMetadata", + on_delete=models.SET_NULL, + null=True, + blank=True, + default=None, + related_name="message_attachments", + ) + created_at = models.DateTimeField(auto_now_add=True, db_index=True) + + class Meta: + ordering = ["created_at"] + + def __str__(self): + if self.metadata_id: + return f"{self.metadata.name} ({self.url})" + return self.url + + class FeedbackQuerySet(models.QuerySet): def to_df(self, tz=pytz.timezone(settings.TIME_ZONE)) -> "pd.DataFrame": import pandas as pd @@ -824,6 +1100,37 @@ def to_df(self, tz=pytz.timezone(settings.TIME_ZONE)) -> "pd.DataFrame": df = pd.DataFrame.from_records(rows) return df + def to_df_format( + self, tz=pytz.timezone(settings.TIME_ZONE), row_limit=10000 + ) -> "pd.DataFrame": + import pandas as pd + + qs = self.all().prefetch_related("message", "message__conversation") + rows = [] + for feedback in qs[:row_limit]: + feedback: Feedback + row = { + "Name": feedback.message.conversation.get_display_name(), + "Question (EN)": feedback.message.get_previous_by_created_at().content, + "Question Sent": feedback.message.get_previous_by_created_at() + .created_at.astimezone(tz) + .replace(tzinfo=None) + .strftime("%b %d, %Y %I:%M %p"), + "Answer (EN)": feedback.message.content, + "Answer Sent": feedback.message.created_at.astimezone(tz) + .replace(tzinfo=None) + .strftime("%b %d, %Y %I:%M %p"), + "Rating": Feedback.Rating(feedback.rating).label, + "Feedback (EN)": feedback.text_english, + "Feedback Sent": feedback.created_at.astimezone(tz) + .replace(tzinfo=None) + .strftime("%b %d, %Y %I:%M %p"), + "Question Answered": feedback.message.question_answered, + } + rows.append(row) + df = pd.DataFrame.from_records(rows) + return df + class Feedback(models.Model): message = models.ForeignKey( @@ -882,7 +1189,10 @@ class Status(models.IntegerChoices): objects = FeedbackQuerySet.as_manager() class Meta: - ordering = ("-created_at",) + indexes = [ + models.Index(fields=["-created_at"]), + ] + ordering = ["-created_at"] get_latest_by = "created_at" def __str__(self): @@ -989,3 +1299,221 @@ def get_value(self): def delete_secret(self): self.delete() GCPSecret(self.gcp_secret_id).delete() + + +class PublishedRunQuerySet(models.QuerySet): + def create_published_run( + self, + *, + workflow: Workflow, + published_run_id: str, + saved_run: SavedRun, + user: AppUser, + title: str, + notes: str, + visibility: PublishedRunVisibility, + ): + with transaction.atomic(): + published_run = PublishedRun( + workflow=workflow, + published_run_id=published_run_id, + created_by=user, + last_edited_by=user, + title=title, + ) + published_run.save() + published_run.add_version( + user=user, + saved_run=saved_run, + title=title, + visibility=visibility, + notes=notes, + ) + return published_run + + +class PublishedRun(models.Model): + # published_run_id was earlier SavedRun.example_id + published_run_id = models.CharField( + max_length=128, + blank=True, + ) + + saved_run = models.ForeignKey( + "bots.SavedRun", + on_delete=models.PROTECT, + related_name="published_runs", + null=True, + ) + workflow = models.IntegerField( + choices=Workflow.choices, + ) + title = models.TextField(blank=True, default="") + notes = models.TextField(blank=True, default="") + visibility = models.IntegerField( + choices=PublishedRunVisibility.choices, + default=PublishedRunVisibility.UNLISTED, + ) + is_approved_example = models.BooleanField(default=False) + + created_by = models.ForeignKey( + "app_users.AppUser", + on_delete=models.SET_NULL, # TODO: set to sentinel instead (e.g. github's ghost user) + null=True, + related_name="published_runs", + ) + last_edited_by = models.ForeignKey( + "app_users.AppUser", + on_delete=models.SET_NULL, # TODO: set to sentinel instead (e.g. github's ghost user) + null=True, + ) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + objects = PublishedRunQuerySet.as_manager() + + class Meta: + get_latest_by = "updated_at" + + ordering = ["-updated_at"] + unique_together = [ + ["workflow", "published_run_id"], + ] + + indexes = [ + models.Index(fields=["workflow"]), + models.Index(fields=["workflow", "created_by"]), + models.Index(fields=["workflow", "published_run_id"]), + models.Index(fields=["workflow", "visibility", "is_approved_example"]), + models.Index( + fields=[ + "workflow", + "visibility", + "is_approved_example", + "published_run_id", + ] + ), + ] + + def __str__(self): + return self.title or self.get_app_url() + + @admin.display(description="Open in Gooey") + def open_in_gooey(self): + return open_in_new_tab(self.get_app_url(), label=self.get_app_url()) + + def duplicate( + self, + *, + user: AppUser, + title: str, + notes: str, + visibility: PublishedRunVisibility, + ) -> "PublishedRun": + return PublishedRun.objects.create_published_run( + workflow=Workflow(self.workflow), + published_run_id=get_random_doc_id(), + saved_run=self.saved_run, + user=user, + title=title, + notes=notes, + visibility=visibility, + ) + + def get_app_url(self): + return Workflow(self.workflow).get_app_url( + example_id=self.published_run_id, + run_id="", + uid="", + run_slug=self.title and slugify(self.title), + ) + + def add_version( + self, + *, + user: AppUser, + saved_run: SavedRun, + visibility: PublishedRunVisibility, + title: str, + notes: str, + ): + assert saved_run.workflow == self.workflow + + with transaction.atomic(): + version = PublishedRunVersion( + published_run=self, + version_id=get_random_doc_id(), + saved_run=saved_run, + changed_by=user, + title=title, + notes=notes, + visibility=visibility, + ) + version.save() + self.update_fields_to_latest_version() + + def is_editor(self, user: AppUser): + return self.created_by == user + + def is_root(self): + return not self.published_run_id + + def update_fields_to_latest_version(self): + latest_version = self.versions.latest() + self.saved_run = latest_version.saved_run + self.last_edited_by = latest_version.changed_by + self.title = latest_version.title + self.notes = latest_version.notes + self.visibility = latest_version.visibility + + self.save() + + def get_run_count(self): + annotated_versions = self.versions.annotate( + children_runs_count=models.Count("children_runs") + ) + return ( + annotated_versions.aggregate(run_count=models.Sum("children_runs_count"))[ + "run_count" + ] + or 0 + ) + + +class PublishedRunVersion(models.Model): + version_id = models.CharField(max_length=128, unique=True) + + published_run = models.ForeignKey( + PublishedRun, + on_delete=models.CASCADE, + related_name="versions", + ) + saved_run = models.ForeignKey( + SavedRun, + on_delete=models.PROTECT, + related_name="published_run_versions", + ) + changed_by = models.ForeignKey( + "app_users.AppUser", + on_delete=models.SET_NULL, # TODO: set to sentinel instead (e.g. github's ghost user) + null=True, + ) + title = models.TextField(blank=True, default="") + notes = models.TextField(blank=True, default="") + visibility = models.IntegerField( + choices=PublishedRunVisibility.choices, + default=PublishedRunVisibility.UNLISTED, + ) + created_at = models.DateTimeField(auto_now_add=True) + + class Meta: + ordering = ["-created_at"] + get_latest_by = "created_at" + indexes = [ + models.Index(fields=["published_run", "-created_at"]), + models.Index(fields=["version_id"]), + ] + + def __str__(self): + return f"{self.published_run} - {self.version_id}" diff --git a/bots/tasks.py b/bots/tasks.py index 004a04408..5dda17e90 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -1,12 +1,25 @@ import json from celery import shared_task +from django.db.models import QuerySet from app_users.models import AppUser -from bots.models import Message, CHATML_ROLE_ASSISSTANT, BotIntegration +from bots.models import ( + Message, + CHATML_ROLE_ASSISSTANT, + BotIntegration, + Conversation, + Platform, +) +from daras_ai_v2.facebook_bots import WhatsappBot from daras_ai_v2.functional import flatten, map_parallel -from daras_ai_v2.slack_bot import fetch_channel_members, create_personal_channel +from daras_ai_v2.slack_bot import ( + fetch_channel_members, + create_personal_channel, + SlackBot, +) from daras_ai_v2.vector_search import references_as_prompt +from recipes.VideoBots import ReplyButton @shared_task @@ -54,3 +67,85 @@ def msg_analysis(msg_id: int): Message.objects.filter(id=msg_id).update( analysis_result=json.loads(flatten(sr.state["output_text"].values())[0]), ) + + +def send_broadcast_msgs_chunked( + *, + text: str, + audio: str, + video: str, + documents: list[str], + buttons: list[ReplyButton] = None, + convo_qs: QuerySet[Conversation], + bi: BotIntegration, +): + convo_ids = list(convo_qs.values_list("id", flat=True)) + for i in range(0, len(convo_ids), 100): + send_broadcast_msg.delay( + text=text, + audio=audio, + video=video, + buttons=buttons, + documents=documents, + bi_id=bi.id, + convo_ids=convo_ids[i : i + 100], + ) + + +@shared_task +def send_broadcast_msg( + *, + text: str | None, + audio: str = None, + video: str = None, + buttons: list[ReplyButton] = None, + documents: list[str] = None, + bi_id: int, + convo_ids: list[int], +): + bi = BotIntegration.objects.get(id=bi_id) + convos = Conversation.objects.filter(id__in=convo_ids) + for convo in convos: + match bi.platform: + case Platform.WHATSAPP: + msg_id = WhatsappBot.send_msg_to( + text=text, + audio=audio, + video=video, + buttons=buttons, + documents=documents, + bot_number=bi.wa_phone_number_id, + user_number=convo.wa_phone_number.as_e164, + ) + case Platform.SLACK: + msg_id = SlackBot.send_msg_to( + text=text, + audio=audio, + video=video, + buttons=buttons, + documents=documents, + channel=convo.slack_channel_id, + channel_is_personal=convo.slack_channel_is_personal, + username=bi.name, + token=bi.slack_access_token, + )[0] + case _: + raise NotImplementedError( + f"Platform {bi.platform} doesn't support broadcasts yet" + ) + # save_broadcast_message(convo, text, msg_id) + + +## Disabled for now to prevent messing up the chat history +# def save_broadcast_message(convo: Conversation, text: str, msg_id: str | None = None): +# message = Message( +# conversation=convo, +# role=CHATML_ROLE_ASSISTANT, +# content=text, +# display_content=text, +# saved_run=None, +# ) +# if msg_id: +# message.platform_msg_id = msg_id +# message.save() +# return message diff --git a/daras_ai_v2/all_pages.py b/daras_ai_v2/all_pages.py index 68231daee..096150662 100644 --- a/daras_ai_v2/all_pages.py +++ b/daras_ai_v2/all_pages.py @@ -3,6 +3,7 @@ from bots.models import Workflow from daras_ai_v2.base import BasePage +from recipes.BulkEval import BulkEvalPage from recipes.BulkRunner import BulkRunnerPage from recipes.ChyronPlant import ChyronPlantPage from recipes.CompareLLM import CompareLLMPage @@ -33,9 +34,10 @@ from recipes.VideoBots import VideoBotsPage from recipes.asr import AsrPage from recipes.embeddings_page import EmbeddingsPage +from recipes.VideoBotsStats import VideoBotsStatsPage # note: the ordering here matters! -all_home_pages_by_category = { +all_home_pages_by_category: dict[str, list[typing.Type[BasePage]]] = { "Featured": [ VideoBotsPage, DeforumSDPage, @@ -49,6 +51,7 @@ ], "LLMs, RAG, & Synthetic Data": [ BulkRunnerPage, + BulkEvalPage, DocExtractPage, CompareLLMPage, DocSearchPage, @@ -74,10 +77,15 @@ ], } -all_home_pages = [ +all_home_pages: list[typing.Type[BasePage]] = [ page for page_group in all_home_pages_by_category.values() for page in page_group ] +# hidden UI pages (that don't have api and don't show up in /explore) +all_hidden_pages = [ + VideoBotsStatsPage, +] + # exposed as API all_api_pages = all_home_pages.copy() + [ ChyronPlantPage, @@ -97,8 +105,10 @@ def normalize_slug(page_slug): page_slug_map: dict[str, typing.Type[BasePage]] = { - normalize_slug(slug): page for page in all_api_pages for slug in page.slug_versions -} + normalize_slug(slug): page + for page in (all_api_pages + all_hidden_pages) + for slug in page.slug_versions +} | {str(page.workflow.value): page for page in (all_api_pages + all_hidden_pages)} workflow_map: dict[Workflow, typing.Type[BasePage]] = { page.workflow: page for page in all_api_pages diff --git a/daras_ai_v2/api_examples_widget.py b/daras_ai_v2/api_examples_widget.py index da9a25a5e..8f5e1cf7e 100644 --- a/daras_ai_v2/api_examples_widget.py +++ b/daras_ai_v2/api_examples_widget.py @@ -93,19 +93,20 @@ def api_example_generator( """ 1. Generate an api key [below๐Ÿ‘‡](#api-keys) -2. Install [curl](https://everything.curl.dev/get) & add the `GOOEY_API_KEY` to your environment variables. -Never store the api key [in your code](https://12factor.net/config). +2. Install [curl](https://everything.curl.dev/get) & add the `GOOEY_API_KEY` to your environment variables. +Never store the api key [in your code](https://12factor.net/config). ```bash export GOOEY_API_KEY=sk-xxxx ``` -3. Run the following `curl` command in your terminal. +3. Run the following `curl` command in your terminal. If you encounter any issues, write to us at support@gooey.ai and make sure to include the full curl command and the error message. ```bash %s ``` """ - % curl_code.strip() + % curl_code.strip(), + unsafe_allow_html=True, ) with python: @@ -157,8 +158,8 @@ def api_example_generator( ) if as_async: py_code += r""" -from time import sleep - +from time import sleep + status_url = response.headers["Location"] while True: response = requests.get(status_url, headers={"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"]}) @@ -188,20 +189,21 @@ def api_example_generator( rf""" 1. Generate an api key [below๐Ÿ‘‡](#api-keys) -2. Install [requests](https://requests.readthedocs.io/en/latest/) & add the `GOOEY_API_KEY` to your environment variables. -Never store the api key [in your code](https://12factor.net/config). +2. Install [requests](https://requests.readthedocs.io/en/latest/) & add the `GOOEY_API_KEY` to your environment variables. +Never store the api key [in your code](https://12factor.net/config). ```bash $ python3 -m pip install requests $ export GOOEY_API_KEY=sk-xxxx ``` - -3. Use this sample code to call the API. + +3. Use this sample code to call the API. If you encounter any issues, write to us at support@gooey.ai and make sure to include the full code snippet and the error message. ```python %s ``` """ - % py_code + % py_code, + unsafe_allow_html=True, ) with js: @@ -276,7 +278,7 @@ def api_example_generator( if (!response.ok) { throw new Error(response.status); } - + const result = await response.json(); if (result.status === "completed") { console.log(response.status, result); @@ -302,18 +304,19 @@ def api_example_generator( r""" 1. Generate an api key [below๐Ÿ‘‡](#api-keys) -2. Install [node-fetch](https://www.npmjs.com/package/node-fetch) & add the `GOOEY_API_KEY` to your environment variables. -Never store the api key [in your code](https://12factor.net/config) and don't use direcly in the browser. +2. Install [node-fetch](https://www.npmjs.com/package/node-fetch) & add the `GOOEY_API_KEY` to your environment variables. +Never store the api key [in your code](https://12factor.net/config) and don't use direcly in the browser. ```bash $ npm install node-fetch $ export GOOEY_API_KEY=sk-xxxx ``` -3. Use this sample code to call the API. +3. Use this sample code to call the API. If you encounter any issues, write to us at support@gooey.ai and make sure to include the full code snippet and the error message. ```js %s ``` """ - % js_code + % js_code, + unsafe_allow_html=True, ) diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 0389a14b5..d8a922320 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -3,6 +3,7 @@ import subprocess import tempfile from enum import Enum +from time import sleep import langcodes import requests @@ -12,18 +13,18 @@ import gooey_ui as st from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri +from daras_ai_v2 import settings +from daras_ai_v2.exceptions import raise_for_status +from daras_ai_v2.functional import map_parallel from daras_ai_v2.gdrive_downloader import ( is_gdrive_url, gdrive_download, gdrive_metadata, url_to_gdrive_file_id, ) -from daras_ai_v2 import settings -from daras_ai_v2 import google_utils -from daras_ai_v2.functional import map_parallel +from daras_ai_v2.google_utils import get_google_auth_session from daras_ai_v2.gpu_server import call_celery_task from daras_ai_v2.redis_cache import redis_cache_decorator -from time import sleep SHORT_FILE_CUTOFF = 5 * 1024 * 1024 # 1 MB @@ -42,9 +43,15 @@ AZURE_SUPPORTED = {"af-ZA", "am-ET", "ar-AE", "ar-BH", "ar-DZ", "ar-EG", "ar-IL", "ar-IQ", "ar-JO", "ar-KW", "ar-LB", "ar-LY", "ar-MA", "ar-OM", "ar-PS", "ar-QA", "ar-SA", "ar-SY", "ar-TN", "ar-YE", "az-AZ", "bg-BG", "bn-IN", "bs-BA", "ca-ES", "cs-CZ", "cy-GB", "da-DK", "de-AT", "de-CH", "de-DE", "el-GR", "en-AU", "en-CA", "en-GB", "en-GH", "en-HK", "en-IE", "en-IN", "en-KE", "en-NG", "en-NZ", "en-PH", "en-SG", "en-TZ", "en-US", "en-ZA", "es-AR", "es-BO", "es-CL", "es-CO", "es-CR", "es-CU", "es-DO", "es-EC", "es-ES", "es-GQ", "es-GT", "es-HN", "es-MX", "es-NI", "es-PA", "es-PE", "es-PR", "es-PY", "es-SV", "es-US", "es-UY", "es-VE", "et-EE", "eu-ES", "fa-IR", "fi-FI", "fil-PH", "fr-BE", "fr-CA", "fr-CH", "fr-FR", "ga-IE", "gl-ES", "gu-IN", "he-IL", "hi-IN", "hr-HR", "hu-HU", "hy-AM", "id-ID", "is-IS", "it-CH", "it-IT", "ja-JP", "jv-ID", "ka-GE", "kk-KZ", "km-KH", "kn-IN", "ko-KR", "lo-LA", "lt-LT", "lv-LV", "mk-MK", "ml-IN", "mn-MN", "mr-IN", "ms-MY", "mt-MT", "my-MM", "nb-NO", "ne-NP", "nl-BE", "nl-NL", "pa-IN", "pl-PL", "ps-AF", "pt-BR", "pt-PT", "ro-RO", "ru-RU", "si-LK", "sk-SK", "sl-SI", "so-SO", "sq-AL", "sr-RS", "sv-SE", "sw-KE", "sw-TZ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-IN", "uz-UZ", "vi-VN", "wuu-CN", "yue-CN", "zh-CN", "zh-CN-shandong", "zh-CN-sichuan", "zh-HK", "zh-TW", "zu-ZA"} # fmt: skip MAX_POLLS = 100 +# https://deepgram.com/product/languages for the "general" model: +# DEEPGRAM_SUPPORTED = {"nl","en","en-AU","en-US","en-GB","en-NZ","en-IN","fr","fr-CA","de","hi","hi-Latn","id","it","ja","ko","cmn-Hans-CN","cmn-Hant-TW","no","pl","pt","pt-PT","pt-BR","ru","es","es-419","sv","tr","uk"} # fmt: skip +# but we only have the Nova tier so these are our languages (https://developers.deepgram.com/docs/models-languages-overview): +DEEPGRAM_SUPPORTED = {"en", "en-US", "en-AU", "en-GB", "en-NZ", "en-IN", "es", "es-419"} # fmt: skip + class AsrModels(Enum): whisper_large_v2 = "Whisper Large v2 (openai)" + whisper_large_v3 = "Whisper Large v3 (openai)" whisper_hindi_large_v2 = "Whisper Hindi Large v2 (Bhashini)" whisper_telugu_large_v2 = "Whisper Telugu Large v2 (Bhashini)" nemo_english = "Conformer English (ai4bharat.org)" @@ -55,8 +62,14 @@ class AsrModels(Enum): azure = "Azure Speech" seamless_m4t = "Seamless M4T (Facebook Research)" + def supports_auto_detect(self) -> bool: + return self not in { + self.azure, + } + asr_model_ids = { + AsrModels.whisper_large_v3: "vaibhavs10/incredibly-fast-whisper:37dfc0d6a7eb43ff84e230f74a24dab84e6bb7756c9b457dbdcceca3de7a4a04", AsrModels.whisper_large_v2: "openai/whisper-large-v2", AsrModels.whisper_hindi_large_v2: "vasista22/whisper-hindi-large-v2", AsrModels.whisper_telugu_large_v2: "vasista22/whisper-telugu-large-v2", @@ -75,9 +88,10 @@ class AsrModels(Enum): } asr_supported_languages = { + AsrModels.whisper_large_v3: WHISPER_SUPPORTED, AsrModels.whisper_large_v2: WHISPER_SUPPORTED, AsrModels.usm: CHIRP_SUPPORTED, - AsrModels.deepgram: WHISPER_SUPPORTED, + AsrModels.deepgram: DEEPGRAM_SUPPORTED, AsrModels.seamless_m4t: SEAMLESS_SUPPORTED, AsrModels.azure: AZURE_SUPPORTED, } @@ -106,6 +120,8 @@ def google_translate_language_selector( ###### Google Translate (*optional*) """, key="google_translate_target", + allow_none=True, + **kwargs, ): """ Streamlit widget for selecting a language for Google Translate. @@ -115,12 +131,14 @@ def google_translate_language_selector( """ languages = google_translate_languages() options = list(languages.keys()) - options.insert(0, None) - st.selectbox( + if allow_none: + options.insert(0, None) + return st.selectbox( label=label, key=key, format_func=lambda k: languages[k] if k else "โ€”โ€”โ€”", options=options, + **kwargs, ) @@ -145,6 +163,34 @@ def google_translate_languages() -> dict[str, str]: } +@redis_cache_decorator +def google_translate_input_languages() -> dict[str, str]: + """ + Get list of supported languages for Google Translate. + :return: Dictionary of language codes and display names. + """ + from google.cloud import translate + + _, project = get_google_auth_session() + parent = f"projects/{project}/locations/global" + client = translate.TranslationServiceClient() + supported_languages = client.get_supported_languages( + parent=parent, display_language_code="en" + ) + return { + lang.language_code: lang.display_name + for lang in supported_languages.languages + if lang.support_source + } + + +def get_language_in_collection(langcode: str, languages): + for lang in languages: + if langcodes.get(lang).language == langcodes.get(langcode).language: + return langcode + return None + + def asr_language_selector( selected_model: AsrModels, label="##### Spoken Language", @@ -156,7 +202,9 @@ def asr_language_selector( st.session_state[key] = forced_lang return forced_lang - options = [None, *asr_supported_languages.get(selected_model, [])] + options = list(asr_supported_languages.get(selected_model, [])) + if selected_model and selected_model.supports_auto_detect(): + options.insert(0, None) # handle non-canonical language codes old_val = st.session_state.get(key) @@ -195,6 +243,19 @@ def run_google_translate( """ from google.cloud import translate_v2 as translate + # convert to BCP-47 format (google handles consistent language codes but sometimes gets confused by a mix of iso2 and iso3 which we have) + if source_language: + source_language = langcodes.Language.get(source_language).to_tag() + source_language = get_language_in_collection( + source_language, google_translate_input_languages().keys() + ) # this will default to autodetect if language is not found as supported + target_language = langcodes.Language.get(target_language).to_tag() + target_language: str | None = get_language_in_collection( + target_language, google_translate_languages().keys() + ) + if not target_language: + raise ValueError(f"Unsupported target language: {target_language!r}") + # if the language supports transliteration, we should check if the script is Latin if source_language and source_language not in TRANSLITERATION_SUPPORTED: language_codes = [source_language] * len(texts) @@ -225,19 +286,20 @@ def _translate_text( ) # prevent incorrect API calls - if source_language == target_language or not text: + if not text or source_language == target_language or source_language == "und": return text if source_language == "wo-SN" or target_language == "wo-SN": return _MinT_translate_one_text(text, source_language, target_language) config = { - "source_language_code": source_language, "target_language_code": target_language, "contents": text, "mime_type": "text/plain", "transliteration_config": {"enable_transliteration": enable_transliteration}, } + if source_language != "auto": + config["source_language_code"] = source_language # glossary does not work with transliteration if glossary_url and not enable_transliteration: @@ -260,7 +322,7 @@ def _translate_text( f"https://translation.googleapis.com/v3/projects/{project}/locations/{location}:translateText", json=config, ) - res.raise_for_status() + raise_for_status(res) data = res.json() try: result = data["glossaryTranslations"][0]["translatedText"] @@ -278,7 +340,7 @@ def _MinT_translate_one_text( f"https://translate.wmcloud.org/api/translate/{source_language}/{target_language}", json={"text": text}, ) - res.raise_for_status() + raise_for_status(res) # e.g. {"model":"IndicTrans2_indec_en","sourcelanguage":"hi","targetlanguage":"en","translation":"hello","translationtime":0.8} tanslation = res.json() @@ -327,6 +389,19 @@ def run_asr( if selected_model == AsrModels.azure: return azure_asr(audio_url, language) + elif selected_model == AsrModels.whisper_large_v3: + import replicate + + config = { + "audio": audio_url, + "return_timestamps": output_format != AsrOutputFormat.text, + } + if language: + config["language"] = language + data = replicate.run( + asr_model_ids[AsrModels.whisper_large_v3], + input=config, + ) elif selected_model == AsrModels.deepgram: r = requests.post( "https://api.deepgram.com/v1/listen", @@ -344,7 +419,7 @@ def run_asr( "url": audio_url, }, ) - r.raise_for_status() + raise_for_status(r) data = r.json() result = data["results"]["channels"][0]["alternatives"][0] chunk = None @@ -532,19 +607,6 @@ def azure_asr(audio_url: str, language: str): }, "locale": language or "en-US", } - if not language: - payload["properties"]["languageIdentification"] = { - "candidateLocales": [ - "en-US", - "en-IN", - "hi-IN", - "te-IN", - "ta-IN", - "kn-IN", - "es-ES", - "de-DE", - ] - } r = requests.post( str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/transcriptions"), headers={ @@ -553,7 +615,7 @@ def azure_asr(audio_url: str, language: str): }, json=payload, ) - r.raise_for_status() + raise_for_status(r) uri = r.json()["self"] # poll for results @@ -565,7 +627,7 @@ def azure_asr(audio_url: str, language: str): }, ) if not r.ok or not r.json()["status"] == "Succeeded": - sleep(1) + sleep(5) continue r = requests.get( r.json()["links"]["files"], @@ -573,7 +635,7 @@ def azure_asr(audio_url: str, language: str): "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, }, ) - r.raise_for_status() + raise_for_status(r) transcriptions = [] for value in r.json()["values"]: if value["kind"] != "Transcription": @@ -582,8 +644,9 @@ def azure_asr(audio_url: str, language: str): value["links"]["contentUrl"], headers={"Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY}, ) - r.raise_for_status() - transcriptions += [r.json()["combinedRecognizedPhrases"][0]["display"]] + raise_for_status(r) + combined_phrases = r.json().get("combinedRecognizedPhrases") or [{}] + transcriptions += [combined_phrases[0].get("display", "")] return "\n".join(transcriptions) assert False, "Max polls exceeded, Azure speech did not yield a response" @@ -626,7 +689,7 @@ def download_youtube_to_wav(youtube_url: str) -> tuple[str, int]: def audio_url_to_wav(audio_url: str) -> tuple[str, int]: r = requests.get(audio_url) - r.raise_for_status() + raise_for_status(r) wavdata, size = audio_bytes_to_wav(r.content) if not wavdata: diff --git a/daras_ai_v2/azure_doc_extract.py b/daras_ai_v2/azure_doc_extract.py index 484fc86ce..b14179dba 100644 --- a/daras_ai_v2/azure_doc_extract.py +++ b/daras_ai_v2/azure_doc_extract.py @@ -9,6 +9,7 @@ from jinja2.lexer import whitespace_re from daras_ai_v2 import settings +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.redis_cache import redis_cache_decorator from daras_ai_v2.text_splitter import default_length_function @@ -24,7 +25,21 @@ def azure_doc_extract_pages(pdf_url: str, model_id: str = "prebuilt-layout"): @redis_cache_decorator -def azure_form_recognizer(pdf_url: str, model_id: str): +def azure_form_recognizer_models() -> dict[str, str]: + r = requests.get( + str( + furl(settings.AZURE_FORM_RECOGNIZER_ENDPOINT) + / "formrecognizer/documentModels" + ), + params={"api-version": "2023-07-31"}, + headers=auth_headers, + ) + raise_for_status(r) + return {value["modelId"]: value["description"] for value in r.json()["value"]} + + +@redis_cache_decorator +def azure_form_recognizer(url: str, model_id: str): r = requests.post( str( furl(settings.AZURE_FORM_RECOGNIZER_ENDPOINT) @@ -32,13 +47,13 @@ def azure_form_recognizer(pdf_url: str, model_id: str): ), params={"api-version": "2023-07-31"}, headers=auth_headers, - json={"urlSource": pdf_url}, + json={"urlSource": url}, ) - r.raise_for_status() + raise_for_status(r) location = r.headers["Operation-Location"] while True: r = requests.get(location, headers=auth_headers) - r.raise_for_status() + raise_for_status(r) r_json = r.json() match r_json.get("status"): case "succeeded": diff --git a/daras_ai_v2/azure_image_moderation.py b/daras_ai_v2/azure_image_moderation.py index f1f52ebfa..30da7a561 100644 --- a/daras_ai_v2/azure_image_moderation.py +++ b/daras_ai_v2/azure_image_moderation.py @@ -4,6 +4,7 @@ import requests from daras_ai_v2 import settings +from daras_ai_v2.exceptions import raise_for_status def get_auth_headers(): @@ -21,7 +22,7 @@ def run_moderator(image_url: str, cache: bool) -> dict[str, Any]: headers=get_auth_headers(), json={"DataRepresentation": "URL", "Value": image_url}, ) - r.raise_for_status() + raise_for_status(r) return r.json() diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 8a10f06de..f05a0c71c 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1,19 +1,22 @@ import datetime import html import inspect +import math import typing import urllib import urllib.parse import uuid from copy import deepcopy +from enum import Enum +from itertools import pairwise from random import Random from time import sleep from types import SimpleNamespace -import math import requests import sentry_sdk from django.utils import timezone +from django.utils.text import slugify from fastapi import HTTPException from firebase_admin import auth from furl import furl @@ -25,10 +28,16 @@ import gooey_ui as st from app_users.models import AppUser, AppUserTransaction -from bots.models import SavedRun, Workflow -from daras_ai.image_input import truncate_text_words +from bots.models import ( + SavedRun, + PublishedRun, + PublishedRunVersion, + PublishedRunVisibility, + Workflow, +) from daras_ai_v2 import settings from daras_ai_v2.api_examples_widget import api_example_generator +from daras_ai_v2.breadcrumbs import render_breadcrumbs, get_title_breadcrumbs from daras_ai_v2.copy_to_clipboard_button_widget import ( copy_to_clipboard_button, ) @@ -47,19 +56,21 @@ ) from daras_ai_v2.query_params_util import ( extract_query_params, - EXAMPLE_ID_QUERY_PARAM, - RUN_ID_QUERY_PARAM, - USER_ID_QUERY_PARAM, ) from daras_ai_v2.send_email import send_reported_run_email from daras_ai_v2.tabs_widget import MenuTabs -from daras_ai_v2.user_date_widgets import render_js_dynamic_dates, js_dynamic_date +from daras_ai_v2.user_date_widgets import ( + render_js_dynamic_dates, + re_render_js_dynamic_dates, + js_dynamic_date, +) from gooey_ui import realtime_clear_subs +from gooey_ui.components.modal import Modal from gooey_ui.pubsub import realtime_pull DEFAULT_META_IMG = ( # Small - "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/optimized%20hp%20gif.gif" + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/b0f328d0-93f7-11ee-bd89-02420a0001cc/Main.jpg.png" # "https://storage.googleapis.com/dara-c1b52.appspot.com/meta_tag_default_img.jpg" # Big # "https://storage.googleapis.com/dara-c1b52.appspot.com/meta_tag_gif.gif" @@ -71,10 +82,14 @@ SUBMIT_AFTER_LOGIN_Q = "submitafterlogin" -class StateKeys: - page_title = "__title" - page_notes = "__notes" +class RecipeRunState(Enum): + idle = 1 + running = 2 + completed = 3 + failed = 4 + +class StateKeys: created_at = "created_at" updated_at = "updated_at" @@ -93,6 +108,8 @@ class BasePage: sane_defaults: dict = {} + explore_image: str = None + RequestModel: typing.Type[BaseModel] ResponseModel: typing.Type[BaseModel] @@ -120,11 +137,17 @@ def app_url( query_params = cls.clean_query_params( example_id=example_id, run_id=run_id, uid=uid ) | (query_params or {}) - f = furl(settings.APP_BASE_URL, query_params=query_params) / ( - cls.slug_versions[-1] + "/" + f = ( + furl(settings.APP_BASE_URL, query_params=query_params) + / cls.slug_versions[-1] ) + if example_id := query_params.get("example_id"): + pr = cls.get_published_run(published_run_id=example_id) + if pr and pr.title: + f /= slugify(pr.title) if tab_name: - f /= tab_name + "/" + f /= tab_name + f /= "/" # keep trailing slash return str(f) @classmethod @@ -148,20 +171,35 @@ def api_url(self, example_id=None, run_id=None, uid=None) -> furl: def endpoint(self) -> str: return f"/v2/{self.slug_versions[0]}/" - def render(self): + def get_tab_url(self, tab: str) -> str: + example_id, run_id, uid = extract_query_params(gooey_get_query_params()) + return self.app_url( + example_id=example_id, + run_id=run_id, + uid=uid, + tab_name=MenuTabs.paths[tab], + ) + + def setup_render(self): with sentry_sdk.configure_scope() as scope: scope.set_extra("base_url", self.app_url()) scope.set_transaction_name( "/" + self.slug_versions[0], source=TRANSACTION_SOURCE_ROUTE ) - example_id, run_id, uid = extract_query_params(gooey_get_query_params()) - if st.session_state.get(StateKeys.run_status): - channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}" - output = realtime_pull([channel])[0] - if output: - st.session_state.update(output) - if not st.session_state.get(StateKeys.run_status): + def refresh_state(self): + _, run_id, uid = extract_query_params(gooey_get_query_params()) + channel = self.realtime_channel_name(run_id, uid) + output = realtime_pull([channel])[0] + if output: + st.session_state.update(output) + + def render(self): + self.setup_render() + + if self.get_run_state(st.session_state) == RecipeRunState.running: + self.refresh_state() + else: realtime_clear_subs() self._user_disabled_check() @@ -171,86 +209,470 @@ def render(self): self.render_report_form() return - st.session_state.setdefault(StateKeys.page_title, self.title) - st.session_state.setdefault( - StateKeys.page_notes, self.preview_description(st.session_state) - ) + self._render_header() - self._render_page_title_with_breadcrumbs(example_id, run_id, uid) - st.write(st.session_state.get(StateKeys.page_notes)) + self._render_tab_menu(selected_tab=self.tab) + with st.nav_tab_content(): + self.render_selected_tab(self.tab) - try: - selected_tab = MenuTabs.paths_reverse[self.tab] - except KeyError: - st.error(f"## 404 - Tab {self.tab!r} Not found") - return + def _render_tab_menu(self, selected_tab: str): + assert selected_tab in MenuTabs.paths with st.nav_tabs(): - tab_names = self.get_tabs() - for name in tab_names: - url = self.app_url( - *extract_query_params(gooey_get_query_params()), - tab_name=MenuTabs.paths[name], - ) + for name in self.get_tabs(): + url = self.get_tab_url(name) with st.nav_item(url, active=name == selected_tab): st.html(name) - with st.nav_tab_content(): - self.render_selected_tab(selected_tab) - def _render_page_title_with_breadcrumbs( - self, example_id: str, run_id: str, uid: str + def _render_header(self): + current_run = self.get_current_sr() + published_run = self.get_current_published_run() + is_root_example = ( + published_run + and published_run.is_root() + and published_run.saved_run == current_run + ) + tbreadcrumbs = get_title_breadcrumbs( + self, current_run, published_run, tab=self.tab + ) + + with st.div(className="d-flex justify-content-between mt-4"): + with st.div(className="d-lg-flex d-block align-items-center"): + if not tbreadcrumbs.has_breadcrumbs() and not self.run_user: + self._render_title(tbreadcrumbs.h1_title) + + if tbreadcrumbs: + with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): + render_breadcrumbs(tbreadcrumbs) + + author = self.run_user or current_run.get_creator() + if not is_root_example: + self.render_author( + author, + show_as_link=self.is_current_user_admin(), + ) + + with st.div(className="d-flex align-items-center"): + can_user_edit_run = self.is_current_user_admin() or ( + self.request + and self.request.user + and current_run.uid == self.request.user.uid + ) + has_unpublished_changes = ( + published_run + and published_run.saved_run != current_run + and self.request + and self.request.user + and published_run.created_by == self.request.user + ) + + if can_user_edit_run and has_unpublished_changes: + self._render_unpublished_changes_indicator() + + with st.div(className="d-flex align-items-start right-action-icons"): + st.html( + """ + + """ + ) + + if published_run and can_user_edit_run: + self._render_published_run_buttons( + current_run=current_run, + published_run=published_run, + ) + + self._render_social_buttons(show_button_text=not can_user_edit_run) + + with st.div(): + if tbreadcrumbs.has_breadcrumbs() or self.run_user: + # only render title here if the above row was not empty + self._render_title(tbreadcrumbs.h1_title) + if published_run and published_run.notes: + st.write(published_run.notes) + elif is_root_example: + st.write(self.preview_description(current_run.to_dict())) + + def _render_title(self, title: str): + st.write(f"# {title}") + + def _render_unpublished_changes_indicator(self): + with st.div( + className="d-none d-lg-flex h-100 align-items-center text-muted ms-2" + ): + with st.tag("span", className="d-inline-block"): + st.html("Unpublished changes") + + def _render_social_buttons(self, show_button_text: bool = False): + button_text = ( + ' Copy Link' + if show_button_text + else "" + ) + + copy_to_clipboard_button( + f'{button_text}', + value=self._get_current_app_url(), + type="secondary", + className="mb-0 ms-lg-2", + ) + + def _render_published_run_buttons( + self, + *, + current_run: SavedRun, + published_run: PublishedRun, ): - if example_id or run_id: - # the title on the saved root / the hardcoded title - recipe_title = ( - self.recipe_doc_sr().to_dict().get(StateKeys.page_title) or self.title - ) + is_update_mode = ( + self.is_current_user_admin() + or published_run.created_by == self.request.user + ) - # the user saved title for the current run (if its not the same as the recipe title) - current_title = st.session_state.get(StateKeys.page_title) - if current_title == recipe_title: - current_title = "" + with st.div(className="d-flex justify-content-end"): + st.html( + """ + + """ + ) - # prefer the prompt as h1 title for runs, but not for examples - prompt_title = truncate_text_words( - self.preview_input(st.session_state) or "", maxlen=60 + pressed_options = is_update_mode and st.button( + '', + className="mb-0 ms-lg-2", + type="tertiary", ) - if run_id: - h1_title = prompt_title or current_title or recipe_title + options_modal = Modal("Options", key="published-run-options-modal") + if pressed_options: + options_modal.open() + if options_modal.is_open(): + with options_modal.container(style={"min-width": "min(300px, 100vw)"}): + self._render_options_modal( + current_run=current_run, + published_run=published_run, + modal=options_modal, + ) + + save_icon = '' + if is_update_mode: + save_text = "Update" else: - h1_title = current_title or prompt_title or recipe_title - - # render recipe title if it doesn't clash with the h1 title - render_item1 = recipe_title and recipe_title != h1_title - # render current title if it doesn't clash with the h1 title - render_item2 = current_title and current_title != h1_title - if render_item1 or render_item2: # avoids empty space - with st.breadcrumbs(className="mt-4"): - if render_item1: - st.breadcrumb_item( - recipe_title, - link_to=self.app_url(), - className="text-muted", - ) - if render_item2: - current_sr = self.get_sr_from_query_params( - example_id, run_id, uid - ) - st.breadcrumb_item( - current_title, - link_to=current_sr.parent.get_app_url() - if current_sr.parent_id - else None, - ) - st.write(f"# {h1_title}") + save_text = "Save" + pressed_save = st.button( + f'{save_icon} {save_text}', + className="mb-0 ms-lg-2 px-lg-4", + type="primary", + ) + publish_modal = Modal("Publish to", key="publish-modal") + if pressed_save: + publish_modal.open() + if publish_modal.is_open(): + with publish_modal.container(style={"min-width": "min(500px, 100vw)"}): + self._render_publish_modal( + current_run=current_run, + published_run=published_run, + modal=publish_modal, + is_update_mode=is_update_mode, + ) + + def _render_publish_modal( + self, + *, + current_run: SavedRun, + published_run: PublishedRun, + modal: Modal, + is_update_mode: bool = False, + ): + if published_run.is_root() and self.is_current_user_admin(): + with st.div(className="text-danger"): + st.write( + "###### You're about to update the root workflow as an admin. " + ) + st.html( + 'If you want to create a new example, press and "Duplicate" instead.' + ) + published_run_visibility = PublishedRunVisibility.PUBLIC else: - with st.link( - to=self.app_url(), className="text-decoration-none", target="_blank" + with st.div(className="visibility-radio"): + options = { + str(enum.value): enum.help_text() for enum in PublishedRunVisibility + } + published_run_visibility = PublishedRunVisibility( + int( + st.radio( + "", + options=options, + format_func=options.__getitem__, + value=str(published_run.visibility), + ) + ) + ) + st.radio( + "", + options=[ + 'Anyone at my org (coming soon)' + ], + disabled=True, + checked_by_default=False, + ) + + with st.div(className="mt-4"): + if is_update_mode: + title = published_run.title or self.title + else: + recipe_title = self.get_root_published_run().title or self.title + if self.request.user.display_name: + username = self.request.user.display_name + "'s" + elif self.request.user.email: + username = self.request.user.email.split("@")[0] + "'s" + else: + username = "My" + title = f"{username} {recipe_title}" + published_run_title = st.text_input( + "##### Title", + key="published_run_title", + value=title, + ) + published_run_notes = st.text_area( + "##### Notes", + key="published_run_notes", + value=( + published_run.notes + or self.preview_description(st.session_state) + or "" + ), + ) + + with st.div(className="mt-4 d-flex justify-content-center"): + pressed_save = st.button( + f' Save', + className="px-4", + type="primary", + ) + + self._render_admin_options(current_run, published_run) + + if not pressed_save: + return + + is_root_published_run = is_update_mode and published_run.is_root() + if not is_root_published_run: + try: + self._validate_published_run_title(published_run_title) + except TitleValidationError as e: + st.error(str(e)) + return + + if is_update_mode: + updates = dict( + saved_run=current_run, + title=published_run_title.strip(), + notes=published_run_notes.strip(), + visibility=published_run_visibility, + ) + if not self._has_published_run_changed( + published_run=published_run, **updates ): - st.write(f"# {self.get_recipe_title(st.session_state)}") + st.error("No changes to publish", icon="โš ๏ธ") + return + published_run.add_version(user=self.request.user, **updates) + else: + published_run = self.create_published_run( + published_run_id=get_random_doc_id(), + saved_run=current_run, + user=self.request.user, + title=published_run_title.strip(), + notes=published_run_notes.strip(), + visibility=published_run_visibility, + ) + force_redirect(published_run.get_app_url()) + + def _validate_published_run_title(self, title: str): + if slugify(title) in settings.DISALLOWED_TITLE_SLUGS: + raise TitleValidationError( + "This title is not allowed. Please choose a different title." + ) + elif title.strip() == self.get_recipe_title(): + raise TitleValidationError( + "Please choose a different title for your published run." + ) + elif title.strip() == "": + raise TitleValidationError("Title cannot be empty.") + + def _has_published_run_changed( + self, + *, + published_run: PublishedRun, + saved_run: SavedRun, + title: str, + notes: str, + visibility: PublishedRunVisibility, + ): + return ( + published_run.title != title + or published_run.notes != notes + or published_run.visibility != visibility + or published_run.saved_run != saved_run + ) + + def _render_options_modal( + self, + *, + current_run: SavedRun, + published_run: PublishedRun, + modal: Modal, + ): + is_latest_version = published_run.saved_run == current_run + + with st.div(className="mt-4"): + duplicate_button = None + save_as_new_button = None + duplicate_icon = save_as_new_icon = '' + if is_latest_version: + duplicate_button = st.button( + f"{duplicate_icon} Duplicate", className="w-100" + ) + else: + save_as_new_button = st.button( + f"{save_as_new_icon} Save as New", className="w-100" + ) + delete_button = not published_run.is_root() and st.button( + f' Delete', + className="w-100 text-danger", + ) + + if duplicate_button: + duplicate_pr = self.duplicate_published_run( + published_run, + title=f"{published_run.title} (Copy)", + notes=published_run.notes, + visibility=PublishedRunVisibility(PublishedRunVisibility.UNLISTED), + ) + raise QueryParamsRedirectException( + query_params=dict(example_id=duplicate_pr.published_run_id), + ) + + if save_as_new_button: + new_pr = self.create_published_run( + published_run_id=get_random_doc_id(), + saved_run=current_run, + user=self.request.user, + title=f"{published_run.title} (Copy)", + notes=published_run.notes, + visibility=PublishedRunVisibility(PublishedRunVisibility.UNLISTED), + ) + raise QueryParamsRedirectException( + query_params=dict(example_id=new_pr.published_run_id) + ) + + with st.div(className="mt-4"): + st.write("#### Version History", className="mb-4") + self._render_version_history() + + confirm_delete_modal = Modal("Confirm Delete", key="confirm-delete-modal") + if delete_button: + confirm_delete_modal.open() + if confirm_delete_modal.is_open(): + modal.empty() + with confirm_delete_modal.container(): + self._render_confirm_delete_modal( + published_run=published_run, + modal=confirm_delete_modal, + ) + + def _render_confirm_delete_modal( + self, + *, + published_run: PublishedRun, + modal: Modal, + ): + st.write( + "Are you sure you want to delete this published run? " + f"_({published_run.title})_" + ) + st.caption("This will also delete all the associated versions.") + with st.div(className="d-flex"): + confirm_button = st.button( + 'Confirm', + type="secondary", + className="w-100", + ) + cancel_button = st.button( + "Cancel", + type="secondary", + className="w-100", + ) + + if confirm_button: + published_run.delete() + raise QueryParamsRedirectException(query_params={}) + + if cancel_button: + modal.close() + + def _render_admin_options(self, current_run: SavedRun, published_run: PublishedRun): + if ( + not self.is_current_user_admin() + or published_run.is_root() + or published_run.saved_run != current_run + ): + return + + with st.expander("๐Ÿ› ๏ธ Admin Options"): + st.write( + f"This will hide/show this workflow from {self.app_url(tab_name=MenuTabs.paths[MenuTabs.examples])} \n" + f"(Given that you have set public visibility above)" + ) + if st.session_state.get("--toggle-approve-example"): + published_run.is_approved_example = ( + not published_run.is_approved_example + ) + published_run.save(update_fields=["is_approved_example"]) + if published_run.is_approved_example: + btn_text = "๐Ÿ™ˆ Hide from Examples" + else: + btn_text = "โœ… Approve as Example" + st.button(btn_text, key="--toggle-approve-example") + + st.write("---") + + if st.checkbox("โญ๏ธ Save as Root Workflow"): + st.write( + f"Are you Sure? \n" + f"This will overwrite the contents of {self.app_url()}", + className="text-danger", + ) + if st.button("๐Ÿ‘Œ Yes, Update the Root Workflow"): + root_run = self.get_root_published_run() + root_run.add_version( + user=self.request.user, + title=published_run.title, + notes=published_run.notes, + saved_run=published_run.saved_run, + visibility=PublishedRunVisibility.PUBLIC, + ) + raise QueryParamsRedirectException(dict()) + + @classmethod + def get_recipe_title(cls) -> str: + return ( + cls.get_or_create_root_published_run().title + or cls.title + or cls.workflow.label + ) - def get_recipe_title(self, state: dict) -> str: - return state.get(StateKeys.page_title) or self.title or "" + def get_explore_image(self, state: dict) -> str: + return self.explore_image or "" def _user_disabled_check(self): if self.run_user and self.run_user.is_disabled: @@ -265,6 +687,8 @@ def get_tabs(self): tabs = [MenuTabs.run, MenuTabs.examples, MenuTabs.run_as_api] if self.request.user: tabs.extend([MenuTabs.history]) + if self.request.user and not self.request.user.is_anonymous: + tabs.extend([MenuTabs.saved]) return tabs def render_selected_tab(self, selected_tab: str): @@ -281,10 +705,9 @@ def render_selected_tab(self, selected_tab: str): col1, col2 = st.columns(2) with col1: self._render_help() - with col2: - self._render_save_options() self.render_related_workflows() + render_js_dynamic_dates() case MenuTabs.examples: self._examples_tab() @@ -297,6 +720,76 @@ def render_selected_tab(self, selected_tab: str): case MenuTabs.run_as_api: self.run_as_api_tab() + case MenuTabs.saved: + self._saved_tab() + render_js_dynamic_dates() + + def _render_version_history(self): + published_run = self.get_current_published_run() + + if published_run: + versions = published_run.versions.all() + first_version = versions[0] + for version, older_version in pairwise(versions): + first_version = older_version + self._render_version_row(version, older_version) + self._render_version_row(first_version, None) + re_render_js_dynamic_dates() + + def _render_version_row( + self, + version: PublishedRunVersion, + older_version: PublishedRunVersion | None, + ): + st.html( + """ + + """ + ) + url = self.app_url( + example_id=version.published_run.published_run_id, + run_id=version.saved_run.run_id, + uid=version.saved_run.uid, + ) + with st.link(to=url, className="text-decoration-none"): + with st.div( + className="d-flex mb-4 disable-p-margin", + style={"min-width": "min(100vw, 500px)"}, + ): + col1 = st.div(className="me-4") + col2 = st.div() + with col1: + with st.div(className="fs-5 mt-1"): + st.html('') + with col2: + is_first_version = not older_version + with st.div(className="fs-5 d-flex align-items-center"): + js_dynamic_date( + version.created_at, + container=self._render_version_history_date, + date_options={"month": "short", "day": "numeric"}, + ) + if is_first_version: + with st.tag("span", className="badge bg-secondary px-3 ms-2"): + st.write("FIRST VERSION") + with st.div(className="text-muted"): + if older_version and older_version.title != version.title: + st.write(f"Renamed: {version.title}") + elif not older_version: + st.write(version.title) + with st.div(className="mt-1", style={"font-size": "0.85rem"}): + self.render_author( + version.changed_by, image_size="18px", responsive=False + ) + + def _render_version_history_date(self, text, **props): + with st.tag("span", **props): + st.html(text) + def render_related_workflows(self): page_clses = self.related_workflows() if not page_clses: @@ -305,21 +798,22 @@ def render_related_workflows(self): with st.link(to="/explore/"): st.html("

Related Workflows

") - def _render(page_cls): + def _render(page_cls: typing.Type[BasePage]): page = page_cls() - state = page_cls().recipe_doc_sr().to_dict() + root_run = page.get_root_published_run() + state = root_run.saved_run.to_dict() preview_image = meta_preview_url( - page_cls().preview_image(state), page_cls().fallback_preivew_image() + page.get_explore_image(state), page.fallback_preivew_image() ) with st.link(to=page.app_url()): - st.markdown( + st.html( # language=html f"""
""" ) - st.markdown(f"###### {page.title}") + st.markdown(f"###### {root_run.title or page.title}") st.caption(page.preview_description(state)) grid_layout(4, page_clses, _render) @@ -423,9 +917,54 @@ def _check_if_flagged(self): # Return and Don't render the run any further st.stop() - def get_sr_from_query_params_dict(self, query_params) -> SavedRun: + @classmethod + def get_runs_from_query_params( + cls, example_id: str, run_id: str, uid: str + ) -> tuple[SavedRun, PublishedRun | None]: + if run_id and uid: + sr = cls.run_doc_sr(run_id, uid) + pr = sr.parent_published_run() + else: + pr = cls.get_published_run(published_run_id=example_id or "") + sr = pr.saved_run + return sr, pr + + @classmethod + def get_current_published_run(cls) -> PublishedRun | None: + example_id, run_id, uid = extract_query_params(gooey_get_query_params()) + return cls.get_pr_from_query_params(example_id, run_id, uid) + + @classmethod + def get_pr_from_query_params( + cls, example_id: str, run_id: str, uid: str + ) -> PublishedRun | None: + if run_id and uid: + sr = cls.get_sr_from_query_params(example_id, run_id, uid) + return sr.parent_published_run() + elif example_id: + return cls.get_published_run(published_run_id=example_id) + else: + return cls.get_root_published_run() + + @classmethod + def get_root_published_run(cls) -> PublishedRun: + return cls.get_published_run(published_run_id="") + + @classmethod + def get_published_run(cls, *, published_run_id: str): + return PublishedRun.objects.get( + workflow=cls.workflow, + published_run_id=published_run_id, + ) + + @classmethod + def get_current_sr(cls) -> SavedRun: + return cls.get_sr_from_query_params_dict(gooey_get_query_params()) + + @classmethod + def get_sr_from_query_params_dict(cls, query_params) -> SavedRun: example_id, run_id, uid = extract_query_params(query_params) - return self.get_sr_from_query_params(example_id, run_id, uid) + return cls.get_sr_from_query_params(example_id, run_id, uid) @classmethod def get_sr_from_query_params( @@ -435,41 +974,101 @@ def get_sr_from_query_params( if run_id and uid: sr = cls.run_doc_sr(run_id, uid) elif example_id: - sr = cls.example_doc_sr(example_id) + pr = cls.get_published_run(published_run_id=example_id) + assert ( + pr.saved_run is not None + ), "invalid published run: without a saved run" + sr = pr.saved_run else: sr = cls.recipe_doc_sr() return sr - except SavedRun.DoesNotExist: + except (SavedRun.DoesNotExist, PublishedRun.DoesNotExist): raise HTTPException(status_code=404) @classmethod - def recipe_doc_sr(cls) -> SavedRun: - return SavedRun.objects.get_or_create( + def get_total_runs(cls) -> int: + # TODO: fix to also handle published run case + return SavedRun.objects.filter(workflow=cls.workflow).count() + + @classmethod + def get_or_create_root_published_run(cls) -> PublishedRun: + published_run, _ = PublishedRun.objects.get_or_create( workflow=cls.workflow, - run_id__isnull=True, - uid__isnull=True, - example_id__isnull=True, - )[0] + published_run_id="", + defaults={ + "saved_run": lambda: cls.run_doc_sr( + run_id="", uid="", create=True, parent=None, parent_version=None + ), + "created_by": None, + "last_edited_by": None, + "title": cls.title, + "notes": cls().preview_description(state=cls.sane_defaults), + "visibility": PublishedRunVisibility(PublishedRunVisibility.PUBLIC), + "is_approved_example": True, + }, + ) + return published_run + + @classmethod + def recipe_doc_sr(cls, create: bool = False) -> SavedRun: + if create: + return cls.get_or_create_root_published_run().saved_run + else: + return cls.get_root_published_run().saved_run @classmethod def run_doc_sr( - cls, run_id: str, uid: str, create: bool = False, parent: SavedRun = None + cls, + run_id: str, + uid: str, + create: bool = False, + parent: SavedRun | None = None, + parent_version: PublishedRunVersion | None = None, ) -> SavedRun: config = dict(workflow=cls.workflow, uid=uid, run_id=run_id) if create: return SavedRun.objects.get_or_create( - **config, defaults=dict(parent=parent) + **config, + defaults=dict(parent=parent, parent_version=parent_version), )[0] else: return SavedRun.objects.get(**config) @classmethod - def example_doc_sr(cls, example_id: str, create: bool = False) -> SavedRun: - config = dict(workflow=cls.workflow, example_id=example_id) - if create: - return SavedRun.objects.get_or_create(**config)[0] - else: - return SavedRun.objects.get(**config) + def create_published_run( + cls, + *, + published_run_id: str, + saved_run: SavedRun, + user: AppUser, + title: str, + notes: str, + visibility: PublishedRunVisibility, + ): + return PublishedRun.objects.create_published_run( + workflow=cls.workflow, + published_run_id=published_run_id, + saved_run=saved_run, + user=user, + title=title, + notes=notes, + visibility=visibility, + ) + + def duplicate_published_run( + self, + published_run: PublishedRun, + *, + title: str, + notes: str, + visibility: PublishedRunVisibility, + ): + return published_run.duplicate( + user=self.request.user, + title=title, + notes=notes, + visibility=visibility, + ) def render_description(self): pass @@ -486,34 +1085,65 @@ def render_form_v2(self): def validate_form_v2(self): pass - def render_author(self): - if not self.run_user or ( - not self.run_user.photo_url and not self.run_user.display_name - ): + def render_author( + self, + user: AppUser, + *, + image_size: str = "30px", + responsive: bool = True, + show_as_link: bool = False, + text_size: str | None = None, + ): + if not user or (not user.photo_url and not user.display_name): return - html = "
" - if self.run_user.photo_url: - html += f""" - -
- """ - if self.run_user.display_name: - html += f"
{self.run_user.display_name}
" - html += "
" + responsive_image_size = ( + f"calc({image_size} * 0.67)" if responsive else image_size + ) - if self.is_current_user_admin(): - linkto = lambda: st.link( + # new class name so that different ones don't conflict + class_name = f"author-image-{image_size}" + if responsive: + class_name += "-responsive" + + if show_as_link: + linkto = st.link( to=self.app_url( tab_name=MenuTabs.paths[MenuTabs.history], - query_params={"uid": self.run_user.uid}, + query_params={"uid": user.uid}, ) ) else: - linkto = st.dummy + linkto = st.dummy() + + with linkto, st.div(className="d-flex align-items-center"): + if user.photo_url: + st.html( + f""" + + """ + ) + st.image(user.photo_url, className=class_name) - with linkto(): - st.html(html) + if user.display_name: + name_style = {"fontSize": text_size} if text_size else {} + with st.tag("span", style=name_style): + st.html(html.escape(user.display_name)) def get_credits_click_url(self): if self.request.user and self.request.user.is_anonymous: @@ -538,7 +1168,7 @@ def render_submit_button(self, key="--submit-1"): cost_note = f"({cost_note.strip()})" st.caption( f""" -Run cost = {self.get_price_roundoff(st.session_state)} credits {cost_note} +Run cost = {self.get_price_roundoff(st.session_state)} credits {cost_note} {self.additional_notes() or ""} """, unsafe_allow_html=True, @@ -555,7 +1185,7 @@ def render_submit_button(self, key="--submit-1"): try: self.validate_form_v2() except AssertionError as e: - st.error(e) + st.error(str(e)) return False else: return True @@ -630,7 +1260,7 @@ def run_v2( def _render_report_button(self): example_id, run_id, uid = extract_query_params(gooey_get_query_params()) - # only logged in users can report a run (but not explamples/default runs) + # only logged in users can report a run (but not examples/default runs) if not (self.request.user and run_id and uid): return @@ -650,20 +1280,6 @@ def _render_before_output(self): if not url: return - with st.div(className="d-flex gap-1"): - with st.div(className="flex-grow-1"): - st.text_input( - "recipe url", - label_visibility="collapsed", - disabled=True, - value=url.split("://")[1].rstrip("/"), - ) - copy_to_clipboard_button( - "๐Ÿ”— Copy URL", - value=url, - style="height: 3.2rem", - ) - def _get_current_app_url(self) -> str | None: example_id, run_id, uid = extract_query_params(gooey_get_query_params()) return self.app_url(example_id, run_id, uid) @@ -679,14 +1295,10 @@ def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool): st.session_state["is_flagged"] = is_flagged def _render_input_col(self): - self.render_author() self.render_form_v2() with st.expander("โš™๏ธ Settings"): self.render_settings() st.write("---") - st.write("##### ๐Ÿ–Œ๏ธ Personalize") - st.text_input("Title", key=StateKeys.page_title) - st.text_area("Notes", key=StateKeys.page_notes) submitted = self.render_submit_button() with st.div(style={"textAlign": "right"}): st.caption( @@ -695,6 +1307,18 @@ def _render_input_col(self): ) return submitted + @classmethod + def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState: + if state.get(StateKeys.run_status): + return RecipeRunState.running + elif state.get(StateKeys.error_msg): + return RecipeRunState.failed + elif state.get(StateKeys.run_time): + return RecipeRunState.completed + else: + # when user is at a recipe root, and not running anything + return RecipeRunState.idle + def _render_output_col(self, submitted: bool): assert inspect.isgeneratorfunction(self.run) @@ -708,27 +1332,61 @@ def _render_output_col(self, submitted: bool): self._render_before_output() - run_status = st.session_state.get(StateKeys.run_status) - if run_status: - st.caption("Your changes are saved in the above URL. Save it for later!") - html_spinner(run_status) - else: - err_msg = st.session_state.get(StateKeys.error_msg) - run_time = st.session_state.get(StateKeys.run_time, 0) - - # render errors - if err_msg is not None: - st.error(err_msg) - # render run time - elif run_time: - st.success(f"Success! Run Time: `{run_time:.2f}` seconds.") + run_state = self.get_run_state(st.session_state) + match run_state: + case RecipeRunState.completed: + self._render_completed_output() + case RecipeRunState.failed: + self._render_failed_output() + case RecipeRunState.running: + self._render_running_output() + case RecipeRunState.idle: + pass # render outputs self.render_output() - if not run_status: + if run_state != "waiting": self._render_after_output() + def _render_completed_output(self): + run_time = st.session_state.get(StateKeys.run_time, 0) + st.success(f"Success! Run Time: `{run_time:.2f}` seconds.") + + def _render_failed_output(self): + err_msg = st.session_state.get(StateKeys.error_msg) + st.error(err_msg, unsafe_allow_html=True) + + def _render_running_output(self): + run_status = st.session_state.get(StateKeys.run_status) + html_spinner(run_status) + self.render_extra_waiting_output() + + def render_extra_waiting_output(self): + estimated_run_time = self.estimate_run_duration() + if not estimated_run_time: + return + if created_at := st.session_state.get("created_at"): + if isinstance(created_at, datetime.datetime): + start_time = created_at + else: + start_time = datetime.datetime.fromisoformat(created_at) + with st.countdown_timer( + end_time=start_time + datetime.timedelta(seconds=estimated_run_time), + delay_text="Sorry for the wait. Your run is taking longer than we expected.", + ): + if self.is_current_user_owner() and self.request.user.email: + st.write( + f"""We'll email **{self.request.user.email}** when your workflow is done.""" + ) + st.write( + f"""In the meantime, check out [๐Ÿš€ Examples]({self.get_tab_url(MenuTabs.examples)}) + for inspiration.""" + ) + + def estimate_run_duration(self) -> int | None: + pass + def on_submit(self): example_id, run_id, uid = self.create_new_run() if settings.CREDITS_TO_DEDUCT_PER_RUN and not self.check_credits(): @@ -740,7 +1398,7 @@ def on_submit(self): else: self.call_runner_task(example_id, run_id, uid) raise QueryParamsRedirectException( - self.clean_query_params(example_id=example_id, run_id=run_id, uid=uid) + self.clean_query_params(example_id=None, run_id=run_id, uid=uid) ) def should_submit_after_login(self) -> bool: @@ -776,12 +1434,21 @@ def create_new_run(self): parent = self.get_sr_from_query_params( parent_example_id, parent_run_id, parent_uid ) + published_run = self.get_current_published_run() + try: + parent_version = published_run and published_run.versions.latest() + except PublishedRunVersion.DoesNotExist: + parent_version = None - self.run_doc_sr(run_id, uid, create=True, parent=parent).set( - self.state_to_doc(st.session_state) - ) + self.run_doc_sr( + run_id, + uid, + create=True, + parent=parent, + parent_version=parent_version, + ).set(self.state_to_doc(st.session_state)) - return parent_example_id, run_id, uid + return None, run_id, uid def call_runner_task(self, example_id, run_id, uid, is_api_call=False): from celeryapp.tasks import gui_runner @@ -792,13 +1459,16 @@ def call_runner_task(self, example_id, run_id, uid, is_api_call=False): run_id=run_id, uid=uid, state=st.session_state, - channel=f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}", + channel=self.realtime_channel_name(run_id, uid), query_params=self.clean_query_params( example_id=example_id, run_id=run_id, uid=uid ), is_api_call=is_api_call, ) + def realtime_channel_name(self, run_id, uid): + return f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}" + def generate_credit_error_message(self, example_id, run_id, uid) -> str: account_url = furl(settings.APP_BASE_URL) / "account/" if self.request.user.is_anonymous: @@ -811,7 +1481,7 @@ def generate_credit_error_message(self, example_id, run_id, uid) -> str: Doh! Please login to run more Gooey.AI workflows.

-Youโ€™ll receive {settings.LOGIN_USER_FREE_CREDITS} Credits when you sign up via your phone #, Google, Apple or GitHub account +Youโ€™ll receive {settings.LOGIN_USER_FREE_CREDITS} Credits when you sign up via your phone #, Google, Apple or GitHub account and can purchase more for $1/100 Credits. """ else: @@ -860,52 +1530,6 @@ def _render_after_output(self): with col3: self._render_report_button() - def _render_save_options(self): - if not self.is_current_user_admin(): - return - - parent_example_id, parent_run_id, parent_uid = extract_query_params( - gooey_get_query_params() - ) - current_sr = self.get_sr_from_query_params( - parent_example_id, parent_run_id, parent_uid - ) - - with st.expander("๐Ÿ› ๏ธ Admin Options"): - sr_to_save = None - - if st.button("โญ๏ธ Save Workflow"): - sr_to_save = self.recipe_doc_sr() - - if st.button("๐Ÿ”– Create new Example"): - sr_to_save = self.example_doc_sr(get_random_doc_id(), create=True) - - if parent_example_id: - if st.button("๐Ÿ’พ Save this Example"): - sr_to_save = self.example_doc_sr(parent_example_id) - - if current_sr.example_id: - hidden = st.session_state.get(StateKeys.hidden) - if st.button("๐Ÿ‘๏ธ Make Public" if hidden else "๐Ÿ™ˆ๏ธ Hide"): - self.set_hidden( - example_id=current_sr.example_id, - doc=st.session_state, - hidden=not hidden, - ) - - if sr_to_save: - if current_sr != sr_to_save: # ensure parent != child - sr_to_save.parent = current_sr - sr_to_save.set(self.state_to_doc(st.session_state)) - ## TODO: pass the success message to the redirect - # st.success("Saved", icon="โœ…") - raise QueryParamsRedirectException( - dict(example_id=sr_to_save.example_id) - ) - - if current_sr.parent_id: - st.write(f"Parent: {current_sr.parent.get_app_url()}") - def state_to_doc(self, state: dict): ret = { field_name: deepcopy(state[field_name]) @@ -913,13 +1537,6 @@ def state_to_doc(self, state: dict): if field_name in state } - title = state.get(StateKeys.page_title) - notes = state.get(StateKeys.page_notes) - if title and title.strip() != self.title.strip(): - ret[StateKeys.page_title] = title - if notes and notes.strip() != self.preview_description(state).strip(): - ret[StateKeys.page_notes] = notes - return ret def fields_to_save(self) -> [str]: @@ -935,28 +1552,44 @@ def fields_to_save(self) -> [str]: ] def _examples_tab(self): - allow_delete = self.is_current_user_admin() + allow_hide = self.is_current_user_admin() - def _render(sr: SavedRun): - url = str( - furl( - self.app_url(), query_params={EXAMPLE_ID_QUERY_PARAM: sr.example_id} - ) - ) - self._render_doc_example( - allow_delete=allow_delete, - doc=sr.to_dict(), - url=url, - query_params=dict(example_id=sr.example_id), + def _render(pr: PublishedRun): + self._render_example_preview( + published_run=pr, + allow_hide=allow_hide, ) - example_runs = SavedRun.objects.filter( + example_runs = PublishedRun.objects.filter( + workflow=self.workflow, + visibility=PublishedRunVisibility.PUBLIC, + is_approved_example=True, + ).exclude(published_run_id="")[:50] + + grid_layout(3, example_runs, _render, column_props=dict(className="mb-0 pb-0")) + + def _saved_tab(self): + self.ensure_authentication() + + published_runs = PublishedRun.objects.filter( workflow=self.workflow, - hidden=False, - example_id__isnull=False, + created_by=self.request.user, )[:50] + if not published_runs: + st.write("No published runs yet") + return + + def _render(pr: PublishedRun): + self._render_published_run_preview(published_run=pr) + + grid_layout(3, published_runs, _render) - grid_layout(3, example_runs, _render) + def ensure_authentication(self): + if not self.request.user or self.request.user.is_anonymous: + redirect_url = furl( + "/login", query_params={"next": furl(self.request.url).set(origin=None)} + ) + raise RedirectException(str(redirect_url)) def _history_tab(self): assert self.request, "request must be set to render history tab" @@ -985,25 +1618,7 @@ def _history_tab(self): st.write("No history yet") return - def _render(sr: SavedRun): - url = str( - furl( - self.app_url(), - query_params={ - RUN_ID_QUERY_PARAM: sr.run_id, - USER_ID_QUERY_PARAM: uid, - }, - ) - ) - - self._render_doc_example( - allow_delete=False, - doc=sr.to_dict(), - url=url, - query_params=dict(run_id=sr.run_id, uid=uid), - ) - - grid_layout(3, run_history, _render) + grid_layout(3, run_history, self._render_run_preview) next_url = ( furl(self._get_current_app_url(), query_params=self.request.query_params) @@ -1019,52 +1634,115 @@ def _render(sr: SavedRun): f"""""" ) - def _render_doc_example( - self, *, allow_delete: bool, doc: dict, url: str, query_params: dict - ): - with st.link(to=url): - st.html( - # language=HTML - f"""""" - ) - copy_to_clipboard_button("๐Ÿ”— Copy URL", value=url) - if allow_delete: - self._example_delete_button(**query_params, doc=doc) + def _render_run_preview(self, saved_run: SavedRun): + published_run: PublishedRun | None = ( + saved_run.parent_version.published_run if saved_run.parent_version else None + ) + is_latest_version = published_run and published_run.saved_run == saved_run + tb = get_title_breadcrumbs(self, sr=saved_run, pr=published_run) + + with st.link(to=saved_run.get_app_url()): + with st.div(className="mb-1", style={"font-size": "0.9rem"}): + if is_latest_version: + st.html( + PublishedRunVisibility( + published_run.visibility + ).get_badge_html() + ) + + st.write(f"#### {tb.h1_title}") - updated_at = doc.get("updated_at") + updated_at = saved_run.updated_at if updated_at and isinstance(updated_at, datetime.datetime): js_dynamic_date(updated_at) - title = doc.get(StateKeys.page_title) - if title and title.strip() != self.title.strip(): - st.write("#### " + title) + if saved_run.run_status: + html_spinner(saved_run.run_status) + elif saved_run.error_msg: + st.error(saved_run.error_msg, unsafe_allow_html=True) - notes = doc.get(StateKeys.page_notes) - if ( - notes - and notes.strip() != self.preview_description(st.session_state).strip() - ): - st.write(notes) + return self.render_example(saved_run.to_dict()) + + def _render_published_run_preview(self, published_run: PublishedRun): + tb = get_title_breadcrumbs(self, published_run.saved_run, published_run) + + with st.link(to=published_run.get_app_url()): + with st.div(className="mb-1", style={"font-size": "0.9rem"}): + st.html( + PublishedRunVisibility(published_run.visibility).get_badge_html() + ) + + st.write(f"#### {tb.h1_title}") + + with st.div(className="d-flex align-items-center justify-content-between"): + with st.div(): + updated_at = published_run.updated_at + if updated_at and isinstance(updated_at, datetime.datetime): + js_dynamic_date(updated_at) + + if published_run.visibility == PublishedRunVisibility.PUBLIC: + run_icon = '' + run_count = format_number_with_suffix(published_run.get_run_count()) + st.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) + + if published_run.notes: + st.caption(published_run.notes) + + doc = published_run.saved_run.to_dict() + self.render_example(doc) + + def _render_example_preview( + self, + *, + published_run: PublishedRun, + allow_hide: bool, + ): + tb = get_title_breadcrumbs(self, published_run.saved_run, published_run) + + with st.link(to=published_run.get_app_url()): + with st.div(className="mb-1 text-truncate", style={"height": "1.5rem"}): + if published_run.created_by and self.is_user_admin( + published_run.created_by + ): + self.render_author( + published_run.created_by, image_size="20px", text_size="0.9rem" + ) + + st.write(f"#### {tb.h1_title}") + + with st.div(className="d-flex align-items-center justify-content-between"): + with st.div(): + updated_at = published_run.updated_at + if updated_at and isinstance(updated_at, datetime.datetime): + js_dynamic_date(updated_at) + + run_icon = '' + run_count = format_number_with_suffix(published_run.get_run_count()) + st.caption(f"{run_icon} {run_count} runs", unsafe_allow_html=True) + if published_run.notes: + st.caption(published_run.notes) + + if allow_hide: + self._example_hide_button(published_run=published_run) + + doc = published_run.saved_run.to_dict() self.render_example(doc) - def _example_delete_button(self, example_id, doc): + def _example_hide_button(self, published_run: PublishedRun): pressed_delete = st.button( "๐Ÿ™ˆ๏ธ Hide", - key=f"delete_example_{example_id}", + key=f"delete_example_{published_run.published_run_id}", style={"color": "red"}, ) if not pressed_delete: return - self.set_hidden(example_id=example_id, doc=doc, hidden=True) - - def set_hidden(self, *, example_id, doc, hidden: bool): - sr = self.example_doc_sr(example_id) + self.set_hidden(published_run=published_run, hidden=True) + def set_hidden(self, *, published_run: PublishedRun, hidden: bool): with st.spinner("Hiding..."): - doc[StateKeys.hidden] = hidden - sr.hidden = hidden - sr.save(update_fields=["hidden", "updated_at"]) + published_run.is_approved_example = not hidden + published_run.save() st.experimental_rerun() @@ -1074,7 +1752,8 @@ def render_example(self, state: dict): def render_steps(self): raise NotImplementedError - def preview_input(self, state: dict) -> str | None: + @classmethod + def preview_input(cls, state: dict) -> str | None: return ( state.get("text_prompt") or state.get("input_prompt") @@ -1107,7 +1786,8 @@ def run_as_api_tab(self): ) st.markdown( - f'๐Ÿ“– To learn more, take a look at our complete API' + f'๐Ÿ“– To learn more, take a look at our complete API', + unsafe_allow_html=True, ) st.write("#### ๐Ÿ“ค Example Request") @@ -1206,11 +1886,15 @@ def additional_notes(self) -> str | None: def get_cost_note(self) -> str | None: pass + @classmethod + def is_user_admin(cls, user: AppUser) -> bool: + email = user.email + return email and email in settings.ADMIN_EMAILS + def is_current_user_admin(self) -> bool: if not self.request or not self.request.user: return False - email = self.request.user.email - return email and email in settings.ADMIN_EMAILS + return self.is_user_admin(self.request.user) def is_current_user_paying(self) -> bool: return bool(self.request and self.request.user and self.request.user.is_paying) @@ -1270,6 +1954,17 @@ def err_msg_for_exc(e): return f"{type(e).__name__}: {e}" +def force_redirect(url: str): + # note: assumes sanitized URLs + st.html( + f""" + + """ + ) + + class RedirectException(Exception): def __init__(self, url, status_code=302): self.url = url @@ -1281,3 +1976,24 @@ def __init__(self, query_params: dict, status_code=303): query_params = {k: v for k, v in query_params.items() if v is not None} url = "?" + urllib.parse.urlencode(query_params) super().__init__(url, status_code) + + +class TitleValidationError(Exception): + pass + + +def format_number_with_suffix(num: int) -> str: + """ + Formats large number with a suffix. + + Ref: https://stackoverflow.com/a/45846841 + """ + num_float = float("{:.3g}".format(num)) + magnitude = 0 + while abs(num_float) >= 1000: + magnitude += 1 + num_float /= 1000.0 + return "{}{}".format( + "{:f}".format(num_float).rstrip("0").rstrip("."), + ["", "K", "M", "B", "T"][magnitude], + ) diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py new file mode 100644 index 000000000..309d32b44 --- /dev/null +++ b/daras_ai_v2/bot_integration_widgets.py @@ -0,0 +1,182 @@ +from django.core.exceptions import ValidationError +from furl import furl + +import gooey_ui as st +from bots.models import BotIntegration, Platform +from bots.models import Workflow +from daras_ai_v2 import settings +from daras_ai_v2.asr import ( + google_translate_language_selector, +) +from daras_ai_v2.field_render import field_title_desc +from recipes.BulkRunner import url_to_runs + + +def general_integration_settings(bi: BotIntegration): + from recipes.VideoBots import VideoBotsPage + + if st.session_state.get(f"_bi_reset_{bi.id}"): + st.session_state[f"_bi_user_language_{bi.id}"] = BotIntegration._meta.get_field( + "user_language" + ).default + st.session_state[ + f"_bi_show_feedback_buttons_{bi.id}" + ] = BotIntegration._meta.get_field("show_feedback_buttons").default + st.session_state[f"_bi_analysis_url_{bi.id}"] = None + + bi.streaming_enabled = st.checkbox( + "**๐Ÿ“ก Streaming Enabled**", + value=bi.streaming_enabled, + key=f"_bi_streaming_enabled_{bi.id}", + ) + bi.show_feedback_buttons = st.checkbox( + "**๐Ÿ‘๐Ÿพ ๐Ÿ‘Ž๐Ÿฝ Show Feedback Buttons**", + value=bi.show_feedback_buttons, + key=f"_bi_show_feedback_buttons_{bi.id}", + ) + st.caption( + "Users can rate and provide feedback on every copilot response if enabled." + ) + + bi.user_language = ( + google_translate_language_selector( + f""" +##### {field_title_desc(VideoBotsPage.RequestModel, 'user_language')} \\ +This will also help better understand incoming audio messages by automatically choosing the best [Speech](https://gooey.ai/speech/) model. + """, + default_value=bi.user_language, + allow_none=False, + key=f"_bi_user_language_{bi.id}", + ) + or "en" + ) + st.caption( + "Please note that this language is distinct from the one provided in the workflow settings. Hence, this allows you to integrate the same bot in many languages." + ) + + analysis_url = st.text_input( + """ + ##### ๐Ÿง  Analysis Run URL + Analyze each incoming message and the copilot's response using a Gooey.AI /LLM workflow url. Leave blank to disable. + [Learn more](https://gooey.ai/docs/guides/build-your-ai-copilot/conversation-analysis). + """, + value=bi.analysis_run and bi.analysis_run.get_app_url(), + key=f"_bi_analysis_url_{bi.id}", + ) + if analysis_url: + try: + page_cls, bi.analysis_run, _ = url_to_runs(analysis_url) + assert page_cls.workflow in [ + Workflow.COMPARE_LLM, + Workflow.VIDEO_BOTS, + Workflow.GOOGLE_GPT, + Workflow.DOC_SEARCH, + ], "We only support Compare LLM, Copilot, Google GPT and Doc Search workflows for analysis." + except Exception as e: + bi.analysis_run = None + st.error(repr(e)) + else: + bi.analysis_run = None + + pressed_update = st.button("Update") + pressed_reset = st.button("Reset", key=f"_bi_reset_{bi.id}", type="tertiary") + if pressed_update or pressed_reset: + try: + bi.full_clean() + bi.save() + except ValidationError as e: + st.error(str(e)) + + +def broadcast_input(bi: BotIntegration): + from bots.tasks import send_broadcast_msgs_chunked + from recipes.VideoBots import VideoBotsPage + + key = f"__broadcast_msg_{bi.id}" + api_docs_url = ( + furl( + settings.API_BASE_URL, + fragment_path=f"operation/{VideoBotsPage.slug_versions[0]}__broadcast", + ) + / "docs" + ) + text = st.text_area( + f""" + ##### Broadcast Message + Broadcast a message to all users of this integration using this bot account. \\ + You can also do this via the [API]({api_docs_url}). + """, + key=key + ":text", + placeholder="Type your message here...", + ) + audio = st.file_uploader( + "**๐ŸŽค Audio**", + key=key + ":audio", + help="Attach a video to this message.", + optional=True, + accept=["audio/*"], + ) + video = st.file_uploader( + "**๐ŸŽฅ Video**", + key=key + ":video", + help="Attach a video to this message.", + optional=True, + accept=["video/*"], + ) + documents = st.file_uploader( + "**๐Ÿ“„ Documents**", + key=key + ":documents", + help="Attach documents to this message.", + accept_multiple_files=True, + optional=True, + ) + + should_confirm_key = key + ":should_confirm" + confirmed_send_btn = key + ":confirmed_send" + if st.button("๐Ÿ“ข Send Broadcast", style=dict(height="3.2rem"), key=key + ":send"): + st.session_state[should_confirm_key] = True + if not st.session_state.get(should_confirm_key): + return + + convos = bi.conversations.all() + if st.session_state.get(confirmed_send_btn): + st.success("Started sending broadcast!") + st.session_state.pop(confirmed_send_btn) + st.session_state.pop(should_confirm_key) + send_broadcast_msgs_chunked( + text=text, + audio=audio, + video=video, + documents=documents, + bi=bi, + convo_qs=convos, + ) + else: + if not convos.exists(): + st.error("No users have interacted with this bot yet.", icon="โš ๏ธ") + return + st.write( + f"Are you sure? This will send a message to all {convos.count()} users that have ever interacted with this bot.\n" + ) + st.button("โœ… Yes, Send", key=confirmed_send_btn) + + +def render_bot_test_link(bi: BotIntegration): + if bi.wa_phone_number: + test_link = ( + furl("https://wa.me/", query_params={"text": "Hi"}) + / bi.wa_phone_number.as_e164 + ) + elif bi.slack_team_id: + test_link = ( + furl("https://app.slack.com/client") + / bi.slack_team_id + / bi.slack_channel_id + ) + else: + return + st.html( + f""" + ๐Ÿ“ฑ Test + """ + ) diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 6a4d6fd4e..3d106e9ae 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -3,13 +3,13 @@ import typing from urllib.parse import parse_qs +from django.db import transaction from fastapi import HTTPException, Request from furl import furl from sentry_sdk import capture_exception from app_users.models import AppUser from bots.models import ( - BotIntegration, Platform, Message, Conversation, @@ -17,11 +17,52 @@ SavedRun, ConvoState, Workflow, + MessageAttachment, ) from daras_ai_v2.asr import AsrModels, run_google_translate -from daras_ai_v2.base import BasePage +from daras_ai_v2.base import BasePage, RecipeRunState, StateKeys from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT +from daras_ai_v2.vector_search import doc_url_to_file_metadata +from gooey_ui.pubsub import realtime_subscribe from gooeysite.bg_db_conn import db_middleware +from recipes.VideoBots import VideoBotsPage, ReplyButton +from routers.api import submit_api_call + +PAGE_NOT_CONNECTED_ERROR = ( + "๐Ÿ’” Looks like you haven't connected this page to a gooey.ai workflow. " + "Please go to the Integrations Tab and connect this page." +) +RESET_KEYWORD = "reset" +RESET_MSG = "โ™ป๏ธ Sure! Let's start fresh. How can I help you?" + +DEFAULT_RESPONSE = ( + "๐Ÿค”๐Ÿค– Well that was Unexpected! I seem to be lost. Could you please try again?." +) + +INVALID_INPUT_FORMAT = ( + "โš ๏ธ Sorry! I don't understand {} messsages. Please try with text or audio." +) + +AUDIO_ASR_CONFIRMATION = """ +๐ŸŽง I heard: โ€œ{}โ€ +Working on your answerโ€ฆ +""".strip() + +ERROR_MSG = """ +`{}` + +โš ๏ธ Sorry, I ran into an error while processing your request. Please try again, or type "Reset" to start over. +""".strip() + +FEEDBACK_THUMBS_UP_MSG = "๐ŸŽ‰ What did you like about my response?" +FEEDBACK_THUMBS_DOWN_MSG = "๐Ÿค” What was the issue with the response? How could it be improved? Please send me an voice note or text me." +FEEDBACK_CONFIRMED_MSG = ( + "๐Ÿ™ Thanks! Your feedback helps us make {bot_name} better. How else can I help you?" +) + +TAPPED_SKIP_MSG = "๐ŸŒฑ Alright. What else can I help you with?" + +SLACK_MAX_SIZE = 3000 async def request_json(request: Request): @@ -43,6 +84,8 @@ class BotInterface: input_type: str language: str show_feedback_buttons: bool = False + streaming_enabled: bool = False + can_update_message: bool = False convo: Conversation recieved_msg_id: str = None input_glossary: str | None = None @@ -54,8 +97,10 @@ def send_msg( text: str | None = None, audio: str = None, video: str = None, - buttons: list = None, + buttons: list[ReplyButton] = None, + documents: list[str] = None, should_translate: bool = False, + update_msg_id: str = None, ) -> str | None: raise NotImplementedError @@ -68,13 +113,29 @@ def get_input_text(self) -> str | None: def get_input_audio(self) -> str | None: raise NotImplementedError + def get_input_images(self) -> list[str] | None: + raise NotImplementedError + + def get_input_documents(self) -> list[str] | None: + raise NotImplementedError + def nice_filename(self, mime_type: str) -> str: ext = mimetypes.guess_extension(mime_type) or "" - return f"{self.platform}_{self.input_type}_from_{self.user_id}_to_{self.bot_id}{ext}" + return f"{self.platform.name}_{self.input_type}_from_{self.user_id}_to_{self.bot_id}{ext}" def _unpack_bot_integration(self): bi = self.convo.bot_integration - if bi.saved_run: + if bi.published_run: + self.page_cls = Workflow(bi.published_run.workflow).page_cls + self.query_params = self.page_cls.clean_query_params( + example_id=bi.published_run.published_run_id, + run_id="", + uid="", + ) + saved_run = bi.published_run.saved_run + self.input_glossary = saved_run.state.get("input_glossary_document") + self.output_glossary = saved_run.state.get("output_glossary_document") + elif bi.saved_run: self.page_cls = Workflow(bi.saved_run.workflow).page_cls self.query_params = self.page_cls.clean_query_params( example_id=bi.saved_run.example_id, @@ -90,46 +151,12 @@ def _unpack_bot_integration(self): self.billing_account_uid = bi.billing_account_uid self.language = bi.user_language self.show_feedback_buttons = bi.show_feedback_buttons + self.streaming_enabled = bi.streaming_enabled def get_interactive_msg_info(self) -> tuple[str, str]: raise NotImplementedError("This bot does not support interactive messages.") -PAGE_NOT_CONNECTED_ERROR = ( - "๐Ÿ’” Looks like you haven't connected this page to a gooey.ai workflow. " - "Please go to the Integrations Tab and connect this page." -) -RESET_KEYWORD = "reset" -RESET_MSG = "โ™ป๏ธ Sure! Let's start fresh. How can I help you?" - -DEFAULT_RESPONSE = ( - "๐Ÿค”๐Ÿค– Well that was Unexpected! I seem to be lost. Could you please try again?." -) - -INVALID_INPUT_FORMAT = ( - "โš ๏ธ Sorry! I don't understand {} messsages. Please try with text or audio." -) - -AUDIO_ASR_CONFIRMATION = """ -๐ŸŽง I heard: โ€œ{}โ€ -Working on your answerโ€ฆ -""".strip() - -ERROR_MSG = """ -`{0!r}` - -โš ๏ธ Sorry, I ran into an error while processing your request. Please try again, or type "Reset" to start over. -""".strip() - -FEEDBACK_THUMBS_UP_MSG = "๐ŸŽ‰ What did you like about my response?" -FEEDBACK_THUMBS_DOWN_MSG = "๐Ÿค” What was the issue with the response? How could it be improved? Please send me an voice note or text me." -FEEDBACK_CONFIRMED_MSG = ( - "๐Ÿ™ Thanks! Your feedback helps us make {bot_name} better. How else can I help you?" -) - -TAPPED_SKIP_MSG = "๐ŸŒฑ Alright. What else can I help you with?" - - def _echo(bot, input_text): response_text = f"You said ```{input_text}```\nhttps://www.youtube.com/" if bot.get_input_audio(): @@ -159,6 +186,8 @@ def _mock_api_output(input_text): @db_middleware def _on_msg(bot: BotInterface): speech_run = None + input_images = None + input_documents = None if not bot.page_cls: bot.send_msg(text=PAGE_NOT_CONNECTED_ERROR) return @@ -194,6 +223,24 @@ def _on_msg(bot: BotInterface): return # send confirmation of asr bot.send_msg(text=AUDIO_ASR_CONFIRMATION.format(input_text)) + case "image": + input_images = bot.get_input_images() + if not input_images: + raise HTTPException( + status_code=400, detail="No image found in request." + ) + input_text = (bot.get_input_text() or "").strip() + case "document": + input_documents = bot.get_input_documents() + if not input_documents: + raise HTTPException( + status_code=400, detail="No documents found in request." + ) + filenames = ", ".join( + furl(url.strip("/")).path.segments[-1] for url in input_documents + ) + input_text = (bot.get_input_text() or "").strip() + input_text = f"Files: {filenames}\n\n{input_text}" case "text": input_text = (bot.get_input_text() or "").strip() if not input_text: @@ -221,6 +268,8 @@ def _on_msg(bot: BotInterface): _process_and_send_msg( billing_account_user=billing_account_user, bot=bot, + input_images=input_images, + input_documents=input_documents, input_text=input_text, speech_run=speech_run, ) @@ -258,51 +307,176 @@ def _process_and_send_msg( *, billing_account_user: AppUser, bot: BotInterface, + input_images: list[str] | None, + input_documents: list[str] | None, input_text: str, speech_run: str | None, ): - try: - # # mock testing - # msgs_to_save, response_audio, response_text, response_video = _echo( - # bot, input_text - # ) - # make API call to gooey bots to get the response - response_text, response_audio, response_video, msgs_to_save = _process_msg( - page_cls=bot.page_cls, - api_user=billing_account_user, - query_params=bot.query_params, - convo=bot.convo, - input_text=input_text, - user_language=bot.language, - speech_run=speech_run, - ) - except HTTPException as e: - traceback.print_exc() - capture_exception(e) - # send error msg as repsonse - bot.send_msg(text=ERROR_MSG.format(e)) - return - # this really shouldn't happen, but just in case it does, we should have a nice message - response_text = response_text or DEFAULT_RESPONSE - # send the response to the user - print(bot.show_feedback_buttons) - msg_id = bot.send_msg( - text=response_text, - audio=response_audio, - video=response_video, - buttons=_feedback_start_buttons() if bot.show_feedback_buttons else None, + # get latest messages for context (upto 100) + saved_msgs = bot.convo.messages.all().as_llm_context() + + # # mock testing + # result = _mock_api_output(input_text) + page, result, run_id, uid = submit_api_call( + page_cls=bot.page_cls, + user=billing_account_user, + request_body={ + "input_prompt": input_text, + "input_images": input_images, + "input_documents": input_documents, + "messages": saved_msgs, + "user_language": bot.language, + }, + query_params=bot.query_params, ) - if not msgs_to_save: + + if bot.show_feedback_buttons: + buttons = _feedback_start_buttons() + else: + buttons = None + + update_msg_id = None # this is the message id to update during streaming + sent_msg_id = None # this is the message id to record in the db + last_idx = 0 # this is the last index of the text sent to the user + if bot.streaming_enabled: + # subscribe to the realtime channel for updates + channel = page.realtime_channel_name(run_id, uid) + with realtime_subscribe(channel) as realtime_gen: + for state in realtime_gen: + run_state = page.get_run_state(state) + run_status = state.get(StateKeys.run_status) or "" + # check for errors + if run_state == RecipeRunState.failed: + err_msg = state.get(StateKeys.error_msg) + bot.send_msg(text=ERROR_MSG.format(err_msg)) + return # abort + if run_state != RecipeRunState.running: + break # we're done running, abort + text = state.get("output_text") and state.get("output_text")[0] + if not text: + # if no text, send the run status + if bot.can_update_message: + update_msg_id = bot.send_msg( + text=run_status, update_msg_id=update_msg_id + ) + continue # no text, wait for the next update + streaming_done = not run_status.lower().startswith("streaming") + # send the response to the user + if bot.can_update_message: + update_msg_id = bot.send_msg( + text=text.strip() + "...", + update_msg_id=update_msg_id, + buttons=buttons if streaming_done else None, + ) + last_idx = len(text) + else: + next_chunk = text[last_idx:] + last_idx = len(text) + if not next_chunk: + continue # no chunk, wait for the next update + update_msg_id = bot.send_msg( + text=next_chunk, + buttons=buttons if streaming_done else None, + ) + if streaming_done and not bot.can_update_message: + # if we send the buttons, this is the ID we need to record in the db for lookups later when the button is pressed + sent_msg_id = update_msg_id + # don't show buttons again + buttons = None + if streaming_done: + break # we're done streaming, abort + + # wait for the celery task to finish + result.get(disable_sync_subtasks=False) + # get the final state from db + state = page.run_doc_sr(run_id, uid).to_dict() + # check for errors + err_msg = state.get(StateKeys.error_msg) + if err_msg: + bot.send_msg(text=ERROR_MSG.format(err_msg)) return - # save the message id for the sent message - if msg_id: - msgs_to_save[-1].platform_msg_id = msg_id - # save the message id for the received message - if bot.recieved_msg_id: - msgs_to_save[0].platform_msg_id = bot.recieved_msg_id - # save the messages - for msg in msgs_to_save: - msg.save() + + text = (state.get("output_text") and state.get("output_text")[0]) or "" + audio = state.get("output_audio") and state.get("output_audio")[0] + video = state.get("output_video") and state.get("output_video")[0] + documents = state.get("output_documents") or [] + # check for empty response + if not (text or audio or video or documents or buttons): + bot.send_msg(text=DEFAULT_RESPONSE) + return + # if in-place updates are enabled, update the message, otherwise send the remaining text + if not bot.can_update_message: + text = text[last_idx:] + # send the response to the user if there is any remaining + if text or audio or video or documents or buttons: + update_msg_id = bot.send_msg( + text=text, + audio=audio, + video=video, + documents=documents, + buttons=buttons, + update_msg_id=update_msg_id, + ) + + # save msgs to db + _save_msgs( + bot=bot, + input_images=input_images, + input_documents=input_documents, + input_text=input_text, + speech_run=speech_run, + platform_msg_id=sent_msg_id or update_msg_id, + response=VideoBotsPage.ResponseModel.parse_obj(state), + url=page.app_url(run_id=run_id, uid=uid), + ) + + +def _save_msgs( + bot: BotInterface, + input_images: list[str] | None, + input_documents: list[str] | None, + input_text: str, + speech_run: str | None, + platform_msg_id: str | None, + response: VideoBotsPage.ResponseModel, + url: str, +): + # create messages for future context + user_msg = Message( + platform_msg_id=bot.recieved_msg_id, + conversation=bot.convo, + role=CHATML_ROLE_USER, + content=response.raw_input_text, + display_content=input_text, + saved_run=SavedRun.objects.get_or_create( + workflow=Workflow.ASR, **furl(speech_run).query.params + )[0] + if speech_run + else None, + ) + attachments = [] + for f_url in (input_images or []) + (input_documents or []): + metadata = doc_url_to_file_metadata(f_url) + attachments.append( + MessageAttachment(message=user_msg, url=f_url, metadata=metadata) + ) + assistant_msg = Message( + platform_msg_id=platform_msg_id, + conversation=bot.convo, + role=CHATML_ROLE_ASSISTANT, + content=response.raw_output_text and response.raw_output_text[0], + display_content=response.output_text and response.output_text[0], + saved_run=SavedRun.objects.get_or_create( + workflow=Workflow.VIDEO_BOTS, **furl(url).query.params + )[0], + ) + # save the messages & attachments + with transaction.atomic(): + user_msg.save() + for attachment in attachments: + attachment.metadata.save() + attachment.save() + assistant_msg.save() def _handle_interactive_msg(bot: BotInterface): @@ -404,102 +578,20 @@ class ButtonIds: feedback_thumbs_down = "FEEDBACK_THUMBS_DOWN" -def _feedback_post_click_buttons(): +def _feedback_post_click_buttons() -> list[ReplyButton]: """ Buttons to show after the user has clicked on a feedback button """ return [ - { - "type": "reply", - "reply": {"id": ButtonIds.action_skip, "title": "๐Ÿ”€ Skip"}, - }, + {"id": ButtonIds.action_skip, "title": "๐Ÿ”€ Skip"}, ] -def _feedback_start_buttons(): +def _feedback_start_buttons() -> list[ReplyButton]: """ Buttons to show for collecting feedback after the bot has sent a response """ return [ - { - "type": "reply", - "reply": {"id": ButtonIds.feedback_thumbs_up, "title": "๐Ÿ‘๐Ÿพ"}, - }, - { - "type": "reply", - "reply": {"id": ButtonIds.feedback_thumbs_down, "title": "๐Ÿ‘Ž๐Ÿฝ"}, - }, - ] - - -def _process_msg( - *, - page_cls, - api_user: AppUser, - query_params: dict, - convo: Conversation, - input_text: str, - user_language: str, - speech_run: str | None, -) -> tuple[str, str | None, str | None, list[Message]]: - from routers.api import call_api - - # get latest messages for context (upto 100) - saved_msgs = list( - reversed( - convo.messages.order_by("-created_at").values("role", "content")[:100], - ), - ) - - # # mock testing - # result = _mock_api_output(input_text) - - # call the api with provided input - result = call_api( - page_cls=page_cls, - user=api_user, - request_body={ - "input_prompt": input_text, - "messages": saved_msgs, - "user_language": user_language, - }, - query_params=query_params, - ) - - # extract response video/audio/text - try: - response_video = result["output"]["output_video"][0] - except (KeyError, IndexError): - response_video = None - try: - response_audio = result["output"]["output_audio"][0] - except (KeyError, IndexError): - response_audio = None - raw_input_text = result["output"]["raw_input_text"] - output_text = result["output"]["output_text"][0] - raw_output_text = result["output"]["raw_output_text"][0] - response_text = result["output"]["output_text"][0] - # save new messages for future context - msgs_to_save = [ - Message( - conversation=convo, - role=CHATML_ROLE_USER, - content=raw_input_text, - display_content=input_text, - saved_run=SavedRun.objects.get_or_create( - workflow=Workflow.ASR, **furl(speech_run).query.params - )[0] - if speech_run - else None, - ), - Message( - conversation=convo, - role=CHATML_ROLE_ASSISTANT, - content=raw_output_text, - display_content=output_text, - saved_run=SavedRun.objects.get_or_create( - workflow=Workflow.VIDEO_BOTS, **furl(result.get("url", "")).query.params - )[0], - ), + {"id": ButtonIds.feedback_thumbs_up, "title": "๐Ÿ‘๐Ÿพ"}, + {"id": ButtonIds.feedback_thumbs_down, "title": "๐Ÿ‘Ž๐Ÿฝ"}, ] - return response_text, response_audio, response_video, msgs_to_save diff --git a/daras_ai_v2/breadcrumbs.py b/daras_ai_v2/breadcrumbs.py new file mode 100644 index 000000000..96dc919e5 --- /dev/null +++ b/daras_ai_v2/breadcrumbs.py @@ -0,0 +1,121 @@ +import typing + +import gooey_ui as st +from bots.models import ( + SavedRun, + PublishedRun, +) +from daras_ai.image_input import truncate_text_words +from daras_ai_v2.tabs_widget import MenuTabs + +if typing.TYPE_CHECKING: + from daras_ai_v2.base import BasePage + + +class TitleUrl(typing.NamedTuple): + title: str + url: str + + +class TitleBreadCrumbs(typing.NamedTuple): + """ + Breadcrumbs: root_title / published_title + Title: h1_title + """ + + h1_title: str + root_title: TitleUrl | None + published_title: TitleUrl | None + + def has_breadcrumbs(self): + return bool(self.root_title or self.published_title) + + +def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs): + st.html( + """ + + """ + ) + + if not (breadcrumbs.root_title or breadcrumbs.published_title): + # avoid empty space when breadcrumbs are not rendered + return + + with st.breadcrumbs(): + if breadcrumbs.root_title: + st.breadcrumb_item( + breadcrumbs.root_title.title, + link_to=breadcrumbs.root_title.url, + className="text-muted", + ) + if breadcrumbs.published_title: + st.breadcrumb_item( + breadcrumbs.published_title.title, + link_to=breadcrumbs.published_title.url, + ) + + +def get_title_breadcrumbs( + page_cls: typing.Union["BasePage", typing.Type["BasePage"]], + sr: SavedRun, + pr: PublishedRun | None, + tab: str = MenuTabs.run, +) -> TitleBreadCrumbs: + is_root = pr and pr.saved_run == sr and pr.is_root() + is_example = not is_root and pr and pr.saved_run == sr + is_run = not is_root and not is_example + + recipe_title = page_cls.get_recipe_title() + prompt_title = truncate_text_words( + page_cls.preview_input(sr.to_dict()) or "", + maxlen=60, + ).replace("\n", " ") + + metadata = page_cls.workflow.get_or_create_metadata() + root_breadcrumb = TitleUrl(metadata.short_title, page_cls.app_url()) + + match tab: + case MenuTabs.examples: + return TitleBreadCrumbs( + f"Examples: {metadata.short_title}", + root_title=root_breadcrumb, + published_title=None, + ) + case _ if is_root: + return TitleBreadCrumbs(page_cls.get_recipe_title(), None, None) + case _ if is_example: + assert pr is not None + return TitleBreadCrumbs( + pr.title or prompt_title or recipe_title, + root_title=root_breadcrumb, + published_title=None, + ) + case _ if is_run: + if pr and not pr.is_root(): + published_title = TitleUrl( + pr.title or f"Fork: {pr.published_run_id}", + pr.get_app_url(), + ) + else: + published_title = None + return TitleBreadCrumbs( + prompt_title or f"Run: {recipe_title}", + root_title=root_breadcrumb, + published_title=published_title, + ) + case _: + raise AssertionError("Invalid tab or run") diff --git a/daras_ai_v2/copy_to_clipboard_button_widget.py b/daras_ai_v2/copy_to_clipboard_button_widget.py index c3efbe47a..55222edf2 100644 --- a/daras_ai_v2/copy_to_clipboard_button_widget.py +++ b/daras_ai_v2/copy_to_clipboard_button_widget.py @@ -1,3 +1,4 @@ +import typing import gooey_ui as gui # language="html" @@ -5,10 +6,10 @@ @@ -20,16 +21,18 @@ def copy_to_clipboard_button( *, value: str, style: str = "", + className: str = "", + type: typing.Literal["primary", "secondary", "tertiary", "link"] = "primary", ): return gui.html( # language="html" f""" - """, diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index 4351fca26..b240b482f 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -2,7 +2,6 @@ import typing import gooey_ui as st - from daras_ai_v2 import settings from daras_ai_v2.asr import AsrModels, google_translate_language_selector from daras_ai_v2.enum_selector_widget import enum_selector @@ -80,7 +79,8 @@ def document_uploader( def doc_search_settings( - asr_allowed: bool = True, keyword_instructions_allowed: bool = False + asr_allowed: bool = False, + keyword_instructions_allowed: bool = False, ): from daras_ai_v2.vector_search import DocSearchRequest @@ -95,6 +95,26 @@ def doc_search_settings( allow_none=True, ) + st.text_area( + """ +###### ๐Ÿ‘โ€๐Ÿ—จ Summarization Instructions +Prompt to transform the conversation history into a vector search query. \\ +These instructions run before the workflow performs a search of the knowledge base documents and should summarize the conversation into a VectorDB query most relevant to the user's last message. In general, you shouldn't need to adjust these instructions. + """, + key="query_instructions", + height=300, + ) + if keyword_instructions_allowed: + st.text_area( + """ +###### ๐Ÿ”‘ Keyword Extraction +Prompt to extract a query for hybrid BM25 search. \\ +These instructions run after the Summarization Instructions above and can use its result via `{{ final_search_query }}`. In general, you shouldn't need to adjust these instructions. + """, + key="keyword_instructions", + height=300, + ) + dense_weight_ = DocSearchRequest.__fields__["dense_weight"] st.slider( label=f"###### {dense_weight_.field_info.title}\n{dense_weight_.field_info.description}", @@ -117,7 +137,7 @@ def doc_search_settings( label=""" ###### Max Snippet Words -After a document search, relevant snippets of your documents are returned as results. This setting adjusts the maximum number of words in each snippet. A high snippet size allows the LLM to access more information from your document results, at the cost of being verbose and potentially exhausting input tokens (which can cause a failure of the copilot to respond). Default: 300 +After a document search, relevant snippets of your documents are returned as results. This setting adjusts the maximum number of words in each snippet. A high snippet size allows the LLM to access more information from your document results, at the cost of being verbose and potentially exhausting input tokens (which can cause a failure of the copilot to respond). Default: 300 """, key="max_context_words", min_value=10, @@ -135,33 +155,16 @@ def doc_search_settings( max_value=50, ) - st.text_area( - """ -###### ๐Ÿ‘โ€๐Ÿ—จ Summarization Instructions -Prompt to transform the conversation history into a vector search query. -These instructions run before the workflow performs a search of the knowledge base documents and should summarize the conversation into a VectorDB query most relevant to the user's last message. In general, you shouldn't need to adjust these instructions. - """, - key="query_instructions", - height=300, - ) - if keyword_instructions_allowed: - st.text_area( - """ -###### ๐Ÿ”‘ Keyword Extraction - """, - key="keyword_instructions", - height=300, - ) - if not asr_allowed: return st.write("---") st.write( """ - ##### ๐ŸŽค Knowledge Base Speech Recognition - If your knowledge base documents contain audio or video files, we'll transcribe and optionally translate them to English, given we've found most vectorDBs and LLMs perform best in English (even if their final answers are translated into another language). - """ + ##### ๐ŸŽค Knowledge Base Speech Recognition + If your knowledge base documents contain audio or video files, we'll transcribe and optionally translate them to English, given we've found most vectorDBs and LLMs perform best in English (even if their final answers are translated into another language). + """, + unsafe_allow_html=True, ) enum_selector( diff --git a/daras_ai_v2/enum_selector_widget.py b/daras_ai_v2/enum_selector_widget.py index 098287f0d..6fcca398a 100644 --- a/daras_ai_v2/enum_selector_widget.py +++ b/daras_ai_v2/enum_selector_widget.py @@ -1,4 +1,5 @@ import enum +import typing from typing import TypeVar, Type import gooey_ui as st @@ -35,7 +36,7 @@ def render(e): if inner_key not in st.session_state: st.session_state[inner_key] = e.name in selected - st.checkbox(e.value, key=inner_key) + st.checkbox(_format_func(enum_cls)(e.name), key=inner_key) if st.session_state.get(inner_key): selected.add(e.name) @@ -49,7 +50,7 @@ def render(e): else: return st.multiselect( options=[e.name for e in enums], - format_func=lambda k: enum_cls[k].value, + format_func=_format_func(enum_cls), label=label, key=key, allow_none=allow_none, @@ -69,7 +70,6 @@ def enum_selector( except AttributeError: deprecated = set() enums = [e for e in enum_cls if not e in deprecated] - label = label or enum_cls.__name__ options = [e.name for e in enums] if exclude: options = [o for o in options if o not in exclude] @@ -82,8 +82,19 @@ def enum_selector( return widget( **kwargs, options=options, - format_func=lambda k: getattr(enum_cls[k], "label", enum_cls[k].value) - if k - else "โ€”โ€”โ€”", + format_func=_format_func(enum_cls), label=label, ) + + +def _format_func(enum_cls: E) -> typing.Callable[[str], str]: + def _format(k): + if not k: + return "โ€”โ€”โ€”" + e = enum_cls[k] + try: + return e.label + except AttributeError: + return e.value + + return _format diff --git a/daras_ai_v2/exceptions.py b/daras_ai_v2/exceptions.py new file mode 100644 index 000000000..5d58145a4 --- /dev/null +++ b/daras_ai_v2/exceptions.py @@ -0,0 +1,43 @@ +from logging import getLogger + +import requests +from requests import HTTPError +from requests.exceptions import JSONDecodeError + + +logger = getLogger(__name__) + + +def raise_for_status(resp: requests.Response): + """Raises :class:`HTTPError`, if one occurred.""" + + http_error_msg = "" + if isinstance(resp.reason, bytes): + # We attempt to decode utf-8 first because some servers + # choose to localize their reason strings. If the string + # isn't utf-8, we fall back to iso-8859-1 for all other + # encodings. (See PR #3538) + try: + reason = resp.reason.decode("utf-8") + except UnicodeDecodeError: + reason = resp.reason.decode("iso-8859-1") + else: + reason = resp.reason + + try: + response_body = str(resp.json()) + except JSONDecodeError: + try: + response_body = resp.text + except ValueError: + response_body = resp.content + response_body = response_body[:500] # truncate to at max 500 characters + + if 400 <= resp.status_code < 500: + http_error_msg = f"{resp.status_code} Client Error: {reason} | URL: {resp.url} | Response: {response_body!r}" + + elif 500 <= resp.status_code < 600: + http_error_msg = f"{resp.status_code} Server Error: {reason} | URL: {resp.url} | Response: {response_body!r}" + + if http_error_msg: + raise HTTPError(http_error_msg, response=resp) diff --git a/daras_ai_v2/facebook_bots.py b/daras_ai_v2/facebook_bots.py index c2a785867..abf33b8a0 100644 --- a/daras_ai_v2/facebook_bots.py +++ b/daras_ai_v2/facebook_bots.py @@ -1,11 +1,13 @@ import requests +from furl import furl from bots.models import BotIntegration, Platform, Conversation from daras_ai.image_input import upload_file_from_bytes, get_mimetype_from_response from daras_ai_v2 import settings from daras_ai_v2.asr import run_google_translate, audio_bytes_to_wav +from daras_ai_v2.bots import BotInterface, ReplyButton +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.text_splitter import text_splitter -from daras_ai_v2.bots import BotInterface WA_MSG_MAX_SIZE = 1024 @@ -15,6 +17,8 @@ class WhatsappBot(BotInterface): + platform = Platform.WHATSAPP + def __init__(self, message: dict, metadata: dict): self.input_message = message self.platform = Platform.WHATSAPP @@ -24,12 +28,6 @@ def __init__(self, message: dict, metadata: dict): self.input_type = message["type"] - # if the message has a caption, treat it as text - caption = self._get_caption() - if caption: - self.input_type = "text" - self.input_message["text"] = {"body": caption} - bi = BotIntegration.objects.get(wa_phone_number_id=self.bot_id) self.convo = Conversation.objects.get_or_create( bot_integration=bi, @@ -41,10 +39,11 @@ def get_input_text(self) -> str | None: try: return self.input_message["text"]["body"] except KeyError: - return None - - def _get_caption(self): - return self.input_message.get(self.input_type, {}).get("caption") + pass + try: + return self.input_message[self.input_type]["caption"] + except KeyError: + pass def get_input_audio(self) -> str | None: try: @@ -54,9 +53,6 @@ def get_input_audio(self) -> str | None: media_id = self.input_message["video"]["id"] except KeyError: return None - return self._download_wa_media(media_id) - - def _download_wa_media(self, media_id: str) -> str: # download file from whatsapp data, mime_type = retrieve_wa_media_by_id(media_id) data, _ = audio_bytes_to_wav(data) @@ -68,6 +64,30 @@ def _download_wa_media(self, media_id: str) -> str: content_type=mime_type, ) + def get_input_images(self) -> list[str] | None: + try: + media_id = self.input_message["image"]["id"] + except KeyError: + return None + return [self._download_wa_media(media_id)] + + def get_input_documents(self) -> list[str] | None: + try: + media_id = self.input_message["document"]["id"] + except KeyError: + return None + return [self._download_wa_media(media_id)] + + def _download_wa_media(self, media_id: str) -> str: + # download file from whatsapp + data, mime_type = retrieve_wa_media_by_id(media_id) + # upload file to firebase + return upload_file_from_bytes( + filename=self.nice_filename(mime_type), + data=data, + content_type=mime_type, + ) + def get_interactive_msg_info(self) -> tuple[str, str]: button_id = self.input_message["interactive"]["button_reply"]["id"] context_msg_id = self.input_message["context"]["id"] @@ -79,152 +99,147 @@ def send_msg( text: str = None, audio: str = None, video: str = None, - buttons: list = None, + buttons: list[ReplyButton] = None, + documents: list[str] = None, should_translate: bool = False, + update_msg_id: str = None, ) -> str | None: - if should_translate and self.language and self.language != "en": + if text and should_translate and self.language and self.language != "en": text = run_google_translate( [text], self.language, glossary_url=self.output_glossary )[0] - return send_wa_msg( + return self.send_msg_to( bot_number=self.bot_id, user_number=self.user_id, - response_text=text, - response_audio=audio, - response_video=video, + text=text, + audio=audio, + video=video, + documents=documents, buttons=buttons, ) def mark_read(self): wa_mark_read(self.bot_id, self.input_message["id"]) + @classmethod + def send_msg_to( + cls, + *, + text: str = None, + audio: str = None, + video: str = None, + documents: list[str] = None, + buttons: list[ReplyButton] = None, + ## whatsapp specific + bot_number: str, + user_number: str, + ) -> str | None: + # see https://developers.facebook.com/docs/whatsapp/api/messages/media/ -def send_wa_msg( - *, - bot_number: str, - user_number: str, - response_text: str, - response_audio: str = None, - response_video: str = None, - buttons: list = None, -) -> str | None: - # split text into chunks if too long - if len(response_text) > WA_MSG_MAX_SIZE: - splits = text_splitter( - response_text, chunk_size=WA_MSG_MAX_SIZE, length_function=len - ) - # preserve last chunk for later - response_text = splits[-1].text - # send all but last chunk - send_wa_msgs_raw( - bot_number=bot_number, - user_number=user_number, - messages=[ - # simple text msg - { - "type": "text", - "text": { - "body": doc.text, - "preview_url": True, - }, - } - for doc in splits[:-1] - ], - ) - - if response_video: - if buttons: - messages = [ - # interactive text msg + video in header - { - "body": { - "text": response_text, - }, - "header": { - "type": "video", - "video": {"link": response_video}, - }, - }, - ] - else: - messages = [ - # simple video msg + text caption - { - "type": "video", - "video": { - "link": response_video, - "caption": response_text, - }, - }, - ] - elif response_audio: - if buttons: - # audio can't be sent as an interaction, so send text and audio separately - messages = [ - # simple audio msg - { - "type": "audio", - "audio": {"link": response_audio}, - }, - ] + # split text into chunks if too long + if text and len(text) > WA_MSG_MAX_SIZE: + splits = text_splitter( + text, chunk_size=WA_MSG_MAX_SIZE, length_function=len + ) + # preserve last chunk for later + text = splits[-1].text + # send all but last chunk send_wa_msgs_raw( bot_number=bot_number, user_number=user_number, - messages=messages, + messages=[ + # simple text msg + { + "type": "text", + "text": { + "body": doc.text, + "preview_url": True, + }, + } + for doc in splits[:-1] + ], ) + + messages = [] + if video: + if buttons: + messages = [ + # interactive text msg + video in header + _build_msg_buttons( + buttons, + { + "body": { + "text": text or "\u200b", + }, + "header": { + "type": "video", + "video": {"link": video}, + }, + }, + ), + ] + else: + messages = [ + # simple video msg + text caption + { + "type": "video", + "video": { + "link": video, + "caption": text, + }, + }, + ] + elif buttons: + # interactive text msg messages = [ - # interactive text msg - { - "body": { - "text": response_text, + _build_msg_buttons( + buttons, + { + "body": { + "text": text or "\u200b", + } }, - }, + ), ] - else: - # audio doesn't support captions, so send text and audio separately + elif text: + # simple text msg messages = [ - # simple text msg { "type": "text", "text": { - "body": response_text, + "body": text, "preview_url": True, }, }, - # simple audio msg - { - "type": "audio", - "audio": {"link": response_audio}, - }, ] - else: - # text message - if buttons: - messages = [ - # interactive text msg + + if audio and not video: # video already has audio + # simple audio msg + messages.append( { - "body": { - "text": response_text, - } - }, - ] - else: - messages = [ - # simple text msg + "type": "audio", + "audio": {"link": audio}, + } + ) + + if documents: + messages += [ + # simple document msg { - "type": "text", - "text": { - "body": response_text, - "preview_url": True, + "type": "document", + "document": { + "link": link, + "filename": furl(link).path.segments[-1], }, - }, + } + for link in documents ] - return send_wa_msgs_raw( - bot_number=bot_number, - user_number=user_number, - messages=messages, - buttons=buttons, - ) + + return send_wa_msgs_raw( + bot_number=bot_number, + user_number=user_number, + messages=messages, + ) def retrieve_wa_media_by_id(media_id: str) -> (bytes, str): @@ -233,48 +248,55 @@ def retrieve_wa_media_by_id(media_id: str) -> (bytes, str): f"https://graph.facebook.com/v16.0/{media_id}/", headers=WHATSAPP_AUTH_HEADER, ) - r1.raise_for_status() + raise_for_status(r1) media_info = r1.json() # download media r2 = requests.get( media_info["url"], headers=WHATSAPP_AUTH_HEADER, ) - r2.raise_for_status() + raise_for_status(r2) content = r2.content # return content and mime type return content, media_info["mime_type"] -def send_wa_msgs_raw( - *, bot_number, user_number, messages: list, buttons: list = None -) -> str | None: +def _build_msg_buttons(buttons: list[ReplyButton], msg: dict) -> dict: + return { + "type": "interactive", + "interactive": { + "type": "button", + **msg, + "action": { + "buttons": [ + { + "type": "reply", + "reply": {"id": button["id"], "title": button["title"]}, + } + for button in buttons + ], + }, + }, + } + + +def send_wa_msgs_raw(*, bot_number, user_number, messages: list) -> str | None: msg_id = None for msg in messages: - body = { - "messaging_product": "whatsapp", - "to": user_number, - "preview_url": True, - } - if buttons: - body |= { - "type": "interactive", - "interactive": { - "type": "button", - **msg, - "action": {"buttons": buttons}, - }, - } - else: - body |= msg + print(f"send_wa_msgs_raw: {msg=}") r = requests.post( f"https://graph.facebook.com/v16.0/{bot_number}/messages", headers=WHATSAPP_AUTH_HEADER, - json=body, + json={ + "messaging_product": "whatsapp", + "to": user_number, + "preview_url": True, + **msg, + }, ) confirmation = r.json() print("send_wa_msgs_raw:", r.status_code, confirmation) - r.raise_for_status() + raise_for_status(r) try: msg_id = confirmation["messages"][0]["id"] except (KeyError, IndexError): @@ -338,13 +360,16 @@ def send_msg( text: str = None, audio: str = None, video: str = None, - buttons: list = None, + buttons: list[ReplyButton] = None, + documents: list[str] = None, should_translate: bool = False, + update_msg_id: str = None, ) -> str | None: - if should_translate and self.language and self.language != "en": + if text and should_translate and self.language and self.language != "en": text = run_google_translate( [text], self.language, glossary_url=self.output_glossary )[0] + text = text or "\u200b" # handle empty text with zero-width space return send_fb_msg( access_token=self._access_token, bot_id=self.bot_id, @@ -379,7 +404,7 @@ def get_input_audio(self) -> str | None: return None # downlad file from facebook r = requests.get(url) - r.raise_for_status() + raise_for_status(r) # ensure file is audio/video mime_type = get_mimetype_from_response(r) assert ( diff --git a/daras_ai_v2/field_render.py b/daras_ai_v2/field_render.py new file mode 100644 index 000000000..1b79c1673 --- /dev/null +++ b/daras_ai_v2/field_render.py @@ -0,0 +1,13 @@ +import typing + +from pydantic import BaseModel + + +def field_title_desc(model: typing.Type[BaseModel], name: str) -> str: + field = model.__fields__[name] + return "\n".join( + filter( + None, + [field.field_info.title, field.field_info.description or ""], + ) + ) diff --git a/daras_ai_v2/functional.py b/daras_ai_v2/functional.py index 625e9c026..f8415ec87 100644 --- a/daras_ai_v2/functional.py +++ b/daras_ai_v2/functional.py @@ -8,8 +8,8 @@ def flatapply_parallel( - fn: typing.Callable[[T], list[R]], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., list[R]], + *iterables, max_workers: int = None, message: str = "", ) -> typing.Generator[str, None, list[R]]: @@ -20,8 +20,8 @@ def flatapply_parallel( def apply_parallel( - fn: typing.Callable[[T], R], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., R], + *iterables, max_workers: int = None, message: str = "", ) -> typing.Generator[str, None, list[R]]: @@ -42,8 +42,8 @@ def apply_parallel( def fetch_parallel( - fn: typing.Callable[[T], R], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., R], + *iterables, max_workers: int = None, ) -> typing.Generator[R, None, None]: assert iterables, "fetch_parallel() requires at least one iterable" @@ -57,16 +57,16 @@ def fetch_parallel( def flatmap_parallel( - fn: typing.Callable[[T], list[R]], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., list[R]], + *iterables, max_workers: int = None, ) -> list[R]: return flatten(map_parallel(fn, *iterables, max_workers=max_workers)) def map_parallel( - fn: typing.Callable[[T], R], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., R], + *iterables, max_workers: int = None, ) -> list[R]: assert iterables, "map_parallel() requires at least one iterable" diff --git a/daras_ai_v2/functions.py b/daras_ai_v2/functions.py new file mode 100644 index 000000000..2c6c03348 --- /dev/null +++ b/daras_ai_v2/functions.py @@ -0,0 +1,70 @@ +import json +import tempfile +import typing +from enum import Enum + +from daras_ai.image_input import upload_file_from_bytes +from daras_ai_v2.settings import templates + + +def json_to_pdf(filename: str, data: str) -> str: + html = templates.get_template("form_output.html").render(data=json.loads(data)) + pdf_bytes = html_to_pdf(html) + if not filename.endswith(".pdf"): + filename += ".pdf" + return upload_file_from_bytes(filename, pdf_bytes, "application/pdf") + + +def html_to_pdf(html: str) -> bytes: + from playwright.sync_api import sync_playwright + + with sync_playwright() as p: + browser = p.chromium.launch() + page = browser.new_page() + page.set_content(html) + with tempfile.NamedTemporaryFile(suffix=".pdf") as outfile: + page.pdf(path=outfile.name, format="A4") + ret = outfile.read() + browser.close() + + return ret + + +class LLMTools(Enum): + json_to_pdf = ( + json_to_pdf, + "Save JSON as PDF", + { + "type": "function", + "function": { + "name": json_to_pdf.__name__, + "description": "Save JSON data to PDF", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "A short but descriptive filename for the PDF", + }, + "data": { + "type": "string", + "description": "The JSON data to write to the PDF", + }, + }, + "required": ["filename", "data"], + }, + }, + }, + ) + # send_reply_buttons = (print, "Send back reply buttons to the user.", {}) + + def __new__(cls, fn: typing.Callable, label: str, spec: dict): + obj = object.__new__(cls) + obj._value_ = fn.__name__ + obj.fn = fn + obj.label = label + obj.spec = spec + return obj + + # def __init__(self, *args, **kwargs): + # self._value_ = self.name diff --git a/daras_ai_v2/glossary.py b/daras_ai_v2/glossary.py index 98e135b5d..2e56352da 100644 --- a/daras_ai_v2/glossary.py +++ b/daras_ai_v2/glossary.py @@ -3,16 +3,46 @@ from daras_ai_v2.doc_search_settings_widgets import document_uploader +def validate_glossary_document(document: str): + """ + Throws AssertionError for the most common errors in a glossary document. + I.e. the glossary must have at least 2 columns, top row must be language codes or "description" or "pos" + """ + import langcodes + from daras_ai_v2.vector_search import ( + download_content_bytes, + bytes_to_df, + doc_url_to_metadata, + ) + + metadata = doc_url_to_metadata(document) + f_bytes, ext = download_content_bytes(f_url=document, mime_type=metadata.name) + df = bytes_to_df(f_name=metadata.name, f_bytes=f_bytes, ext=ext) + + if len(df.columns) < 2: + raise AssertionError( + f"Invalid glossary: must have at least 2 columns, but has {len(df.columns)}." + ) + for col in df.columns: + if col not in ["description", "pos"]: + try: + langcodes.Language.get(col).language + except langcodes.LanguageTagError: + raise AssertionError( + f'Invalid glossary: column header "{col}" is not a valid language code.' + ) + + def glossary_input( label: str = "##### Glossary", key: str = "glossary_document", -): +) -> str: return document_uploader( label=label, key=key, accept=[".csv", ".xlsx", ".xls", ".gsheet", ".ods", ".tsv"], accept_multiple_files=False, - ) + ) # type: ignore def create_glossary( diff --git a/daras_ai_v2/gpu_server.py b/daras_ai_v2/gpu_server.py index d4c47a139..fa638da89 100644 --- a/daras_ai_v2/gpu_server.py +++ b/daras_ai_v2/gpu_server.py @@ -8,6 +8,7 @@ from daras_ai.image_input import storage_blob_for from daras_ai_v2 import settings +from daras_ai_v2.exceptions import raise_for_status class GpuEndpoints: @@ -32,7 +33,7 @@ def call_gpu_server(*, endpoint: str, input_data: dict) -> typing.Any: f"{endpoint}/predictions", json={"input": input_data}, ) - r.raise_for_status() + raise_for_status(r) return r.json()["output"] @@ -97,7 +98,7 @@ def call_gooey_gpu( str(endpoint), json={"pipeline": pipeline, "inputs": inputs}, ) - r.raise_for_status() + raise_for_status(r) return [blob.public_url for blob in blobs] diff --git a/daras_ai_v2/grid_layout_widget.py b/daras_ai_v2/grid_layout_widget.py index 30e5a3ca5..4e41a38f5 100644 --- a/daras_ai_v2/grid_layout_widget.py +++ b/daras_ai_v2/grid_layout_widget.py @@ -3,10 +3,21 @@ import gooey_ui as st -def grid_layout(column_spec, iterable: typing.Iterable, render, separator=True): +def grid_layout( + column_spec, + iterable: typing.Iterable, + render, + separator=True, + column_props: dict[str, typing.Any] | None = None, +): + # make a copy so it can be modified + column_props = dict(column_props or {}) + extra_classes = column_props.pop("className", "mb-4 pb-2") + for item, col in zip(iterable, infinte_cols(column_spec)): if separator: - col.node.props["className"] += " border-bottom mb-4 pb-2" + col.node.props["className"] += f" border-bottom " + extra_classes + col.node.props.update(column_props) with col: render(item) diff --git a/daras_ai_v2/image_segmentation.py b/daras_ai_v2/image_segmentation.py index 30a2055bc..099832979 100644 --- a/daras_ai_v2/image_segmentation.py +++ b/daras_ai_v2/image_segmentation.py @@ -2,6 +2,7 @@ import requests +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.gpu_server import ( call_celery_task_outfile, ) @@ -21,7 +22,7 @@ def u2net(input_image: str) -> bytes: filename="u2net.png", )[0] r = requests.get(url) - r.raise_for_status() + raise_for_status(r) return r.content @@ -34,5 +35,5 @@ def dis(input_image: str) -> bytes: filename="dis.png", )[0] r = requests.get(url) - r.raise_for_status() + raise_for_status(r) return r.content diff --git a/daras_ai_v2/img_model_settings_widgets.py b/daras_ai_v2/img_model_settings_widgets.py index fea120051..e86c18873 100644 --- a/daras_ai_v2/img_model_settings_widgets.py +++ b/daras_ai_v2/img_model_settings_widgets.py @@ -181,7 +181,7 @@ def controlnet_weight_setting( ) -def num_outputs_setting(selected_model: str = None): +def num_outputs_setting(selected_models: str | list[str] = None): col1, col2 = st.columns(2, gap="medium") with col1: st.slider( @@ -200,12 +200,41 @@ def num_outputs_setting(selected_model: str = None): """ ) with col2: - quality_setting(selected_model) + quality_setting(selected_models) -def quality_setting(selected_model=None): - if selected_model in [InpaintingModels.dall_e.name]: +def quality_setting(selected_models=None): + if not isinstance(selected_models, list): + selected_models = [selected_models] + if any( + [ + selected_model in [InpaintingModels.dall_e.name] + for selected_model in selected_models + ] + ): return + if any( + [ + selected_model in [Text2ImgModels.dall_e_3.name] + for selected_model in selected_models + ] + ): + st.selectbox( + """##### Dalle 3 Quality""", + options=[ + "standard", + "hd", + ], + key="dall_e_3_quality", + ) + st.selectbox( + """##### Dalle 3 Style""", + options=[ + "natural", + "vivid", + ], + key="dall_e_3_style", + ) st.slider( label=""" ##### Quality @@ -223,7 +252,10 @@ def quality_setting(selected_model=None): ) -RESOLUTIONS = { +RESOLUTIONS: dict[int, dict[str, str]] = { + 256: { + "256, 256": "square", + }, 512: { "512, 512": "square", "576, 448": "A4", @@ -247,6 +279,7 @@ def quality_setting(selected_model=None): "1536, 512": "smartphone", "1792, 512": "cinema", "2048, 512": "panorama", + "1792, 1024": "wide", }, } LANDSCAPE = "Landscape" @@ -283,10 +316,18 @@ def output_resolution_setting(): ) if not isinstance(selected_models, list): selected_models = [selected_models] + + allowed_shapes = None if "jack_qiao" in selected_models or "sd_1_4" in selected_models: pixel_options = [512] elif selected_models == ["deepfloyd_if"]: pixel_options = [1024] + elif selected_models == ["dall_e"]: + pixel_options = [256, 512, 1024] + allowed_shapes = ["square"] + elif selected_models == ["dall_e_3"]: + pixel_options = [1024] + allowed_shapes = ["square", "wide"] else: pixel_options = [512, 768] @@ -298,11 +339,16 @@ def output_resolution_setting(): options=pixel_options, ) with col2: + res_options = [ + res + for res, shape in RESOLUTIONS[pixels or pixel_options[0]].items() + if not allowed_shapes or shape in allowed_shapes + ] res = st.selectbox( "##### Resolution", key="__res", format_func=lambda r: f"{r.split(', ')[0]} x {r.split(', ')[1]} ({RESOLUTIONS[pixels][r]})", - options=list(RESOLUTIONS[pixels].keys()), + options=res_options, ) res = tuple(map(int, res.split(", "))) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 251b3f924..d49ac61ba 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -1,6 +1,8 @@ import hashlib import io +import json import re +import typing from enum import Enum from functools import partial @@ -16,13 +18,21 @@ from django.conf import settings from jinja2.lexer import whitespace_re from loguru import logger +from openai import Stream +from openai.types.chat import ( + ChatCompletionContentPartParam, + ChatCompletionChunk, +) -from daras_ai_v2.google_utils import get_google_auth_session +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.functional import map_parallel -from daras_ai_v2.redis_cache import ( - get_redis_cache, +from daras_ai_v2.functions import LLMTools +from daras_ai_v2.google_utils import get_google_auth_session +from daras_ai_v2.redis_cache import get_redis_cache +from daras_ai_v2.text_splitter import ( + default_length_function, + default_separators, ) -from daras_ai_v2.text_splitter import default_length_function DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible." @@ -33,6 +43,9 @@ CHATML_ROLE_ASSISTANT = "assistant" CHATML_ROLE_USER = "user" +# nice for showing streaming progress +SUPERSCRIPT = str.maketrans("0123456789", "โฐยนยฒยณโดโตโถโทโธโน") + class LLMApis(Enum): vertex_ai = "Vertex AI" @@ -41,6 +54,7 @@ class LLMApis(Enum): class LargeLanguageModels(Enum): + gpt_4_vision = "GPT-4 Vision (openai)" gpt_4_turbo = "GPT-4 Turbo (openai)" gpt_4 = "GPT-4 (openai)" gpt_4_32k = "GPT-4 32K (openai)" @@ -49,8 +63,8 @@ class LargeLanguageModels(Enum): llama2_70b_chat = "Llama 2 (Meta AI)" - palm2_chat = "PaLM 2 Text (Google)" - palm2_text = "PaLM 2 Chat (Google)" + palm2_chat = "PaLM 2 Chat (Google)" + palm2_text = "PaLM 2 Text (Google)" text_davinci_003 = "GPT-3.5 Davinci-3 (openai)" text_davinci_002 = "GPT-3.5 Davinci-2 (openai)" @@ -64,6 +78,11 @@ class LargeLanguageModels(Enum): def _deprecated(cls): return {cls.code_davinci_002} + def is_vision_model(self) -> bool: + return self in { + self.gpt_4_vision, + } + def is_chat_model(self) -> bool: return self not in { self.palm2_text, @@ -79,6 +98,7 @@ def is_chat_model(self) -> bool: AZURE_OPENAI_MODEL_PREFIX = "openai-" llm_model_names = { + LargeLanguageModels.gpt_4_vision: "gpt-4-vision-preview", LargeLanguageModels.gpt_4_turbo: ( "openai-gpt-4-turbo-prod-ca-1", "gpt-4-1106-preview", @@ -108,6 +128,7 @@ def is_chat_model(self) -> bool: } llm_api = { + LargeLanguageModels.gpt_4_vision: LLMApis.openai, LargeLanguageModels.gpt_4_turbo: LLMApis.openai, LargeLanguageModels.gpt_4: LLMApis.openai, LargeLanguageModels.gpt_4_32k: LLMApis.openai, @@ -127,6 +148,8 @@ def is_chat_model(self) -> bool: EMBEDDING_MODEL_MAX_TOKENS = 8191 model_max_tokens = { + # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo + LargeLanguageModels.gpt_4_vision: 128_000, # https://help.openai.com/en/articles/8555510-gpt-4-turbo LargeLanguageModels.gpt_4_turbo: 128_000, # https://platform.openai.com/docs/models/gpt-4 @@ -150,6 +173,7 @@ def is_chat_model(self) -> bool: } llm_price = { + LargeLanguageModels.gpt_4_vision: 6, LargeLanguageModels.gpt_4_turbo: 5, LargeLanguageModels.gpt_4: 10, LargeLanguageModels.gpt_4_32k: 20, @@ -274,23 +298,48 @@ def _run_openai_embedding( class ConversationEntry(typing_extensions.TypedDict): - role: str + role: typing.Literal["user", "system", "assistant"] + content: str | list[ChatCompletionContentPartParam] display_name: typing_extensions.NotRequired[str] - content: str + + +def get_entry_images(entry: ConversationEntry) -> list[str]: + contents = entry.get("content") or "" + if isinstance(contents, str): + return [] + return list( + filter(None, (part.get("image_url", {}).get("url") for part in contents)), + ) + + +def get_entry_text(entry: ConversationEntry) -> str: + contents = entry.get("content") or "" + if isinstance(contents, str): + return contents + return "\n".join( + filter(None, (part.get("text") for part in contents)), + ) def run_language_model( *, model: str, - prompt: str | None = None, - messages: list[ConversationEntry] | None = None, - max_tokens: int = 512, # Default value version 1.0 - quality: float = 1.0, # Default value version 1.0 - num_outputs: int = 1, # Default value version 1.0 - temperature: float = 0.7, # Default value version 1.0 - stop: list[str] | None = None, + prompt: str = None, + messages: list[ConversationEntry] = None, + max_tokens: int = 512, + quality: float = 1.0, + num_outputs: int = 1, + temperature: float = 0.7, + stop: list[str] = None, avoid_repetition: bool = False, -) -> list[str]: + tools: list[LLMTools] = None, + stream: bool = False, + response_format_type: typing.Literal["text", "json_object"] = None, +) -> ( + list[str] + | tuple[list[str], list[list[dict]]] + | typing.Generator[list[dict], None, None] +): assert bool(prompt) != bool( messages ), "Pleave provide exactly one of { prompt, messages }" @@ -298,15 +347,19 @@ def run_language_model( model: LargeLanguageModels = LargeLanguageModels[str(model)] api = llm_api[model] model_name = llm_model_names[model] + is_chatml = False if model.is_chat_model(): - if messages: - is_chatml = False - else: + if not messages: # if input is chatml, convert it into json messages is_chatml, messages = parse_chatml(prompt) # type: ignore messages = messages or [] logger.info(f"{model_name=}, {len(messages)=}, {max_tokens=}, {temperature=}") - result = _run_chat_model( + if not model.is_vision_model(): + messages = [ + format_chat_entry(role=entry["role"], content=get_entry_text(entry)) + for entry in messages + ] + entries = _run_chat_model( api=api, model=model_name, messages=messages, # type: ignore @@ -315,17 +368,20 @@ def run_language_model( temperature=temperature, stop=stop, avoid_repetition=avoid_repetition, + tools=tools, + response_format_type=response_format_type, + # we can't stream with tools or json yet + stream=stream and not (tools or response_format_type), ) - return [ - # return messages back as either chatml or json messages - format_chatml_message(entry) - if is_chatml - else (entry.get("content") or "").strip() - for entry in result - ] + if stream: + return _stream_llm_outputs(entries, response_format_type) + else: + return _parse_entries(entries, is_chatml, response_format_type, tools) else: + if tools: + raise ValueError("Only OpenAI chat models support Tools") logger.info(f"{model_name=}, {len(prompt)=}, {max_tokens=}, {temperature=}") - result = _run_text_model( + msgs = _run_text_model( api=api, model=model_name, prompt=prompt, @@ -336,7 +392,54 @@ def run_language_model( avoid_repetition=avoid_repetition, quality=quality, ) - return [msg.strip() for msg in result] + ret = [msg.strip() for msg in msgs] + if stream: + ret = [ + [ + format_chat_entry(role=CHATML_ROLE_ASSISTANT, content=msg) + for msg in ret + ] + ] + return ret + + +def _stream_llm_outputs( + result: list | typing.Generator[list[ConversationEntry], None, None], + response_format_type: typing.Literal["text", "json_object"] | None, +): + if isinstance(result, list): # compatibility with non-streaming apis + result = [result] + for entries in result: + if response_format_type == "json_object": + for i, entry in enumerate(entries): + entries[i] = json.loads(entry["content"]) + for i, entry in enumerate(entries): + entries[i]["content"] = entry.get("content") or "" + yield entries + + +def _parse_entries( + entries: list[dict], + is_chatml: bool, + response_format_type: typing.Literal["text", "json_object"] | None, + tools: list[dict] | None, +): + if response_format_type == "json_object": + ret = [json.loads(entry["content"]) for entry in entries] + else: + ret = [ + # return messages back as either chatml or json messages + ( + format_chatml_message(entry) + if is_chatml + else (entry.get("content") or "").strip() + ) + for entry in entries + ] + if tools: + return ret, [(entry.get("tool_calls") or []) for entry in entries] + else: + return ret def _run_text_model( @@ -385,7 +488,10 @@ def _run_chat_model( model: str | tuple, stop: list[str] | None, avoid_repetition: bool, -) -> list[ConversationEntry]: + tools: list[LLMTools] | None, + response_format_type: typing.Literal["text", "json_object"] | None, + stream: bool = False, +) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]: match api: case LLMApis.openai: return _run_openai_chat( @@ -396,8 +502,13 @@ def _run_chat_model( num_outputs=num_outputs, stop=stop, temperature=temperature, + tools=tools, + response_format_type=response_format_type, + stream=stream, ) case LLMApis.vertex_ai: + if tools: + raise ValueError("Only OpenAI chat models support Tools") return _run_palm_chat( model_id=model, messages=messages, @@ -406,6 +517,8 @@ def _run_chat_model( temperature=temperature, ) case LLMApis.together: + if tools: + raise ValueError("Only OpenAI chat models support Tools") return _run_together_chat( model=model, messages=messages, @@ -428,7 +541,12 @@ def _run_openai_chat( temperature: float, stop: list[str] | None, avoid_repetition: bool, -) -> list[ConversationEntry]: + tools: list[LLMTools] | None, + response_format_type: typing.Literal["text", "json_object"] | None, + stream: bool = False, +) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]: + from openai._types import NOT_GIVEN + if avoid_repetition: frequency_penalty = 0.1 presence_penalty = 0.25 @@ -444,16 +562,85 @@ def _run_openai_chat( model=model_str, messages=messages, max_tokens=max_tokens, - stop=stop, + stop=stop or NOT_GIVEN, n=num_outputs, temperature=temperature, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + tools=[tool.spec for tool in tools] if tools else NOT_GIVEN, + response_format={"type": response_format_type} + if response_format_type + else NOT_GIVEN, + stream=stream, ) for model_str in model ], ) - return [choice.message.dict() for choice in r.choices] + if stream: + return _stream_openai_chunked(r) + else: + return [choice.message.dict() for choice in r.choices] + + +def _stream_openai_chunked( + r: Stream[ChatCompletionChunk], + start_chunk_size: int = 50, + stop_chunk_size: int = 400, + step_chunk_size: int = 150, +) -> typing.Generator[list[ConversationEntry], None, None]: + ret = [] + chunk_size = start_chunk_size + + for completion_chunk in r: + changed = False + for choice in completion_chunk.choices: + delta = choice.delta + try: + # get the entry for this choice + entry = ret[choice.index] + except IndexError: + # initialize the entry + entry = delta.dict() | {"content": "", "chunk": ""} + ret.append(entry) + # this is to mark the end of streaming + entry["finish_reason"] = choice.finish_reason + + # append the delta to the current chunk + if not delta.content: + continue + entry["chunk"] += delta.content + # if the chunk is too small, we need to wait for more data + chunk = entry["chunk"] + if len(chunk) < chunk_size: + continue + + # iterate through the separators and find the best one that matches + for sep in default_separators[:-1]: + # find the last occurrence of the separator + match = None + for match in re.finditer(sep, chunk): + pass + if not match: + continue # no match, try the next separator or wait for more data + # append text before the separator to the content + part = chunk[: match.end()] + if len(part) < chunk_size: + continue # not enough text, try the next separator or wait for more data + entry["content"] += part + # set text after the separator as the next chunk + entry["chunk"] = chunk[match.end() :] + # increase the chunk size, but don't go over the max + chunk_size = min(chunk_size + step_chunk_size, stop_chunk_size) + # we found a separator, so we can stop looking and yield the partial result + changed = True + break + if changed: + yield ret + + # add the leftover chunks + for entry in ret: + entry["content"] += entry["chunk"] + yield ret @retry_if(openai_should_retry) @@ -536,7 +723,7 @@ def _run_together_chat( ) ret = [] for r in results: - r.raise_for_status() + raise_for_status(r) data = r.json() output = data["output"] error = output.get("error") @@ -597,7 +784,7 @@ def _run_palm_chat( }, }, ) - r.raise_for_status() + raise_for_status(r) return [ { @@ -642,13 +829,13 @@ def _run_palm_text( }, }, ) - res.raise_for_status() + raise_for_status(res) return [prediction["content"] for prediction in res.json()["predictions"]] def format_chatml_message(entry: ConversationEntry) -> str: msg = CHATML_START_TOKEN + entry.get("role", "") - content = entry.get("content").strip() + content = get_entry_text(entry).strip() if content: msg += "\n" + content + CHATML_END_TOKEN return msg @@ -740,3 +927,15 @@ def build_llama_prompt(messages: list[ConversationEntry]): ret += f"{B_INST} {messages[-1].get('content').strip()} {E_INST}" return ret + + +def format_chat_entry( + *, role: str, content: str, images: list[str] = None +) -> ConversationEntry: + if images: + content = [ + {"type": "image_url", "image_url": {"url": url}} for url in images + ] + [ + {"type": "text", "text": content}, + ] + return {"role": role, "content": content} diff --git a/daras_ai_v2/language_model_settings_widgets.py b/daras_ai_v2/language_model_settings_widgets.py index 4083ece31..e5ab27a59 100644 --- a/daras_ai_v2/language_model_settings_widgets.py +++ b/daras_ai_v2/language_model_settings_widgets.py @@ -1,10 +1,14 @@ import gooey_ui as st +from daras_ai_v2.azure_doc_extract import azure_form_recognizer_models from daras_ai_v2.enum_selector_widget import enum_selector +from daras_ai_v2.field_render import field_title_desc from daras_ai_v2.language_model import LargeLanguageModels -def language_model_settings(show_selector=True): +def language_model_settings(show_selector=True, show_document_model=False): + from recipes.VideoBots import VideoBotsPage + st.write("##### ๐Ÿ”  Language Model Settings") if show_selector: @@ -14,6 +18,16 @@ def language_model_settings(show_selector=True): key="selected_model", use_selectbox=True, ) + if show_document_model: + doc_model_descriptions = azure_form_recognizer_models() + st.selectbox( + f"###### {field_title_desc(VideoBotsPage.RequestModel, 'document_model')}", + key="document_model", + options=[None, *doc_model_descriptions], + format_func=lambda x: f"{doc_model_descriptions[x]} ({x})" + if x + else "โ€”โ€”โ€”", + ) st.checkbox("Avoid Repetition", key="avoid_repetition") diff --git a/daras_ai_v2/manage_api_keys_widget.py b/daras_ai_v2/manage_api_keys_widget.py index 1c7dc71e8..8bfc9d616 100644 --- a/daras_ai_v2/manage_api_keys_widget.py +++ b/daras_ai_v2/manage_api_keys_widget.py @@ -1,9 +1,6 @@ import datetime - import gooey_ui as st -from firebase_admin import auth - from app_users.models import AppUser from daras_ai_v2 import db from daras_ai_v2.copy_to_clipboard_button_widget import ( @@ -19,12 +16,12 @@ def manage_api_keys(user: AppUser): st.write( """ -Your secret API keys are listed below. +Your secret API keys are listed below. Please note that we do not display your secret API keys again after you generate them. -Do not share your API key with others, or expose it in the browser or other client-side code. +Do not share your API key with others, or expose it in the browser or other client-side code. -In order to protect the security of your account, +In order to protect the security of your account, Gooey.AI may also automatically rotate any API key that we've found has leaked publicly. """ ) @@ -74,10 +71,10 @@ def _generate_new_key_doc() -> dict: st.success( f""" -
API key generated
+##### API key generated -Please save this secret key somewhere safe and accessible. -For security reasons, **you won't be able to view it again** through your account. +Please save this secret key somewhere safe and accessible. +For security reasons, **you won't be able to view it again** through your account. If you lose this secret key, you'll need to generate a new one. """ ) diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py index b799f5bfd..c7ccc2e18 100644 --- a/daras_ai_v2/meta_content.py +++ b/daras_ai_v2/meta_content.py @@ -1,9 +1,14 @@ -from firebase_admin import auth +from django.utils.text import slugify +from furl import furl -from app_users.models import AppUser -from daras_ai.image_input import truncate_text_words +from bots.models import PublishedRun, SavedRun, WorkflowMetadata +from daras_ai_v2 import settings from daras_ai_v2.base import BasePage +from daras_ai_v2.breadcrumbs import get_title_breadcrumbs from daras_ai_v2.meta_preview_url import meta_preview_url +from daras_ai_v2.tabs_widget import MenuTabs + +sep = " โ€ข " def build_meta_tags( @@ -15,27 +20,43 @@ def build_meta_tags( uid: str, example_id: str, ) -> list[dict]: + sr, pr = page.get_runs_from_query_params(example_id, run_id, uid) + metadata = page.workflow.get_or_create_metadata() + title = meta_title_for_page( page=page, - state=state, - run_id=run_id, - uid=uid, - example_id=example_id, + metadata=metadata, + sr=sr, + pr=pr, + tab=page.tab, ) description = meta_description_for_page( + metadata=metadata, + pr=pr, + ) + image = meta_image_for_page( + page=page, + state=state, + sr=sr, + metadata=metadata, + pr=pr, + ) + canonical_url = canonical_url_for_page( page=page, state=state, - run_id=run_id, - uid=uid, - example_id=example_id, + sr=sr, + metadata=metadata, + pr=pr, ) - image = meta_preview_url(page.preview_image(state), page.fallback_preivew_image()) + robots = robots_tag_for_page(page=page, sr=sr, pr=pr) return raw_build_meta_tags( url=url, title=title, description=description, image=image, + canonical_url=canonical_url, + robots=robots, ) @@ -45,6 +66,8 @@ def raw_build_meta_tags( title: str, description: str | None = None, image: str | None = None, + canonical_url: str | None = None, + robots: str | None = None, ) -> list[dict[str, str]]: ret = [ dict(title=title), @@ -56,76 +79,185 @@ def raw_build_meta_tags( dict(property="twitter:url", content=url), dict(property="twitter:title", content=title), ] + if description: ret += [ dict(name="description", content=description), dict(property="og:description", content=description), dict(property="twitter:description", content=description), ] + if image: ret += [ dict(name="image", content=image), dict(property="og:image", content=image), dict(property="twitter:image", content=image), ] + + if canonical_url: + ret += [dict(tagName="link", rel="canonical", href=canonical_url)] + + if robots: + ret += [dict(name="robots", content=robots)] + return ret def meta_title_for_page( *, page: BasePage, - state: dict, - run_id: str, - uid: str, - example_id: str, + metadata: WorkflowMetadata, + sr: SavedRun, + pr: PublishedRun | None, + tab: str, ) -> str: - parts = [] + match tab: + case MenuTabs.examples: + ret = f"Examples: {metadata.meta_title}" + case _ if pr and pr.saved_run == sr and pr.is_root(): + # for root page + ret = metadata.meta_title + case _: + # non-root runs and examples + parts = [] + + tbreadcrumbs = get_title_breadcrumbs(page, sr, pr) + parts.append(tbreadcrumbs.h1_title) + + # use the short title for non-root examples + part = metadata.short_title + if tbreadcrumbs.published_title: + part = f"{pr.title} {part}" + # add the creator's name + user = sr.get_creator() + if user and user.display_name: + part += f" by {user.display_name}" + parts.append(part) + + ret = sep.join(parts) - prompt = truncate_text_words(page.preview_input(state) or "", maxlen=100) - title = state.get("__title") or page.title - end_suffix = f"{title} on Gooey.AI" - - if run_id and uid: - parts.append(prompt) - try: - user = AppUser.objects.get_or_create_from_uid(uid)[0] - except auth.UserNotFoundError: - user = None - if user and user.display_name: - parts.append(user_name_possesive(user.display_name) + " " + end_suffix) - else: - parts.append(end_suffix) - elif example_id: - # DO NOT SHOW PROMPT FOR EXAMPLES - parts.append(end_suffix) + return f"{ret} {sep} Gooey.AI" + + +def meta_description_for_page( + *, + metadata: WorkflowMetadata, + pr: PublishedRun | None, +) -> str: + if pr and not pr.is_root(): + description = pr.notes or metadata.meta_description else: - parts.append(page.title) - parts.append("AI API, workflow & prompt shared on Gooey.AI") + description = metadata.meta_description + + if not (pr and pr.is_root()) or not description: + # for all non-root examples, or when there is no other description + description += sep + "AI API, workflow & prompt shared on Gooey.AI." - return " โ€ข ".join(p for p in parts if p) + return description -def user_name_possesive(name: str) -> str: - if name.endswith("s"): - return name + "'" +def meta_image_for_page( + *, + page: BasePage, + state: dict, + metadata: WorkflowMetadata, + sr: SavedRun, + pr: PublishedRun | None, +) -> str | None: + if pr and pr.saved_run == sr and pr.is_root(): + file_url = metadata.meta_image or page.preview_image(state) else: - return name + "'s" + file_url = page.preview_image(state) + return meta_preview_url( + file_url=file_url, + fallback_img=page.fallback_preivew_image(), + ) -def meta_description_for_page( + +def canonical_url_for_page( *, page: BasePage, state: dict, - run_id: str, - uid: str, - example_id: str, + metadata: WorkflowMetadata, + sr: SavedRun, + pr: PublishedRun | None, ) -> str: - description = state.get("__notes") or page.preview_description(state) - # updated_at = state.get("updated_at") - # if updated_at: - # description = updated_at.strftime("%d-%b-%Y") + " โ€” " + description + """ + Assumes that `page.tab` is a valid tab defined in MenuTabs + """ - if (run_id and uid) or example_id or not description: - description += " AI API, workflow & prompt shared on Gooey.AI." + latest_slug = page.slug_versions[-1] # for recipe + recipe_url = furl(str(settings.APP_BASE_URL)) / latest_slug - return description + if pr and pr.saved_run == sr and pr.is_root(): + query_params = {} + pr_slug = "" + elif pr and pr.saved_run == sr: + query_params = {"example_id": pr.published_run_id} + pr_slug = (pr.title and slugify(pr.title)) or "" + else: + query_params = {"run_id": sr.run_id, "uid": sr.uid} + pr_slug = "" + + tab_path = MenuTabs.paths[page.tab] + match page.tab: + case MenuTabs.examples: + # no query params / run_slug in this case + return str(recipe_url / tab_path / "/") + case MenuTabs.history, MenuTabs.saved: + # no run slug in this case + return str(furl(recipe_url, query_params=query_params) / tab_path / "/") + case _: + # all other cases + return str( + furl(recipe_url, query_params=query_params) / pr_slug / tab_path / "/" + ) + + +def robots_tag_for_page( + *, + page: BasePage, + sr: SavedRun, + pr: PublishedRun | None, +) -> str: + is_root = pr and pr.saved_run == sr and pr.is_root() + is_example = pr and pr.saved_run == sr and not pr.is_root() + + match page.tab: + case MenuTabs.run if is_root or is_example: + no_follow, no_index = False, False + case MenuTabs.run: # ordinary run (not example) + no_follow, no_index = False, True + case MenuTabs.examples: + no_follow, no_index = False, False + case MenuTabs.run_as_api: + no_follow, no_index = False, True + case MenuTabs.integrations: + no_follow, no_index = True, True + case MenuTabs.history: + no_follow, no_index = True, True + case MenuTabs.saved: + no_follow, no_index = True, True + case _: + raise ValueError(f"Unknown tab: {page.tab}") + + parts = [] + if no_follow: + parts.append("nofollow") + if no_index: + parts.append("noindex") + return ",".join(parts) + + +def get_is_indexable_for_page( + *, + page: BasePage, + sr: SavedRun, + pr: PublishedRun | None, +) -> bool: + if pr and pr.saved_run == sr and pr.is_root(): + # index all tabs on root + return True + + return bool(pr and pr.saved_run == sr and page.tab == MenuTabs.run) diff --git a/daras_ai_v2/meta_preview_url.py b/daras_ai_v2/meta_preview_url.py index 3d02e0c62..4307e4f61 100644 --- a/daras_ai_v2/meta_preview_url.py +++ b/daras_ai_v2/meta_preview_url.py @@ -1,12 +1,17 @@ import mimetypes import os -from time import time +import typing -import requests from furl import furl -def meta_preview_url(file_url: str | None, fallback_img: str | None) -> str | None: +def meta_preview_url( + file_url: str | None, + fallback_img: str = None, + size: typing.Literal[ + "400x400", "1170x1560", "40x40", "72x72", "80x80", "96x96" + ] = "400x400", +) -> str | None: if not file_url: return fallback_img @@ -22,7 +27,6 @@ def meta_preview_url(file_url: str | None, fallback_img: str | None) -> str | No file_url = fallback_img elif content_type in ["image/png", "image/jpeg", "image/tiff", "image/webp"]: # sizes: 400x400,1170x1560,40x40,72x72,80x80,96x96 - size = "400x400" f.path.segments = dir_segments + ["thumbs", f"{base}_{size}{ext}"] new_url = str(f) diff --git a/daras_ai_v2/query_generator.py b/daras_ai_v2/query_generator.py index 2769d273c..64b935b09 100644 --- a/daras_ai_v2/query_generator.py +++ b/daras_ai_v2/query_generator.py @@ -1,4 +1,6 @@ -import jinja2 +import typing + +from pydantic import BaseModel from daras_ai_v2.language_model import ( run_language_model, @@ -7,13 +9,16 @@ ) from daras_ai_v2.prompt_vars import render_prompt_vars +Model = typing.TypeVar("Model", bound=BaseModel) + def generate_final_search_query( *, - request, - response=None, + request: Model, + response: Model = None, instructions: str, context: dict = None, + response_format_type: typing.Literal["text", "json_object"] = None, ): if context is None: context = request.dict() @@ -31,4 +36,5 @@ def generate_final_search_query( quality=request.quality, temperature=request.sampling_temperature, avoid_repetition=request.avoid_repetition, + response_format_type=response_format_type, )[0] diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 4e55211d9..541e41b97 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -23,8 +23,8 @@ def safety_checker_text(text_input: str): # run in a thread to avoid messing up threadlocals result, sr = ( CompareLLMPage() - .example_doc_sr(settings.SAFTY_CHECKER_EXAMPLE_ID) - .submit_api_call( + .get_published_run(published_run_id=settings.SAFTY_CHECKER_EXAMPLE_ID) + .saved_run.submit_api_call( current_user=billing_account, request_body=dict(variables=dict(input=text_input)), ) diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py index 1baf64b47..e95fb2ab2 100644 --- a/daras_ai_v2/search_ref.py +++ b/daras_ai_v2/search_ref.py @@ -57,150 +57,180 @@ def render_text_with_refs(text: str, references: list[SearchReference]): return html -def apply_response_template( +def apply_response_formattings_prefix( output_text: list[str], references: list[SearchReference], citation_style: CitationStyles | None = CitationStyles.number, -): +) -> list[dict[int, SearchReference]]: + all_refs_list = [{}] * len(output_text) for i, text in enumerate(output_text): - formatted = "" - all_refs = {} - - for snippet, ref_map in parse_refs(text, references): - match citation_style: - case CitationStyles.number | CitationStyles.number_plaintext: - cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) - case CitationStyles.title: - cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) - case CitationStyles.url: - cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) - case CitationStyles.symbol | CitationStyles.symbol_plaintext: - cites = " ".join( - f"[{generate_footnote_symbol(ref_num - 1)}]" - for ref_num in ref_map.keys() - ) + all_refs_list[i], output_text[i] = format_citations( + text, references, citation_style + ) + return all_refs_list - case CitationStyles.markdown: - cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) - case CitationStyles.html: - cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) - case CitationStyles.slack_mrkdwn: - cites = " ".join( - ref_to_slack_mrkdwn(ref) for ref in ref_map.values() - ) - case CitationStyles.plaintext: - cites = " ".join( - f'[{ref["title"]} {ref["url"]}]' - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_markdown: - cites = " ".join( - markdown_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_html: - cites = " ".join( - html_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_slack_mrkdwn: - cites = " ".join( - slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) +def apply_response_formattings_suffix( + all_refs_list: list[dict[int, SearchReference]], + output_text: list[str], + citation_style: CitationStyles | None = CitationStyles.number, +): + for i, text in enumerate(output_text): + output_text[i] = format_jinja_response_template( + all_refs_list[i], + format_footnotes(all_refs_list[i], text, citation_style), + ) - case CitationStyles.symbol_markdown: - cites = " ".join( - markdown_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.symbol_html: - cites = " ".join( - html_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.symbol_slack_mrkdwn: - cites = " ".join( - slack_mrkdwn_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case None: - cites = "" - case _: - raise ValueError(f"Unknown citation style: {citation_style}") - formatted += snippet + " " + cites + " " - all_refs.update(ref_map) +def format_citations( + text: str, + references: list[SearchReference], + citation_style: CitationStyles | None = CitationStyles.number, +) -> tuple[dict[int, SearchReference], str]: + all_refs = {} + formatted = "" + for snippet, ref_map in parse_refs(text, references): match citation_style: + case CitationStyles.number | CitationStyles.number_plaintext: + cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) + case CitationStyles.title: + cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) + case CitationStyles.url: + cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) + case CitationStyles.symbol | CitationStyles.symbol_plaintext: + cites = " ".join( + f"[{generate_footnote_symbol(ref_num - 1)}]" + for ref_num in ref_map.keys() + ) + + case CitationStyles.markdown: + cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) + case CitationStyles.html: + cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) + case CitationStyles.slack_mrkdwn: + cites = " ".join(ref_to_slack_mrkdwn(ref) for ref in ref_map.values()) + case CitationStyles.plaintext: + cites = " ".join( + f'[{ref["title"]} {ref["url"]}]' for ref_num, ref in ref_map.items() + ) + case CitationStyles.number_markdown: - formatted += "\n\n" - formatted += "\n".join( - f"[{ref_num}] {ref_to_markdown(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + markdown_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.number_html: - formatted += "

" - formatted += "
".join( - f"[{ref_num}] {ref_to_html(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + html_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.number_slack_mrkdwn: - formatted += "\n\n" - formatted += "\n".join( - f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}" - for ref_num, ref in sorted(all_refs.items()) - ) - case CitationStyles.number_plaintext: - formatted += "\n\n" - formatted += "\n".join( - f'{ref_num}. {ref["title"]} {ref["url"]}' - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_markdown: - formatted += "\n\n" - formatted += "\n".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + markdown_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_html: - formatted += "

" - formatted += "
".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + html_link(f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_slack_mrkdwn: - formatted += "\n\n" - formatted += "\n".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" - for ref_num, ref in sorted(all_refs.items()) - ) - case CitationStyles.symbol_plaintext: - formatted += "\n\n" - formatted += "\n".join( - f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' - for ref_num, ref in sorted(all_refs.items()) - ) - - for ref_num, ref in all_refs.items(): - try: - template = ref["response_template"] - except KeyError: - pass - else: - formatted = jinja2.Template(template).render( - **ref, - output_text=formatted, - ref_num=ref_num, + cites = " ".join( + slack_mrkdwn_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) - output_text[i] = formatted + case None: + cites = "" + case _: + raise ValueError(f"Unknown citation style: {citation_style}") + formatted += " ".join(filter(None, [snippet, cites])) + all_refs.update(ref_map) + return all_refs, formatted + + +def format_footnotes( + all_refs: dict[int, SearchReference], formatted: str, citation_style: CitationStyles +) -> str: + if not all_refs: + return formatted + match citation_style: + case CitationStyles.number_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"[{ref_num}] {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_html: + formatted += "

" + formatted += "
".join( + f"[{ref_num}] {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{ref_num}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + + case CitationStyles.symbol_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_html: + formatted += "

" + formatted += "
".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + return formatted + + +def format_jinja_response_template( + all_refs: dict[int, SearchReference], formatted: str +) -> str: + for ref_num, ref in all_refs.items(): + try: + template = ref["response_template"] + except KeyError: + pass + else: + formatted = jinja2.Template(template).render( + **ref, + output_text=formatted, + ref_num=ref_num, + ) + return formatted search_ref_pat = re.compile(r"\[" r"[\d\s\.\,\[\]\$\{\}]+" r"\]") diff --git a/daras_ai_v2/serp_search.py b/daras_ai_v2/serp_search.py index 3042ccc64..d0eebd98e 100644 --- a/daras_ai_v2/serp_search.py +++ b/daras_ai_v2/serp_search.py @@ -3,6 +3,7 @@ import requests from daras_ai_v2 import settings +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.serp_search_locations import SerpSearchType, SerpSearchLocation @@ -72,6 +73,6 @@ def call_serp_api( ), headers={"X-API-KEY": settings.SERPER_API_KEY}, ) - r.raise_for_status() + raise_for_status(r) data = r.json() return data diff --git a/daras_ai_v2/serp_search_locations.py b/daras_ai_v2/serp_search_locations.py index b56184ab3..2631ed8f5 100644 --- a/daras_ai_v2/serp_search_locations.py +++ b/daras_ai_v2/serp_search_locations.py @@ -3,6 +3,7 @@ from pydantic import Field import gooey_ui as st +from daras_ai_v2.field_render import field_title_desc def serp_search_settings(): @@ -26,7 +27,7 @@ def serp_search_settings(): def serp_search_type_selectbox(key="serp_search_type"): st.selectbox( - f"###### {GoogleSearchMixin.__fields__[key].field_info.title}\n{GoogleSearchMixin.__fields__[key].field_info.description or ''}", + f"###### {field_title_desc(GoogleSearchMixin, key)}", options=SerpSearchType, format_func=lambda x: x.label, key=key, @@ -35,7 +36,7 @@ def serp_search_type_selectbox(key="serp_search_type"): def serp_search_location_selectbox(key="serp_search_location"): st.selectbox( - f"###### {GoogleSearchMixin.__fields__[key].field_info.title}\n{GoogleSearchMixin.__fields__[key].field_info.description or ''}", + f"###### {field_title_desc(GoogleSearchMixin, key)}", options=SerpSearchLocation, format_func=lambda x: f"{x.label} ({x.value})", key=key, diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index a13565115..74f38ed5f 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -244,6 +244,17 @@ SUPPORT_EMAIL = "Gooey.AI Support " SEND_RUN_EMAIL_AFTER_SEC = config("SEND_RUN_EMAIL_AFTER_SEC", 20) +DISALLOWED_TITLE_SLUGS = config("DISALLOWED_TITLE_SLUGS", cast=Csv(), default="") + [ + # tab names + "api", + "examples", + "history", + "saved", + "integrations", + # other + "docs", +] + SAFTY_CHECKER_EXAMPLE_ID = "3rcxqx0r" SAFTY_CHECKER_BILLING_EMAIL = "support+mods@gooey.ai" @@ -302,3 +313,8 @@ DEEPGRAM_API_KEY = config("DEEPGRAM_API_KEY", "") ELEVEN_LABS_API_KEY = config("ELEVEN_LABS_API_KEY", "") + +# Paypal +PAYPAL_CLIENT_ID = config("PAYPAL_CLIENT_ID", "") +PAYPAL_SECRET = config("PAYPAL_SECRET", "") +PAYPAL_BASE = config("PAYPAL_BASE", "") diff --git a/daras_ai_v2/slack_bot.py b/daras_ai_v2/slack_bot.py index bf49ca874..10c88f7d7 100644 --- a/daras_ai_v2/slack_bot.py +++ b/daras_ai_v2/slack_bot.py @@ -1,7 +1,6 @@ import re import typing from string import Template -from typing import TypedDict import requests from django.db import transaction @@ -12,9 +11,11 @@ from bots.models import BotIntegration, Platform, Conversation from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2.asr import run_google_translate, audio_bytes_to_wav -from daras_ai_v2.bots import BotInterface +from daras_ai_v2.bots import BotInterface, SLACK_MAX_SIZE +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.functional import fetch_parallel from daras_ai_v2.text_splitter import text_splitter +from recipes.VideoBots import ReplyButton SLACK_CONFIRMATION_MSG = """ Hi there! ๐Ÿ‘‹ @@ -26,11 +27,10 @@ I have been configured for $user_language and will respond to you in that language. """.strip() -SLACK_MAX_SIZE = 3000 - class SlackBot(BotInterface): platform = Platform.SLACK + can_update_message = True _read_rcpt_ts: str | None = None @@ -109,7 +109,7 @@ def get_input_audio(self) -> str | None: ), f"Unsupported mime type {mime_type} for {url}" # download file from slack r = requests.get(url, headers={"Authorization": f"Bearer {self._access_token}"}) - r.raise_for_status() + raise_for_status(r) # convert to wav data, _ = audio_bytes_to_wav(r.content) mime_type = "audio/wav" @@ -131,15 +131,16 @@ def send_msg( text: str | None = None, audio: str | None = None, video: str | None = None, - buttons: list | None = None, + buttons: list[ReplyButton] = None, + documents: list[str] = None, should_translate: bool = False, + update_msg_id: str | None = None, ) -> str | None: - if not text: - return None - if should_translate and self.language and self.language != "en": + if text and should_translate and self.language and self.language != "en": text = run_google_translate( [text], self.language, glossary_url=self.output_glossary )[0] + text = text or "\u200b" # handle empty text with zero-width space if self._read_rcpt_ts and self._read_rcpt_ts != self._msg_ts: delete_msg( @@ -149,28 +150,66 @@ def send_msg( ) self._read_rcpt_ts = None + if not self.can_update_message: + update_msg_id = None + self._msg_ts, num_splits = self.send_msg_to( + text=text, + audio=audio, + video=video, + buttons=buttons, + channel=self.bot_id, + channel_is_personal=self.convo.slack_channel_is_personal, + username=self.convo.bot_integration.name, + token=self._access_token, + thread_ts=self._msg_ts, + update_msg_ts=update_msg_id, + ) + if num_splits > 1: + self.can_update_message = False + return self._msg_ts + + @classmethod + def send_msg_to( + cls, + *, + text: str | None = None, + audio: str = None, + video: str = None, + buttons: list[ReplyButton] = None, + documents: list[str] = None, + ## whatsapp specific + channel: str, + channel_is_personal: bool, + username: str, + token: str, + thread_ts: str = None, + update_msg_ts: str = None, + ) -> tuple[str | None, int]: splits = text_splitter(text, chunk_size=SLACK_MAX_SIZE, length_function=len) for doc in splits[:-1]: - self._msg_ts = chat_post_message( + thread_ts = chat_post_message( text=doc.text, - channel=self.bot_id, - channel_is_personal=self.convo.slack_channel_is_personal, - thread_ts=self._msg_ts, - username=self.convo.bot_integration.name, - token=self._access_token, + channel=channel, + channel_is_personal=channel_is_personal, + thread_ts=thread_ts, + update_msg_ts=update_msg_ts, + username=username, + token=token, ) - self._msg_ts = chat_post_message( + update_msg_ts = None + thread_ts = chat_post_message( text=splits[-1].text, audio=audio, video=video, - channel=self.bot_id, - channel_is_personal=self.convo.slack_channel_is_personal, - thread_ts=self._msg_ts, - username=self.convo.bot_integration.name, - token=self._access_token, buttons=buttons or [], + channel=channel, + channel_is_personal=channel_is_personal, + thread_ts=thread_ts, + update_msg_ts=update_msg_ts, + username=username, + token=token, ) - return self._msg_ts + return thread_ts, len(splits) def mark_read(self): text = self.convo.bot_integration.slack_read_receipt_msg.strip() @@ -494,39 +533,64 @@ def chat_post_message( channel: str, thread_ts: str, token: str, + update_msg_ts: str = None, channel_is_personal: bool = False, - audio: str | None = None, - video: str | None = None, + audio: str = None, + video: str = None, username: str = "Video Bot", - buttons: list | None = None, + buttons: list[ReplyButton] = None, ) -> str | None: if buttons is None: buttons = [] if channel_is_personal: # don't thread in personal channels thread_ts = None - res = requests.post( - "https://slack.com/api/chat.postMessage", - json={ - "channel": channel, - "thread_ts": thread_ts, - "text": text, - "username": username, - "icon_emoji": ":robot_face:", - "blocks": [ - { - "type": "section", - "text": {"type": "mrkdwn", "text": text}, - }, - ] - + create_file_block("Audio", token, audio) - + create_file_block("Video", token, video) - + create_button_block(buttons), - }, - headers={ - "Authorization": f"Bearer {token}", - }, - ) + if update_msg_ts: + res = requests.post( + "https://slack.com/api/chat.update", + json={ + "channel": channel, + "ts": update_msg_ts, + "text": text, + "username": username, + "icon_emoji": ":robot_face:", + "blocks": [ + { + "type": "section", + "text": {"type": "mrkdwn", "text": text}, + }, + ] + + create_file_block("Audio", token, audio) + + create_file_block("Video", token, video) + + create_button_block(buttons), + }, + headers={ + "Authorization": f"Bearer {token}", + }, + ) + else: + res = requests.post( + "https://slack.com/api/chat.postMessage", + json={ + "channel": channel, + "thread_ts": thread_ts, + "text": text, + "username": username, + "icon_emoji": ":robot_face:", + "blocks": [ + { + "type": "section", + "text": {"type": "mrkdwn", "text": text}, + }, + ] + + create_file_block("Audio", token, audio) + + create_file_block("Video", token, video) + + create_button_block(buttons), + }, + headers={ + "Authorization": f"Bearer {token}", + }, + ) data = parse_slack_response(res) return data.get("ts") @@ -559,7 +623,7 @@ def create_file_block( ] -def create_button_block(buttons: list[dict]) -> list[dict]: +def create_button_block(buttons: list[ReplyButton]) -> list[dict]: if not buttons: return [] return [ @@ -568,9 +632,9 @@ def create_button_block(buttons: list[dict]) -> list[dict]: "elements": [ { "type": "button", - "text": {"type": "plain_text", "text": button["reply"]["title"]}, - "value": button["reply"]["id"], - "action_id": "button_" + button["reply"]["id"], + "text": {"type": "plain_text", "text": button["title"]}, + "value": button["id"], + "action_id": "button_" + button["id"], } for button in buttons ], @@ -593,7 +657,7 @@ def send_confirmation_msg(bot: BotIntegration): str(bot.slack_channel_hook_url), json={"text": text}, ) - res.raise_for_status() + raise_for_status(res) def invite_bot_account_to_channel(channel: str, bot_user_id: str, token: str): @@ -612,7 +676,7 @@ def invite_bot_account_to_channel(channel: str, bot_user_id: str, token: str): def parse_slack_response(res: Response): - res.raise_for_status() + raise_for_status(res) data = res.json() print(f'> {res.request.url.split("/")[-1]}: {data}') if data.get("ok"): diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index f325de57e..22c7adc8e 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -13,6 +13,7 @@ resize_img_fit, get_downscale_factor, ) +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.extract_face import rgb_img_to_rgba from daras_ai_v2.gpu_server import ( b64_img_decode, @@ -42,20 +43,22 @@ def _deprecated(cls): class Text2ImgModels(Enum): # sd_1_4 = "SD v1.4 (RunwayML)" # Host this too? - sd_2 = "Stable Diffusion v2.1 (stability.ai)" - sd_1_5 = "Stable Diffusion v1.5 (RunwayML)" dream_shaper = "DreamShaper (Lykon)" - openjourney = "Open Journey (PromptHero)" - openjourney_2 = "Open Journey v2 beta (PromptHero)" - analog_diffusion = "Analog Diffusion (wavymulder)" - protogen_5_3 = "Protogen v5.3 (darkstorm2150)" dreamlike_2 = "Dreamlike Photoreal 2.0 (dreamlike.art)" + sd_2 = "Stable Diffusion v2.1 (stability.ai)" + sd_1_5 = "Stable Diffusion v1.5 (RunwayML)" + dall_e = "DALLยทE 2 (OpenAI)" dall_e_3 = "DALLยทE 3 (OpenAI)" + openjourney_2 = "Open Journey v2 beta (PromptHero) ๐Ÿข" + openjourney = "Open Journey (PromptHero) ๐Ÿข" + analog_diffusion = "Analog Diffusion (wavymulder) ๐Ÿข" + protogen_5_3 = "Protogen v5.3 (darkstorm2150) ๐Ÿข" + jack_qiao = "Stable Diffusion v1.4 [Deprecated] (Jack Qiao)" - deepfloyd_if = "DeepFloyd IF [Deprecated] (stability.ai)" rodent_diffusion_1_5 = "Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)" + deepfloyd_if = "DeepFloyd IF [Deprecated] (stability.ai)" @classmethod def _deprecated(cls): @@ -81,12 +84,14 @@ class Img2ImgModels(Enum): dreamlike_2 = "Dreamlike Photoreal 2.0 (dreamlike.art)" sd_2 = "Stable Diffusion v2.1 (stability.ai)" sd_1_5 = "Stable Diffusion v1.5 (RunwayML)" + dall_e = "Dall-E (OpenAI)" + instruct_pix2pix = "โœจ InstructPix2Pix (Tim Brooks)" - openjourney_2 = "Open Journey v2 beta (PromptHero)" - openjourney = "Open Journey (PromptHero)" - analog_diffusion = "Analog Diffusion (wavymulder)" - protogen_5_3 = "Protogen v5.3 (darkstorm2150)" + openjourney_2 = "Open Journey v2 beta (PromptHero) ๐Ÿข" + openjourney = "Open Journey (PromptHero) ๐Ÿข" + analog_diffusion = "Analog Diffusion (wavymulder) ๐Ÿข" + protogen_5_3 = "Protogen v5.3 (darkstorm2150) ๐Ÿข" jack_qiao = "Stable Diffusion v1.4 [Deprecated] (Jack Qiao)" rodent_diffusion_1_5 = "Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)" @@ -266,25 +271,32 @@ def text2img( guidance_scale: float = None, negative_prompt: str = None, scheduler: str = None, + dall_e_3_quality: str | None = None, + dall_e_3_style: str | None = None, ): - _resolution_check(width, height, max_size=(1024, 1024)) + if selected_model != Text2ImgModels.dall_e_3.name: + _resolution_check(width, height, max_size=(1024, 1024)) match selected_model: case Text2ImgModels.dall_e_3.name: from openai import OpenAI client = OpenAI() + width, height = _get_dall_e_3_img_size(width, height) response = client.images.generate( model=text2img_model_ids[Text2ImgModels[selected_model]], - n=num_outputs, + n=1, # num_outputs, not supported yet prompt=prompt, response_format="b64_json", + quality=dall_e_3_quality, + style=dall_e_3_style, + size=f"{width}x{height}", ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] case Text2ImgModels.dall_e.name: from openai import OpenAI - edge = _get_dalle_img_size(width, height) + edge = _get_dall_e_img_size(width, height) client = OpenAI() response = client.images.generate( n=num_outputs, @@ -320,7 +332,7 @@ def text2img( ] -def _get_dalle_img_size(width: int, height: int) -> int: +def _get_dall_e_img_size(width: int, height: int) -> int: edge = max(width, height) if edge < 512: edge = 256 @@ -331,6 +343,15 @@ def _get_dalle_img_size(width: int, height: int) -> int: return edge +def _get_dall_e_3_img_size(width: int, height: int) -> tuple[int, int]: + if height == width: + return 1024, 1024 + elif width < height: + return 1024, 1792 + else: + return 1792, 1024 + + def img2img( *, selected_model: str, @@ -354,7 +375,7 @@ def img2img( case Img2ImgModels.dall_e.name: from openai import OpenAI - edge = _get_dalle_img_size(width, height) + edge = _get_dall_e_img_size(width, height) image = resize_img_pad(init_image_bytes, (edge, edge)) client = OpenAI() @@ -402,7 +423,7 @@ def controlnet( scheduler: str = None, prompt: str, num_outputs: int = 1, - init_image: str, + init_images: list[str] | str, num_inference_steps: int = 50, negative_prompt: str = None, guidance_scale: float = 7.5, @@ -411,6 +432,8 @@ def controlnet( ): if isinstance(selected_controlnet_model, str): selected_controlnet_model = [selected_controlnet_model] + if isinstance(init_images, str): + init_images = [init_images] * len(selected_controlnet_model) prompt = add_prompt_prefix(prompt, selected_model) return call_sd_multi( "diffusion.controlnet", @@ -432,7 +455,7 @@ def controlnet( "num_images_per_prompt": num_outputs, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, - "image": [init_image] * len(selected_controlnet_model), + "image": init_images, "controlnet_conditioning_scale": controlnet_conditioning_scale, # "strength": prompt_strength, }, @@ -474,7 +497,7 @@ def inpainting( case InpaintingModels.dall_e.name: from openai import OpenAI - edge = _get_dalle_img_size(width, height) + edge = _get_dall_e_img_size(width, height) edit_image_bytes = resize_img_pad(edit_image_bytes, (edge, edge)) mask_bytes = resize_img_pad(mask_bytes, (edge, edge)) image = rgb_img_to_rgba(edit_image_bytes, mask_bytes) @@ -513,7 +536,7 @@ def inpainting( out_imgs = [] for url in out_imgs_urls: r = requests.get(url) - r.raise_for_status() + raise_for_status(r) out_imgs.append(r.content) case _: diff --git a/daras_ai_v2/tabs_widget.py b/daras_ai_v2/tabs_widget.py index 235dd4039..f6511e58e 100644 --- a/daras_ai_v2/tabs_widget.py +++ b/daras_ai_v2/tabs_widget.py @@ -10,6 +10,7 @@ class MenuTabs: run_as_api = "๐Ÿš€ API" history = "๐Ÿ“– History" integrations = "๐Ÿ”Œ Integrations" + saved = "๐Ÿ“ Saved" paths = { run: "", @@ -17,6 +18,7 @@ class MenuTabs: run_as_api: "api", history: "history", integrations: "integrations", + saved: "saved", } paths_reverse = {v: k for k, v in paths.items()} diff --git a/daras_ai_v2/text_to_speech_settings_widgets.py b/daras_ai_v2/text_to_speech_settings_widgets.py index 0526c441f..499b502d2 100644 --- a/daras_ai_v2/text_to_speech_settings_widgets.py +++ b/daras_ai_v2/text_to_speech_settings_widgets.py @@ -5,6 +5,7 @@ import gooey_ui as st from daras_ai_v2.enum_selector_widget import enum_selector +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.redis_cache import redis_cache_decorator SESSION_ELEVENLABS_API_KEY = "__user__elevenlabs_api_key" @@ -79,6 +80,7 @@ class TextToSpeechProviders(Enum): "eleven_multilingual_v2": "Multilingual V2 - High quality speech in 29 languages", "eleven_turbo_v2": "English V2 - Very low latency text-to-speech", "eleven_monolingual_v1": "English V1 - Low latency text-to-speech", + "eleven_multilingual_v1": "Multilingual V1", } ELEVEN_LABS_SUPPORTED_LANGS = [ @@ -289,8 +291,7 @@ def text_to_speech_settings(page): ): st.caption( """ - Note: Please purchase Gooey.AI credits to use ElevenLabs voices - here.
+ Note: Please purchase Gooey.AI credits to use ElevenLabs voices [here](/account). Alternatively, you can use your own ElevenLabs API key by selecting the checkbox above. """ ) @@ -347,6 +348,26 @@ def text_to_speech_settings(page): key="elevenlabs_similarity_boost", ) + if st.session_state.get("elevenlabs_model") == "eleven_multilingual_v2": + col1, col2 = st.columns(2) + with col1: + st.slider( + """ + ###### Style Exaggeration + """, + min_value=0, + max_value=1.0, + step=0.05, + key="elevenlabs_style", + value=0.0, + ) + with col2: + st.checkbox( + "Speaker Boost", + key="elevenlabs_speaker_boost", + value=True, + ) + with st.expander( "Eleven Labs Supported Languages", style={"fontSize": "0.9rem", "textDecoration": "underline"}, @@ -403,7 +424,7 @@ def fetch_elevenlabs_voices(api_key: str) -> dict[str, str]: "https://api.elevenlabs.io/v1/voices", headers={"Accept": "application/json", "xi-api-key": api_key}, ) - r.raise_for_status() + raise_for_status(r) print(r.json()["voices"]) sorted_voices = sorted( r.json()["voices"], diff --git a/daras_ai_v2/user_date_widgets.py b/daras_ai_v2/user_date_widgets.py index 91d6c001e..a4c510b56 100644 --- a/daras_ai_v2/user_date_widgets.py +++ b/daras_ai_v2/user_date_widgets.py @@ -1,35 +1,58 @@ import datetime +import json +from typing import Any, Callable import gooey_ui as gui -def js_dynamic_date(dt: datetime.datetime): +def js_dynamic_date( + dt: datetime.datetime, + *, + container: Callable = gui.caption, + date_options: dict[str, Any] | None = None, + time_options: dict[str, Any] | None = None, +): timestamp_ms = dt.timestamp() * 1000 - gui.caption("Loading...", **{"data-id-dynamic-date": str(timestamp_ms)}) + attrs = {"data-id-dynamic-date": str(timestamp_ms)} + if date_options: + attrs["data-id-date-options"] = json.dumps(date_options) + if time_options: + attrs["data-id-time-options"] = json.dumps(time_options) + container("Loading...", **attrs) def render_js_dynamic_dates(): + default_date_options = { + "weekday": "short", + "day": "numeric", + "month": "short", + } + default_time_options = { + "hour": "numeric", + "hour12": True, + "minute": "numeric", + } gui.html( # language=HTML """ + """ + % { + "date_options_json": json.dumps(default_date_options), + "time_options_json": json.dumps(default_time_options), + }, + ) + + +def re_render_js_dynamic_dates(): + gui.html( + # language=HTML + """ + """, ) diff --git a/daras_ai_v2/vcard.py b/daras_ai_v2/vcard.py index 602391822..88af0856b 100644 --- a/daras_ai_v2/vcard.py +++ b/daras_ai_v2/vcard.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from daras_ai.image_input import resize_img_scale +from daras_ai_v2.exceptions import raise_for_status # see - https://datatracker.ietf.org/doc/html/rfc6350#section-3.2 CRLF = "\r\n" @@ -115,7 +116,7 @@ def to_vcf_str(self, compress_and_base64: bool = True) -> str: def vard_img(prop: str, img: str, compress_and_base64: bool, fmt: str = "PNG") -> str: if compress_and_base64: r = requests.get(img) - r.raise_for_status() + raise_for_status(r) downscaled = resize_img_scale(r.content, (400, 400)) img = base64.b64encode(downscaled).decode() return prop + ";" + vard_line(f"ENCODING=BASE64;TYPE={fmt}", img) diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index d5cf5549d..b2cc187d0 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -12,6 +12,7 @@ import requests from furl import furl from googleapiclient.errors import HttpError +from loguru import logger from pydantic import BaseModel, Field from rank_bm25 import BM25Okapi @@ -30,8 +31,9 @@ from daras_ai_v2.doc_search_settings_widgets import ( is_user_uploaded_url, ) +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS -from daras_ai_v2.functional import flatmap_parallel +from daras_ai_v2.functional import flatmap_parallel, map_parallel from daras_ai_v2.gdrive_downloader import ( gdrive_download, is_gdrive_url, @@ -52,7 +54,7 @@ class DocSearchRequest(BaseModel): search_query: str - keyword_query: str | None + keyword_query: str | list[str] | None documents: list[str] | None @@ -86,23 +88,26 @@ def get_top_k_references( Returns: the top k documents """ - yield "Getting embeddings..." + yield "Fetching latest knowledge docs..." input_docs = request.documents or [] + doc_metas = map_parallel(doc_url_to_metadata, input_docs) + + yield "Creating knowledge embeddings..." embeds: list[tuple[SearchReference, np.ndarray]] = flatmap_parallel( - lambda f_url: doc_url_to_embeds( + lambda f_url, doc_meta: get_embeds_for_doc( f_url=f_url, + doc_meta=doc_meta, max_context_words=request.max_context_words, scroll_jump=request.scroll_jump, selected_asr_model=request.selected_asr_model, google_translate_target=request.google_translate_target, ), input_docs, - max_workers=10, + doc_metas, ) - dense_query_embeds = openai_embedding_create([request.search_query])[0] - yield "Searching documents..." + yield "Searching knowledge base..." dense_weight = request.dense_weight if dense_weight is None: # for backwards compatibility @@ -128,15 +133,27 @@ def get_top_k_references( dense_ranks = np.zeros(len(embeds)) if sparse_weight: + yield "Considering results..." # get sparse scores - tokenized_corpus = [ - bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"]) - for ref, _ in embeds - ] - bm25 = BM25Okapi(tokenized_corpus, k1=2, b=0.3) - sparse_query_tokenized = bm25_tokenizer( - request.keyword_query or request.search_query + bm25_corpus = flatmap_parallel( + lambda f_url, doc_meta: get_bm25_embeds_for_doc( + f_url=f_url, + doc_meta=doc_meta, + max_context_words=request.max_context_words, + scroll_jump=request.scroll_jump, + selected_asr_model=request.selected_asr_model, + google_translate_target=request.google_translate_target, + ), + input_docs, + doc_metas, ) + bm25 = BM25Okapi(bm25_corpus, k1=2, b=0.3) + if request.keyword_query and isinstance(request.keyword_query, list): + sparse_query_tokenized = [item.lower() for item in request.keyword_query] + else: + sparse_query_tokenized = bm25_tokenizer( + request.keyword_query or request.search_query + ) if sparse_query_tokenized: sparse_scores = np.array(bm25.get_scores(sparse_query_tokenized)) # sparse_scores *= custom_weights @@ -146,6 +163,13 @@ def get_top_k_references( else: sparse_ranks = np.zeros(len(embeds)) + # just in case sparse and dense ranks are different lengths, truncate to the shorter one + if len(sparse_ranks) != len(dense_ranks): + logger.warning( + f"sparse and dense ranks are different lengths, truncating... {len(sparse_ranks)=} {len(dense_ranks)=} {len(embeds)=}" + ) + sparse_ranks = sparse_ranks[: len(dense_ranks)] + dense_ranks = dense_ranks[: len(sparse_ranks)] # RRF formula: 1 / (k + rank) k = 60 rrf_scores = ( @@ -155,11 +179,7 @@ def get_top_k_references( # Final ranking max_references = min(request.max_references, len(rrf_scores)) top_k = np.argpartition(rrf_scores, -max_references)[-max_references:] - final_ranks = sorted( - top_k, - key=lambda idx: rrf_scores[idx], - reverse=True, - ) + final_ranks = sorted(top_k, key=rrf_scores.__getitem__, reverse=True) references = [embeds[idx][0] | {"score": rrf_scores[idx]} for idx in final_ranks] @@ -207,38 +227,6 @@ def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str: ) -def doc_url_to_embeds( - *, - f_url: str, - max_context_words: int, - scroll_jump: int, - selected_asr_model: str = None, - google_translate_target: str = None, -) -> list[tuple[SearchReference, np.ndarray]]: - """ - Get document embeddings for a given document url. - - Args: - f_url: document url - max_context_words: max number of words to include in each chunk - scroll_jump: number of words to scroll by - google_translate_target: target language for google translate - selected_asr_model: selected ASR model (used for audio files) - - Returns: - list of (SearchReference, embeddings vector) tuples - """ - doc_meta = doc_url_to_metadata(f_url) - return get_embeds_for_doc( - f_url=f_url, - doc_meta=doc_meta, - max_context_words=max_context_words, - scroll_jump=scroll_jump, - selected_asr_model=selected_asr_model, - google_translate_target=google_translate_target, - ) - - class DocMetadata(typing.NamedTuple): name: str etag: str | None @@ -287,7 +275,7 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, ) - r.raise_for_status() + raise_for_status(r) except requests.RequestException as e: print(f"ignore error while downloading {f_url}: {e}") name = None @@ -317,6 +305,35 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: ) +@redis_cache_decorator +def get_bm25_embeds_for_doc( + *, + f_url: str, + doc_meta: DocMetadata, + max_context_words: int, + scroll_jump: int, + google_translate_target: str = None, + selected_asr_model: str = None, +): + pages = doc_url_to_text_pages( + f_url=f_url, + doc_meta=doc_meta, + selected_asr_model=selected_asr_model, + google_translate_target=google_translate_target, + ) + refs = pages_to_split_refs( + pages=pages, + f_url=f_url, + doc_meta=doc_meta, + max_context_words=max_context_words, + scroll_jump=scroll_jump, + ) + tokenized_corpus = [ + bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"]) for ref in refs + ] + return tokenized_corpus + + @redis_cache_decorator def get_embeds_for_doc( *, @@ -341,18 +358,44 @@ def get_embeds_for_doc( Returns: list of (metadata, embeddings) tuples """ - import pandas as pd - pages = doc_url_to_text_pages( f_url=f_url, doc_meta=doc_meta, selected_asr_model=selected_asr_model, google_translate_target=google_translate_target, ) + refs = pages_to_split_refs( + pages=pages, + f_url=f_url, + doc_meta=doc_meta, + max_context_words=max_context_words, + scroll_jump=scroll_jump, + ) + texts = [m["title"] + " | " + m["snippet"] for m in refs] + # get doc embeds in batches + batch_size = 16 # azure openai limits + embeds = flatmap_parallel( + openai_embedding_create, + [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)], + max_workers=2, + ) + return list(zip(refs, embeds)) + + +def pages_to_split_refs( + *, + pages, + f_url: str, + doc_meta: DocMetadata, + max_context_words: int, + scroll_jump: int, +) -> list[SearchReference]: + import pandas as pd + chunk_size = int(max_context_words * 2) chunk_overlap = int(max_context_words * 2 / scroll_jump) if isinstance(pages, pd.DataFrame): - metas = [] + refs = [] # treat each row as a separate document for idx, row in pages.iterrows(): row = dict(row) @@ -368,7 +411,7 @@ def get_embeds_for_doc( ) else: continue - metas += [ + refs += [ { "title": doc_meta.name, "url": f_url, @@ -381,7 +424,7 @@ def get_embeds_for_doc( ] else: # split the text into chunks - metas = [ + refs = [ { "title": ( doc_meta.name + (f", page {doc.end + 1}" if len(pages) > 1 else "") @@ -399,15 +442,7 @@ def get_embeds_for_doc( pages, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) ] - # get doc embeds in batches - batch_size = 16 # azure openai limits - texts = [m["title"] + " | " + m["snippet"] for m in metas] - embeds = flatmap_parallel( - openai_embedding_create, - [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)], - max_workers=5, - ) - return list(zip(metas, embeds)) + return refs sections_re = re.compile(r"(\s*[\r\n\f\v]|^)(\w+)\=", re.MULTILINE) @@ -489,19 +524,23 @@ def download_content_bytes(*, f_url: str, mime_type: str) -> tuple[bytes, str]: headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, ) - r.raise_for_status() + raise_for_status(r) except requests.RequestException as e: print(f"ignore error while downloading {f_url}: {e}") return b"", "" f_bytes = r.content # if it's a known encoding, standardize to utf-8 - if r.encoding: + encoding = r.apparent_encoding or r.encoding + if encoding: try: - codec = codecs.lookup(r.encoding) + codec = codecs.lookup(encoding) except LookupError: pass else: - f_bytes = codec.decode(f_bytes)[0].encode() + try: + f_bytes = codec.decode(f_bytes)[0].encode() + except UnicodeDecodeError: + pass ext = guess_ext_from_response(r) return f_bytes, ext diff --git a/explore.py b/explore.py index 05cddac15..f47f6fc23 100644 --- a/explore.py +++ b/explore.py @@ -1,8 +1,11 @@ +import typing + import gooey_ui as gui from daras_ai.image_input import truncate_text_words from daras_ai_v2.all_pages import all_home_pages_by_category +from daras_ai_v2.base import BasePage from daras_ai_v2.grid_layout_widget import grid_layout - +from daras_ai_v2.meta_preview_url import meta_preview_url META_TITLE = "Explore AI workflows" META_DESCRIPTION = "Find, fork and run your fieldโ€™s favorite AI recipes on Gooey.AI" @@ -12,26 +15,66 @@ def render(): - def _render(page_cls): + def _render_non_featured(page_cls): page = page_cls() - state = page.recipe_doc_sr().to_dict() + state = page.recipe_doc_sr(create=True).to_dict() + # total_runs = page.get_total_runs() - with gui.link(to=page.app_url()): - gui.markdown(f"### {page.get_recipe_title(state)}") + col1, col2 = gui.columns([1, 2]) + with col1: + render_image(page, state) + + with col2: + # render_description(page, state, total_runs) + render_description(page, state) + + def _render_as_featured(page_cls: typing.Type[BasePage]): + page = page_cls() + state = page.recipe_doc_sr(create=True).to_dict() + # total_runs = page.get_total_runs() + render_image(page, state) + # render_description(page, state, total_runs) + render_description(page, state) + + def render_image(page: BasePage, state: dict): + gui.image( + meta_preview_url(page.get_explore_image(state), page.preview_image(state)), + href=page.app_url(), + style={"border-radius": 5}, + ) + def render_description(page, state): + with gui.link(to=page.app_url()): + gui.markdown(f"#### {page.get_recipe_title()}") preview = page.preview_description(state) if preview: - gui.write(truncate_text_words(preview, 150)) + with gui.tag("p", style={"margin-bottom": "25px"}): + gui.html( + truncate_text_words(preview, 150), + ) else: page.render_description() - - page.render_example(state) + # with gui.tag( + # "p", + # style={ + # "font-size": "14px", + # "float": "right", + # }, + # className="text-muted", + # ): + # gui.html( + # f' {total_runs} runs' + # ) heading(title=TITLE, description=DESCRIPTION) for category, pages in all_home_pages_by_category.items(): gui.write("---") - section_heading(category) - grid_layout(3, pages, _render, separator=False) + if category != "Featured": + section_heading(category) + if category == "Images" or category == "Featured": + grid_layout(3, pages, _render_as_featured, separator=False) + else: + grid_layout(2, pages, _render_non_featured, separator=False) def heading( diff --git a/gooey_ui/components.py b/gooey_ui/components/__init__.py similarity index 76% rename from gooey_ui/components.py rename to gooey_ui/components/__init__.py index dd892869e..408b950fb 100644 --- a/gooey_ui/components.py +++ b/gooey_ui/components/__init__.py @@ -1,7 +1,9 @@ import base64 +import html as html_lib import math import textwrap import typing +from datetime import datetime, timezone import numpy as np @@ -28,10 +30,20 @@ def dummy(*args, **kwargs): spinner = dummy set_page_config = dummy form = dummy -plotly_chart = dummy dataframe = dummy +def countdown_timer( + end_time: datetime, + delay_text: str, +) -> state.NestingCtx: + return _node( + "countdown-timer", + endTime=end_time.astimezone(timezone.utc).isoformat(), + delayText=delay_text, + ) + + def nav_tabs(): return _node("nav-tabs") @@ -71,9 +83,11 @@ def write(*objs: typing.Any, unsafe_allow_html=False, **props): ) -def markdown(body: str, *, unsafe_allow_html=False, **props): +def markdown(body: str | None, *, unsafe_allow_html=False, **props): if body is None: return _node("markdown", body="", **props) + if not unsafe_allow_html: + body = html_lib.escape(body) props["className"] = ( props.get("className", "") + " gui-html-container gui-md-container" ) @@ -202,6 +216,7 @@ def image( src: str | np.ndarray, caption: str = None, alt: str = None, + href: str = None, **props, ): if isinstance(src, np.ndarray): @@ -222,6 +237,7 @@ def image( src=src, caption=dedent(caption), alt=alt or caption, + href=href, **props, ), ).mount() @@ -273,13 +289,14 @@ def text_area( **props, ) -> str: style = props.setdefault("style", {}) - if key: - assert not value, "only one of value or key can be provided" - else: + # if key: + # assert not value, "only one of value or key can be provided" + # else: + if not key: key = md5_values( "textarea", label, height, help, value, placeholder, label_visibility ) - value = str(state.session_state.setdefault(key, value)) + value = str(state.session_state.setdefault(key, value) or "") if label_visibility != "visible": label = None if disabled: @@ -405,12 +422,20 @@ def button( key: str = None, help: str = None, *, - type: typing.Literal["primary", "secondary"] = "secondary", + type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", disabled: bool = False, **props, ) -> bool: + """ + Example: + st.button("Primary", key="test0", type="primary") + st.button("Secondary", key="test1") + st.button("Tertiary", key="test3", type="tertiary") + st.button("Link Button", key="test3", type="link") + """ if not key: key = md5_values("button", label, help, type, props) + className = f"btn-{type} " + props.pop("className", "") state.RenderTreeNode( name="gui-button", props=dict( @@ -420,6 +445,7 @@ def button( label=dedent(label), help=help, disabled=disabled, + className=className, **props, ), ).mount() @@ -453,6 +479,7 @@ def file_uploader( disabled: bool = False, label_visibility: LabelVisibility = "visible", upload_meta: dict = None, + optional: bool = False, ): if label_visibility != "visible": label = None @@ -466,6 +493,13 @@ def file_uploader( help, label_visibility, ) + if optional: + if not checkbox( + label, value=bool(state.session_state.get(key)), disabled=disabled + ): + state.session_state.pop(key, None) + return None + label = None value = state.session_state.get(key) if not value: if accept_multiple_files: @@ -501,57 +535,64 @@ def json(value: typing.Any, expanded: bool = False, depth: int = 1): ).mount() -def data_table(file_url: str): - return _node("data-table", fileUrl=file_url) +def data_table(file_url_or_cells: str | list): + if isinstance(file_url_or_cells, str): + file_url = file_url_or_cells + return _node("data-table", fileUrl=file_url) + else: + cells = file_url_or_cells + return _node("data-table-raw", cells=cells) def table(df: "pd.DataFrame"): - state.RenderTreeNode( - name="table", - children=[ - state.RenderTreeNode( - name="thead", - children=[ - state.RenderTreeNode( - name="tr", - children=[ - state.RenderTreeNode( - name="th", - children=[ - state.RenderTreeNode( - name="markdown", - props=dict(body=dedent(col)), - ), - ], - ) - for col in df.columns - ], - ), - ], - ), - state.RenderTreeNode( - name="tbody", - children=[ - state.RenderTreeNode( - name="tr", - children=[ - state.RenderTreeNode( - name="td", - children=[ - state.RenderTreeNode( - name="markdown", - props=dict(body=dedent(str(value))), - ), - ], - ) - for value in row - ], - ) - for row in df.itertuples(index=False) - ], - ), - ], - ).mount() + with tag("table", className="table table-striped table-sm"): + with tag("thead"): + with tag("tr"): + for col in df.columns: + with tag("th", scope="col"): + html(dedent(col)) + with tag("tbody"): + for row in df.itertuples(index=False): + with tag("tr"): + for value in row: + with tag("td"): + html(dedent(str(value))) + + +def horizontal_radio( + label: str, + options: typing.Sequence[T], + format_func: typing.Callable[[T], typing.Any] = _default_format, + key: str = None, + help: str = None, + *, + disabled: bool = False, + checked_by_default: bool = True, + label_visibility: LabelVisibility = "visible", +) -> T | None: + if not options: + return None + options = list(options) + if not key: + key = md5_values("horizontal_radio", label, options, help, label_visibility) + value = state.session_state.get(key) + if (key not in state.session_state or value not in options) and checked_by_default: + value = options[0] + state.session_state.setdefault(key, value) + if label_visibility != "visible": + label = None + markdown(label) + for option in options: + if button( + format_func(option), + key=f"tab-{key}-{option}", + type="primary", + className="replicate-nav " + ("active" if value == option else ""), + disabled=disabled, + ): + state.session_state[key] = value = option + state.experimental_rerun() + return value def radio( @@ -559,9 +600,11 @@ def radio( options: typing.Sequence[T], format_func: typing.Callable[[T], typing.Any] = _default_format, key: str = None, + value: T = None, help: str = None, *, disabled: bool = False, + checked_by_default: bool = True, label_visibility: LabelVisibility = "visible", ) -> T | None: if not options: @@ -569,10 +612,10 @@ def radio( options = list(options) if not key: key = md5_values("radio", label, options, help, label_visibility) - value = state.session_state.get(key) - if key not in state.session_state or value not in options: + value = state.session_state.setdefault(key, value) + if value not in options and checked_by_default: value = options[0] - state.session_state.setdefault(key, value) + state.session_state[key] = value if label_visibility != "visible": label = None markdown(label) @@ -619,6 +662,38 @@ def text_input( return value or "" +def date_input( + label: str, + value: str | None = None, + key: str = None, + help: str = None, + *, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + **props, +) -> datetime | None: + value = _input_widget( + input_type="date", + label=label, + value=value, + key=key, + help=help, + disabled=disabled, + label_visibility=label_visibility, + style=dict( + border="1px solid hsl(0, 0%, 80%)", + padding="0.375rem 0.75rem", + borderRadius="0.25rem", + margin="0 0.5rem 0 0.5rem", + ), + **props, + ) + try: + return datetime.strptime(value, "%Y-%m-%d") if value else None + except ValueError: + return None + + def password_input( label: str, value: str = "", @@ -772,7 +847,7 @@ def breadcrumbs(divider: str = "/", **props) -> state.NestingCtx: def breadcrumb_item(inner_html: str, link_to: str | None = None, **props): - className = "breadcrumb-item lead " + props.pop("className", "") + className = "breadcrumb-item " + props.pop("className", "") with tag("li", className=className, **props): if link_to: with tag("a", href=link_to): @@ -781,6 +856,21 @@ def breadcrumb_item(inner_html: str, link_to: str | None = None, **props): html(inner_html) +def plotly_chart(figure_or_data, **kwargs): + data = ( + figure_or_data.to_plotly_json() + if hasattr(figure_or_data, "to_plotly_json") + else figure_or_data + ) + state.RenderTreeNode( + name="plotly-chart", + props=dict( + chart=data, + args=kwargs, + ), + ).mount() + + def dedent(text: str | None) -> str | None: if not text: return text @@ -795,3 +885,29 @@ def js(src: str, **kwargs): args=kwargs, ), ).mount() + + +def change_url(url: str, request): + """Change the url of the page, without reloading the page. Only for urls on the current domain due to browser security policies.""" + # this is useful to store certain state inputs in the url to allow for sharing/returning to a state + old_url = furl(request.url).remove(origin=True).tostr() + url = furl(url).remove(origin=True).tostr() + if old_url == url: + return + # the request is likely processing which means it will overwrite the url we set once it is done + # so we set up a timer to keep setting the url until the request is done at which point we stop + js( + f""" + setTimeout(() => window.history.replaceState(null, '', '{url}')); + function change_url() {{ + if (window.location.href.replace(window.location.origin, "") == '{old_url}') {{ + clearInterval(window._change_url_timer); + }} + window.history.replaceState(null, '', '{url}'); + }} + clearInterval(window._change_url_timer); + if (window.location.href.replace(window.location.origin, "") != '{url}') {{ + window._change_url_timer = setInterval(change_url, 100); + }} + """, + ) diff --git a/gooey_ui/components/modal.py b/gooey_ui/components/modal.py new file mode 100644 index 000000000..f8d46dcd2 --- /dev/null +++ b/gooey_ui/components/modal.py @@ -0,0 +1,96 @@ +from contextlib import contextmanager + +import gooey_ui as st +from gooey_ui import experimental_rerun as rerun + + +class Modal: + def __init__(self, title, key, padding=20, max_width=744): + """ + :param title: title of the Modal shown in the h1 + :param key: unique key identifying this modal instance + :param padding: padding of the content within the modal + :param max_width: maximum width this modal should use + """ + self.title = title + self.padding = padding + self.max_width = str(max_width) + "px" + self.key = key + + self._container = None + + def is_open(self): + return st.session_state.get(f"{self.key}-opened", False) + + def open(self): + st.session_state[f"{self.key}-opened"] = True + rerun() + + def close(self, rerun_condition=True): + st.session_state[f"{self.key}-opened"] = False + if rerun_condition: + rerun() + + def empty(self): + if self._container: + self._container.empty() + + @contextmanager + def container(self, **props): + st.html( + f""" + + """ + ) + + with st.div(className="blur-background"): + with st.div(className="modal-parent"): + container_class = "modal-container " + props.pop("className", "") + self._container = st.div(className=container_class, **props) + + with self._container: + with st.div(className="d-flex justify-content-between align-items-center"): + if self.title: + st.markdown(f"### {self.title}") + else: + st.div() + + close_ = st.button( + "✖", + type="tertiary", + key=f"{self.key}-close", + style={"padding": "0.375rem 0.75rem"}, + ) + if close_: + self.close() + yield self._container diff --git a/gooey_ui/pubsub.py b/gooey_ui/pubsub.py index ca3dfcb1a..ef2ba1539 100644 --- a/gooey_ui/pubsub.py +++ b/gooey_ui/pubsub.py @@ -2,6 +2,7 @@ import json import threading import typing +from contextlib import contextmanager from time import time import redis @@ -42,7 +43,39 @@ def realtime_push(channel: str, value: typing.Any = "ping"): msg = json.dumps(jsonable_encoder(value)) r.set(channel, msg) r.publish(channel, json.dumps(time())) - logger.info(f"publish {channel=}") + if isinstance(value, dict): + run_status = value.get("__run_status") + logger.info(f"publish {channel=} {run_status=}") + else: + logger.info(f"publish {channel=}") + + +@contextmanager +def realtime_subscribe(channel: str) -> typing.Generator: + channel = f"gooey-gui/state/{channel}" + pubsub = r.pubsub() + pubsub.subscribe(channel) + logger.info(f"subscribe {channel=}") + try: + yield _realtime_sub_gen(channel, pubsub) + finally: + logger.info(f"unsubscribe {channel=}") + pubsub.unsubscribe(channel) + pubsub.close() + + +def _realtime_sub_gen(channel: str, pubsub: redis.client.PubSub) -> typing.Generator: + while True: + message = pubsub.get_message(timeout=10) + if not (message and message["type"] == "message"): + continue + value = json.loads(r.get(channel)) + if isinstance(value, dict): + run_status = value.get("__run_status") + logger.info(f"realtime_subscribe: {channel=} {run_status=}") + else: + logger.info(f"realtime_subscribe: {channel=}") + yield value # def use_state( diff --git a/pages/UsageDashboard.py b/pages/UsageDashboard.py index 5d023e887..b4faf6a5d 100644 --- a/pages/UsageDashboard.py +++ b/pages/UsageDashboard.py @@ -171,14 +171,28 @@ def main(): """ ) - total_runs = ( - counts_df.sum(numeric_only=True) - .rename("Total Runs") - .to_frame() - .reset_index(names=["label"]) - .sort_values("Total Runs", ascending=False) - .reset_index(drop=True) - ) + if st.checkbox("Show Uniques"): + calc = "Unique Users" + total_runs = ( + counts_df.drop(columns=["display_name", "email"]) + .astype(bool) + .sum(numeric_only=True) + .rename(calc) + .to_frame() + .reset_index(names=["label"]) + .sort_values(calc, ascending=False) + .reset_index(drop=True) + ) + else: + calc = "Total Runs" + total_runs = ( + counts_df.sum(numeric_only=True) + .rename(calc) + .to_frame() + .reset_index(names=["label"]) + .sort_values(calc, ascending=False) + .reset_index(drop=True) + ) col1, col2 = st.columns(2) @@ -189,7 +203,7 @@ def main(): st.plotly_chart( px.pie( total_runs.iloc[2:], - values="Total Runs", + values=calc, names="label", ), use_container_width=True, diff --git a/poetry.lock b/poetry.lock index 5e3fd1d70..f8632b90a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -13,13 +13,13 @@ files = [ [[package]] name = "aifail" -version = "0.1.0" +version = "0.2.0" description = "" optional = false python-versions = ">=3.10,<4.0" files = [ - {file = "aifail-0.1.0-py3-none-any.whl", hash = "sha256:d51fa56b1b8531a298a38cf91a3979f7d427afc88f5af635f86865b8f721bd23"}, - {file = "aifail-0.1.0.tar.gz", hash = "sha256:40f95fd45c07f9a7f0478d9702e3ea0d811f8582da7f6b841618de2a52803177"}, + {file = "aifail-0.2.0-py3-none-any.whl", hash = "sha256:83f3a842dbe523ee10a4d53ee00f06e794176122b93b57752566fb98d60db603"}, + {file = "aifail-0.2.0.tar.gz", hash = "sha256:d3e19e16740577181922055883a00e894d11413d12803e8d443094846040d6df"}, ] [package.dependencies] @@ -1898,72 +1898,73 @@ requests = "*" [[package]] name = "greenlet" -version = "3.0.1" +version = "3.0.3" description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.7" files = [ - {file = "greenlet-3.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f89e21afe925fcfa655965ca8ea10f24773a1791400989ff32f467badfe4a064"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28e89e232c7593d33cac35425b58950789962011cc274aa43ef8865f2e11f46d"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8ba29306c5de7717b5761b9ea74f9c72b9e2b834e24aa984da99cbfc70157fd"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19bbdf1cce0346ef7341705d71e2ecf6f41a35c311137f29b8a2dc2341374565"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:599daf06ea59bfedbec564b1692b0166a0045f32b6f0933b0dd4df59a854caf2"}, - {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b641161c302efbb860ae6b081f406839a8b7d5573f20a455539823802c655f63"}, - {file = "greenlet-3.0.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d57e20ba591727da0c230ab2c3f200ac9d6d333860d85348816e1dca4cc4792e"}, - {file = "greenlet-3.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5805e71e5b570d490938d55552f5a9e10f477c19400c38bf1d5190d760691846"}, - {file = "greenlet-3.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:52e93b28db27ae7d208748f45d2db8a7b6a380e0d703f099c949d0f0d80b70e9"}, - {file = "greenlet-3.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f7bfb769f7efa0eefcd039dd19d843a4fbfbac52f1878b1da2ed5793ec9b1a65"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91e6c7db42638dc45cf2e13c73be16bf83179f7859b07cfc139518941320be96"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1757936efea16e3f03db20efd0cd50a1c86b06734f9f7338a90c4ba85ec2ad5a"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:19075157a10055759066854a973b3d1325d964d498a805bb68a1f9af4aaef8ec"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9d21aaa84557d64209af04ff48e0ad5e28c5cca67ce43444e939579d085da72"}, - {file = "greenlet-3.0.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2847e5d7beedb8d614186962c3d774d40d3374d580d2cbdab7f184580a39d234"}, - {file = "greenlet-3.0.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:97e7ac860d64e2dcba5c5944cfc8fa9ea185cd84061c623536154d5a89237884"}, - {file = "greenlet-3.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b2c02d2ad98116e914d4f3155ffc905fd0c025d901ead3f6ed07385e19122c94"}, - {file = "greenlet-3.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:22f79120a24aeeae2b4471c711dcf4f8c736a2bb2fabad2a67ac9a55ea72523c"}, - {file = "greenlet-3.0.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:100f78a29707ca1525ea47388cec8a049405147719f47ebf3895e7509c6446aa"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60d5772e8195f4e9ebf74046a9121bbb90090f6550f81d8956a05387ba139353"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:daa7197b43c707462f06d2c693ffdbb5991cbb8b80b5b984007de431493a319c"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea6b8aa9e08eea388c5f7a276fabb1d4b6b9d6e4ceb12cc477c3d352001768a9"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d11ebbd679e927593978aa44c10fc2092bc454b7d13fdc958d3e9d508aba7d0"}, - {file = "greenlet-3.0.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dbd4c177afb8a8d9ba348d925b0b67246147af806f0b104af4d24f144d461cd5"}, - {file = "greenlet-3.0.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20107edf7c2c3644c67c12205dc60b1bb11d26b2610b276f97d666110d1b511d"}, - {file = "greenlet-3.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8bef097455dea90ffe855286926ae02d8faa335ed8e4067326257cb571fc1445"}, - {file = "greenlet-3.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:b2d3337dcfaa99698aa2377c81c9ca72fcd89c07e7eb62ece3f23a3fe89b2ce4"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80ac992f25d10aaebe1ee15df45ca0d7571d0f70b645c08ec68733fb7a020206"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:337322096d92808f76ad26061a8f5fccb22b0809bea39212cd6c406f6a7060d2"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9934adbd0f6e476f0ecff3c94626529f344f57b38c9a541f87098710b18af0a"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc4d815b794fd8868c4d67602692c21bf5293a75e4b607bb92a11e821e2b859a"}, - {file = "greenlet-3.0.1-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41bdeeb552d814bcd7fb52172b304898a35818107cc8778b5101423c9017b3de"}, - {file = "greenlet-3.0.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6e6061bf1e9565c29002e3c601cf68569c450be7fc3f7336671af7ddb4657166"}, - {file = "greenlet-3.0.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:fa24255ae3c0ab67e613556375a4341af04a084bd58764731972bcbc8baeba36"}, - {file = "greenlet-3.0.1-cp37-cp37m-win32.whl", hash = "sha256:b489c36d1327868d207002391f662a1d163bdc8daf10ab2e5f6e41b9b96de3b1"}, - {file = "greenlet-3.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f33f3258aae89da191c6ebaa3bc517c6c4cbc9b9f689e5d8452f7aedbb913fa8"}, - {file = "greenlet-3.0.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:d2905ce1df400360463c772b55d8e2518d0e488a87cdea13dd2c71dcb2a1fa16"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a02d259510b3630f330c86557331a3b0e0c79dac3d166e449a39363beaae174"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55d62807f1c5a1682075c62436702aaba941daa316e9161e4b6ccebbbf38bda3"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3fcc780ae8edbb1d050d920ab44790201f027d59fdbd21362340a85c79066a74"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4eddd98afc726f8aee1948858aed9e6feeb1758889dfd869072d4465973f6bfd"}, - {file = "greenlet-3.0.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eabe7090db68c981fca689299c2d116400b553f4b713266b130cfc9e2aa9c5a9"}, - {file = "greenlet-3.0.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f2f6d303f3dee132b322a14cd8765287b8f86cdc10d2cb6a6fae234ea488888e"}, - {file = "greenlet-3.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d923ff276f1c1f9680d32832f8d6c040fe9306cbfb5d161b0911e9634be9ef0a"}, - {file = "greenlet-3.0.1-cp38-cp38-win32.whl", hash = "sha256:0b6f9f8ca7093fd4433472fd99b5650f8a26dcd8ba410e14094c1e44cd3ceddd"}, - {file = "greenlet-3.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:990066bff27c4fcf3b69382b86f4c99b3652bab2a7e685d968cd4d0cfc6f67c6"}, - {file = "greenlet-3.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ce85c43ae54845272f6f9cd8320d034d7a946e9773c693b27d620edec825e376"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89ee2e967bd7ff85d84a2de09df10e021c9b38c7d91dead95b406ed6350c6997"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87c8ceb0cf8a5a51b8008b643844b7f4a8264a2c13fcbcd8a8316161725383fe"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d6a8c9d4f8692917a3dc7eb25a6fb337bff86909febe2f793ec1928cd97bedfc"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fbc5b8f3dfe24784cee8ce0be3da2d8a79e46a276593db6868382d9c50d97b1"}, - {file = "greenlet-3.0.1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85d2b77e7c9382f004b41d9c72c85537fac834fb141b0296942d52bf03fe4a3d"}, - {file = "greenlet-3.0.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:696d8e7d82398e810f2b3622b24e87906763b6ebfd90e361e88eb85b0e554dc8"}, - {file = "greenlet-3.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:329c5a2e5a0ee942f2992c5e3ff40be03e75f745f48847f118a3cfece7a28546"}, - {file = "greenlet-3.0.1-cp39-cp39-win32.whl", hash = "sha256:cf868e08690cb89360eebc73ba4be7fb461cfbc6168dd88e2fbbe6f31812cd57"}, - {file = "greenlet-3.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:ac4a39d1abae48184d420aa8e5e63efd1b75c8444dd95daa3e03f6c6310e9619"}, - {file = "greenlet-3.0.1.tar.gz", hash = "sha256:816bd9488a94cba78d93e1abb58000e8266fa9cc2aa9ccdd6eb0696acb24005b"}, + {file = "greenlet-3.0.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9da2bd29ed9e4f15955dd1595ad7bc9320308a3b766ef7f837e23ad4b4aac31a"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d353cadd6083fdb056bb46ed07e4340b0869c305c8ca54ef9da3421acbdf6881"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dca1e2f3ca00b84a396bc1bce13dd21f680f035314d2379c4160c98153b2059b"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ed7fb269f15dc662787f4119ec300ad0702fa1b19d2135a37c2c4de6fadfd4a"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd4f49ae60e10adbc94b45c0b5e6a179acc1736cf7a90160b404076ee283cf83"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:73a411ef564e0e097dbe7e866bb2dda0f027e072b04da387282b02c308807405"}, + {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7f362975f2d179f9e26928c5b517524e89dd48530a0202570d55ad6ca5d8a56f"}, + {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:649dde7de1a5eceb258f9cb00bdf50e978c9db1b996964cd80703614c86495eb"}, + {file = "greenlet-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:68834da854554926fbedd38c76e60c4a2e3198c6fbed520b106a8986445caaf9"}, + {file = "greenlet-3.0.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:b1b5667cced97081bf57b8fa1d6bfca67814b0afd38208d52538316e9422fc61"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52f59dd9c96ad2fc0d5724107444f76eb20aaccb675bf825df6435acb7703559"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:afaff6cf5200befd5cec055b07d1c0a5a06c040fe5ad148abcd11ba6ab9b114e"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe754d231288e1e64323cfad462fcee8f0288654c10bdf4f603a39ed923bef33"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2797aa5aedac23af156bbb5a6aa2cd3427ada2972c828244eb7d1b9255846379"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7f009caad047246ed379e1c4dbcb8b020f0a390667ea74d2387be2998f58a22"}, + {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c5e1536de2aad7bf62e27baf79225d0d64360d4168cf2e6becb91baf1ed074f3"}, + {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:894393ce10ceac937e56ec00bb71c4c2f8209ad516e96033e4b3b1de270e200d"}, + {file = "greenlet-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:1ea188d4f49089fc6fb283845ab18a2518d279c7cd9da1065d7a84e991748728"}, + {file = "greenlet-3.0.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:70fb482fdf2c707765ab5f0b6655e9cfcf3780d8d87355a063547b41177599be"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4d1ac74f5c0c0524e4a24335350edad7e5f03b9532da7ea4d3c54d527784f2e"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149e94a2dd82d19838fe4b2259f1b6b9957d5ba1b25640d2380bea9c5df37676"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15d79dd26056573940fcb8c7413d84118086f2ec1a8acdfa854631084393efcc"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b7db1ebff4ba09aaaeae6aa491daeb226c8150fc20e836ad00041bcb11230"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fcd2469d6a2cf298f198f0487e0a5b1a47a42ca0fa4dfd1b6862c999f018ebbf"}, + {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f672519db1796ca0d8753f9e78ec02355e862d0998193038c7073045899f305"}, + {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2516a9957eed41dd8f1ec0c604f1cdc86758b587d964668b5b196a9db5bfcde6"}, + {file = "greenlet-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:bba5387a6975598857d86de9eac14210a49d554a77eb8261cc68b7d082f78ce2"}, + {file = "greenlet-3.0.3-cp37-cp37m-macosx_11_0_universal2.whl", hash = "sha256:5b51e85cb5ceda94e79d019ed36b35386e8c37d22f07d6a751cb659b180d5274"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:daf3cb43b7cf2ba96d614252ce1684c1bccee6b2183a01328c98d36fcd7d5cb0"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99bf650dc5d69546e076f413a87481ee1d2d09aaaaaca058c9251b6d8c14783f"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dd6e660effd852586b6a8478a1d244b8dc90ab5b1321751d2ea15deb49ed414"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3391d1e16e2a5a1507d83e4a8b100f4ee626e8eca43cf2cadb543de69827c4c"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1f145462f1fa6e4a4ae3c0f782e580ce44d57c8f2c7aae1b6fa88c0b2efdb41"}, + {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1a7191e42732df52cb5f39d3527217e7ab73cae2cb3694d241e18f53d84ea9a7"}, + {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0448abc479fab28b00cb472d278828b3ccca164531daab4e970a0458786055d6"}, + {file = "greenlet-3.0.3-cp37-cp37m-win32.whl", hash = "sha256:b542be2440edc2d48547b5923c408cbe0fc94afb9f18741faa6ae970dbcb9b6d"}, + {file = "greenlet-3.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:01bc7ea167cf943b4c802068e178bbf70ae2e8c080467070d01bfa02f337ee67"}, + {file = "greenlet-3.0.3-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:1996cb9306c8595335bb157d133daf5cf9f693ef413e7673cb07e3e5871379ca"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ddc0f794e6ad661e321caa8d2f0a55ce01213c74722587256fb6566049a8b04"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9db1c18f0eaad2f804728c67d6c610778456e3e1cc4ab4bbd5eeb8e6053c6fc"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7170375bcc99f1a2fbd9c306f5be8764eaf3ac6b5cb968862cad4c7057756506"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b66c9c1e7ccabad3a7d037b2bcb740122a7b17a53734b7d72a344ce39882a1b"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:098d86f528c855ead3479afe84b49242e174ed262456c342d70fc7f972bc13c4"}, + {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:81bb9c6d52e8321f09c3d165b2a78c680506d9af285bfccbad9fb7ad5a5da3e5"}, + {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd096eb7ffef17c456cfa587523c5f92321ae02427ff955bebe9e3c63bc9f0da"}, + {file = "greenlet-3.0.3-cp38-cp38-win32.whl", hash = "sha256:d46677c85c5ba00a9cb6f7a00b2bfa6f812192d2c9f7d9c4f6a55b60216712f3"}, + {file = "greenlet-3.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:419b386f84949bf0e7c73e6032e3457b82a787c1ab4a0e43732898a761cc9dbf"}, + {file = "greenlet-3.0.3-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:da70d4d51c8b306bb7a031d5cff6cc25ad253affe89b70352af5f1cb68e74b53"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:086152f8fbc5955df88382e8a75984e2bb1c892ad2e3c80a2508954e52295257"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d73a9fe764d77f87f8ec26a0c85144d6a951a6c438dfe50487df5595c6373eac"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7dcbe92cc99f08c8dd11f930de4d99ef756c3591a5377d1d9cd7dd5e896da71"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1551a8195c0d4a68fac7a4325efac0d541b48def35feb49d803674ac32582f61"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:64d7675ad83578e3fc149b617a444fab8efdafc9385471f868eb5ff83e446b8b"}, + {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b37eef18ea55f2ffd8f00ff8fe7c8d3818abd3e25fb73fae2ca3b672e333a7a6"}, + {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:77457465d89b8263bca14759d7c1684df840b6811b2499838cc5b040a8b5b113"}, + {file = "greenlet-3.0.3-cp39-cp39-win32.whl", hash = "sha256:57e8974f23e47dac22b83436bdcf23080ade568ce77df33159e019d161ce1d1e"}, + {file = "greenlet-3.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:c5ee858cfe08f34712f548c3c363e807e7186f03ad7a5039ebadb29e8c6be067"}, + {file = "greenlet-3.0.3.tar.gz", hash = "sha256:43374442353259554ce33599da8b692d5aa96f8976d567d4badf263371fbe491"}, ] [package.extras] -docs = ["Sphinx"] +docs = ["Sphinx", "furo"] test = ["objgraph", "psutil"] [[package]] @@ -3807,6 +3808,26 @@ files = [ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] +[[package]] +name = "playwright" +version = "1.41.1" +description = "A high-level API to automate web browsers" +optional = false +python-versions = ">=3.8" +files = [ + {file = "playwright-1.41.1-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:b456f25db38e4d93afc3c671e1093f3995afb374f14cee284152a30f84cfff02"}, + {file = "playwright-1.41.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53ff152506dbd8527aa815e92757be72f5df60810e8000e9419d29fd4445f53c"}, + {file = "playwright-1.41.1-py3-none-macosx_11_0_universal2.whl", hash = "sha256:70c432887b8b5e896fa804fb90ca2c8baf05b13a3590fb8bce8b3c3efba2842d"}, + {file = "playwright-1.41.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:f227a8d616fd3a02d45d68546ee69947dce4a058df134a9e7dc6167c543de3cd"}, + {file = "playwright-1.41.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:475130f879b4ba38b9db7232a043dd5bc3a8bd1a84567fbea7e21a02ee2fcb13"}, + {file = "playwright-1.41.1-py3-none-win32.whl", hash = "sha256:ef769414ea0ceb76085c67812ab6bc0cc6fac0adfc45aaa09d54ee161d7f637b"}, + {file = "playwright-1.41.1-py3-none-win_amd64.whl", hash = "sha256:316e1ba0854a712e9288b3fe49509438e648d43bade77bf724899de8c24848de"}, +] + +[package.dependencies] +greenlet = "3.0.3" +pyee = "11.0.1" + [[package]] name = "plotly" version = "5.18.0" @@ -4367,6 +4388,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-base-url" +version = "2.1.0" +description = "pytest plugin for URL based testing" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_base_url-2.1.0-py3-none-any.whl", hash = "sha256:3ad15611778764d451927b2a53240c1a7a591b521ea44cebfe45849d2d2812e6"}, + {file = "pytest_base_url-2.1.0.tar.gz", hash = "sha256:02748589a54f9e63fcbe62301d6b0496da0d10231b753e950c63e03aee745d45"}, +] + +[package.dependencies] +pytest = ">=7.0.0" +requests = ">=2.9" + +[package.extras] +test = ["black (>=22.1.0)", "flake8 (>=4.0.1)", "pre-commit (>=2.17.0)", "pytest-localserver (>=0.7.1)", "tox (>=3.24.5)"] + [[package]] name = "pytest-django" version = "4.6.0" @@ -4385,6 +4424,23 @@ pytest = ">=7.0.0" docs = ["sphinx", "sphinx-rtd-theme"] testing = ["Django", "django-configurations (>=2.0)"] +[[package]] +name = "pytest-playwright" +version = "0.4.4" +description = "A pytest wrapper with fixtures for Playwright to automate web browsers" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-playwright-0.4.4.tar.gz", hash = "sha256:5488db4cc49028491c5130af0a2bb6b1d0b222a202217f6d14491d4c9aa67ff9"}, + {file = "pytest_playwright-0.4.4-py3-none-any.whl", hash = "sha256:df306f3a60a8631a3cfde1b95a2ed5a89203a3408dfa1154de049ca7de87c90b"}, +] + +[package.dependencies] +playwright = ">=1.18" +pytest = ">=6.2.4,<9.0.0" +pytest-base-url = ">=1.0.0,<3.0.0" +python-slugify = ">=6.0.0,<9.0.0" + [[package]] name = "pytest-subtests" version = "0.11.0" @@ -4472,6 +4528,23 @@ files = [ [package.dependencies] six = ">=1.4.0" +[[package]] +name = "python-slugify" +version = "8.0.3" +description = "A Python slugify application that also handles Unicode" +optional = false +python-versions = ">=3.7" +files = [ + {file = "python-slugify-8.0.3.tar.gz", hash = "sha256:e04cba5f1c562502a1175c84a8bc23890c54cdaf23fccaaf0bf78511508cabed"}, + {file = "python_slugify-8.0.3-py2.py3-none-any.whl", hash = "sha256:c71189c161e8c671f1b141034d9a56308a8a5978cd13d40446c879569212fdd1"}, +] + +[package.dependencies] +text-unidecode = ">=1.3" + +[package.extras] +unidecode = ["Unidecode (>=1.1.1)"] + [[package]] name = "pytz" version = "2023.3.post1" @@ -5545,6 +5618,17 @@ files = [ [package.extras] doc = ["reno", "sphinx", "tornado (>=4.5)"] +[[package]] +name = "text-unidecode" +version = "1.3" +description = "The most basic Text::Unidecode port" +optional = false +python-versions = "*" +files = [ + {file = "text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93"}, + {file = "text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8"}, +] + [[package]] name = "tiktoken" version = "0.3.3" @@ -6487,4 +6571,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "9442702290e07c5cbe9c5cd7d610438f67fcf259da8861b8ce08af0ec700583f" +content-hash = "e773b499e69efc03fded0d07749a512a0f06de76ec72819a98bd8c41b5f13446" diff --git a/pyproject.toml b/pyproject.toml index ebed124ba..d90fa3ee9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,8 @@ ua-parser = "^0.18.0" user-agents = "^2.2.0" openpyxl = "^3.1.2" loguru = "^0.7.2" -aifail = "^0.1.0" +aifail = "0.2.0" +pytest-playwright = "^0.4.3" google-cloud-secret-manager = "^2.16.4" [tool.poetry.group.dev.dependencies] diff --git a/recipes/BulkEval.py b/recipes/BulkEval.py new file mode 100644 index 000000000..0672f662b --- /dev/null +++ b/recipes/BulkEval.py @@ -0,0 +1,375 @@ +import itertools +import typing +from itertools import zip_longest + +import typing_extensions +from pydantic import BaseModel, Field + +import gooey_ui as st +from bots.models import Workflow +from daras_ai.image_input import upload_file_from_bytes +from daras_ai_v2.base import BasePage +from daras_ai_v2.doc_search_settings_widgets import document_uploader +from daras_ai_v2.field_render import field_title_desc +from daras_ai_v2.functional import map_parallel +from daras_ai_v2.language_model import ( + run_language_model, + LargeLanguageModels, + llm_price, +) +from daras_ai_v2.language_model_settings_widgets import language_model_settings +from daras_ai_v2.prompt_vars import render_prompt_vars +from recipes.BulkRunner import read_df_any, list_view_editor, del_button +from recipes.DocSearch import render_documents + +NROWS_CACHE_KEY = "__nrows" + +AggFunctionsList = [ + "mean", + "median", + "min", + "max", + "sum", + "cumsum", + "prod", + "cumprod", + "std", + "var", + "first", + "last", + "count", + "cumcount", + "nunique", + "rank", +] + + +class LLMSettingsMixin(BaseModel): + selected_model: typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + avoid_repetition: bool | None + num_outputs: int | None + quality: float | None + max_tokens: int | None + sampling_temperature: float | None + + +class EvalPrompt(typing.TypedDict): + name: str + prompt: str + + +class AggFunction(typing_extensions.TypedDict): + column: typing_extensions.NotRequired[str] + function: typing.Literal[tuple(AggFunctionsList)] + + +class AggFunctionResult(typing.TypedDict): + column: str + function: typing.Literal[tuple(AggFunctionsList)] + count: int + value: float + + +def _render_results(results: list[AggFunctionResult]): + import plotly.graph_objects as go + from plotly.colors import sample_colorscale + + for k, g in itertools.groupby(results, key=lambda d: d["function"]): + st.write("---\n###### **Aggregate**: " + k.capitalize()) + + g = list(g) + + columns = [d["column"] for d in g] + values = [round(d["value"], 2) for d in g] + + norm_values = [ + (v - min(values)) / ((max(values) - min(values)) or 1) for v in values + ] + colors = sample_colorscale("RdYlGn", norm_values, colortype="tuple") + colors = [f"rgba{(r * 255, g * 255, b * 255, 0.5)}" for r, g, b in colors] + + st.data_table( + [ + ["Metric", k.capitalize(), "Count"], + ] + + [ + [ + columns[i], + dict( + kind="number", + readonly=True, + displayData=str(values[i]), + data=values[i], + themeOverride=dict(bgCell=colors[i]), + ), + g[i].get("count", 1), + ] + for i in range(len(g)) + ] + ) + + fig = go.Figure( + data=[ + go.Bar( + name=k, + x=columns, + y=values, + marker=dict(color=colors), + text=values, + texttemplate="%{text}", + insidetextanchor="middle", + insidetextfont=dict(size=24), + ), + ], + layout=dict( + margin=dict(l=0, r=0, t=24, b=0), + ), + ) + st.plotly_chart(fig) + + +class BulkEvalPage(BasePage): + title = "Evaluator" + workflow = Workflow.BULK_EVAL + slug_versions = ["bulk-eval", "eval"] + + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aad314f0-9a97-11ee-8318-02420a0001c7/W.I.9.png.png" + + def preview_image(self, state: dict) -> str | None: + return "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/9631fb74-9a97-11ee-971f-02420a0001c4/evaluator.png.png" + + def render_description(self): + st.write( + """ +Summarize and score every row of any CSV, google sheet or excel with GPT4 (or any LLM you choose). Then average every score in any column to generate automated evaluations. + """ + ) + + def related_workflows(self) -> list: + from recipes.BulkRunner import BulkRunnerPage + from recipes.VideoBots import VideoBotsPage + from recipes.asr import AsrPage + from recipes.DocSearch import DocSearchPage + + return [BulkRunnerPage, VideoBotsPage, AsrPage, DocSearchPage] + + class RequestModel(LLMSettingsMixin, BaseModel): + documents: list[str] = Field( + title="Input Data Spreadsheet", + description=""" +Upload or link to a CSV or google sheet that contains your sample input data. +For example, for Copilot, this would sample questions or for Art QR Code, would would be pairs of image descriptions and URLs. +Remember to includes header names in your CSV too. + """, + ) + + eval_prompts: list[EvalPrompt] | None = Field( + title="Evaluation Prompts", + description=""" +Specify custom LLM prompts to calculate metrics that evaluate each row of the input data. The output should be a JSON object mapping the metric names to values. +_The `columns` dictionary can be used to reference the spreadsheet columns._ + """, + ) + + agg_functions: list[AggFunction] | None = Field( + title="Aggregations", + description=""" +Aggregate using one or more operations. Uses [pandas](https://pandas.pydata.org/pandas-docs/stable/reference/groupby.html#dataframegroupby-computations-descriptive-stats). + """, + ) + + class ResponseModel(BaseModel): + output_documents: list[str] + aggregations: list[list[AggFunctionResult]] | None + + def render_form_v2(self): + files = document_uploader( + f"##### {field_title_desc(self.RequestModel, 'documents')}", + accept=(".csv", ".xlsx", ".xls", ".json", ".tsv", ".xml"), + ) + st.session_state[NROWS_CACHE_KEY] = get_nrows(files) + if not files: + return + + st.write( + """ +##### Input Data Preview +Here's what you uploaded: + """ + ) + for file in files: + st.data_table(file) + st.write("---") + + def render_inputs(key: str, del_key: str, d: EvalPrompt): + col1, col2 = st.columns([8, 1], responsive=False) + with col1: + d["name"] = st.text_input( + label="", + label_visibility="collapsed", + placeholder="Metric Name", + key=key + ":name", + value=d.get("name"), + ).strip() + d["prompt"] = st.text_area( + label="", + label_visibility="collapsed", + placeholder="Prompt", + key=key + ":prompt", + value=d.get("prompt"), + height=500, + ).strip() + with col2: + del_button(del_key) + + st.write("##### " + field_title_desc(self.RequestModel, "eval_prompts")) + list_view_editor( + add_btn_label="โž• Add a Prompt", + key="eval_prompts", + render_inputs=render_inputs, + ) + + def render_agg_inputs(key: str, del_key: str, d: AggFunction): + col1, col3 = st.columns([8, 1], responsive=False) + with col1: + # d["column"] = st.text_input( + # "", + # label_visibility="collapsed", + # placeholder="Column Name", + # key=key + ":column", + # value=d.get("column"), + # ).strip() + # with col2: + with st.div(className="pt-1"): + d["function"] = st.selectbox( + "", + label_visibility="collapsed", + key=key + ":func", + options=AggFunctionsList, + default_value=d.get("function"), + ) + with col3: + del_button(del_key) + + st.html("
") + st.write("##### " + field_title_desc(self.RequestModel, "agg_functions")) + list_view_editor( + add_btn_label="โž• Add an Aggregation", + key="agg_functions", + render_inputs=render_agg_inputs, + ) + + def render_settings(self): + language_model_settings() + + def render_example(self, state: dict): + render_documents(state) + + def render_output(self): + files = st.session_state.get("output_documents", []) + aggregations = st.session_state.get("aggregations", []) + + for file, results in zip_longest(files, aggregations): + st.write(file) + st.data_table(file) + + if not results: + continue + + _render_results(results) + + def run_v2( + self, + request: "BulkEvalPage.RequestModel", + response: "BulkEvalPage.ResponseModel", + ) -> typing.Iterator[str | None]: + import pandas as pd + + response.output_documents = [] + response.aggregations = [] + + for doc_ix, doc in enumerate(request.documents): + df = read_df_any(doc) + in_recs = df.to_dict(orient="records") + out_recs = [] + + out_df = None + f = upload_file_from_bytes( + filename=f"bulk-eval-{doc_ix}-0.csv", + data=df.to_csv(index=False).encode(), + content_type="text/csv", + ) + response.output_documents.append(f) + response.aggregations.append([]) + + for df_ix in range(len(in_recs)): + rec_ix = len(out_recs) + out_recs.append(in_recs[df_ix]) + + for ep_ix, ep in enumerate(request.eval_prompts): + progress = round( + (doc_ix + df_ix + ep_ix) + / (len(request.documents) + len(df) + len(request.eval_prompts)) + * 100 + ) + yield f"{progress}%" + prompt = render_prompt_vars( + ep["prompt"], + st.session_state | {"columns": out_recs[rec_ix]}, + ) + ret = run_language_model( + model=LargeLanguageModels.gpt_4_turbo.name, + prompt=prompt, + response_format_type="json_object", + )[0] + assert isinstance(ret, dict) + for metric_name, metric_value in ret.items(): + col = f"{ep['name']} - {metric_name}" + out_recs[rec_ix][col] = metric_value + + out_df = pd.DataFrame.from_records(out_recs) + f = upload_file_from_bytes( + filename=f"evaluator-{doc_ix}-{df_ix}.csv", + data=out_df.to_csv(index=False).encode(), + content_type="text/csv", + ) + response.output_documents[doc_ix] = f + + if out_df is None: + continue + for agg in request.agg_functions: + if agg.get("column"): + cols = [agg["column"]] + else: + cols = out_df.select_dtypes(include=["float", "int"]).columns + for col in cols: + col_values = out_df[col].dropna() + agg_value = col_values.agg(agg["function"]) + response.aggregations[doc_ix].append( + { + "column": col, + "function": agg["function"], + "count": len(col_values), + "value": agg_value, + } + ) + + def fields_to_save(self) -> [str]: + return super().fields_to_save() + [NROWS_CACHE_KEY] + + def get_raw_price(self, state: dict) -> float: + try: + price = llm_price[LargeLanguageModels[state["selected_model"]]] + except KeyError: + price = 1 + nprompts = len(state.get("eval_prompts") or {}) or 1 + nrows = ( + state.get(NROWS_CACHE_KEY) or get_nrows(state.get("documents") or []) or 1 + ) + return price * nprompts * nrows + + +@st.cache_in_session_state +def get_nrows(files: list[str]) -> int: + dfs = map_parallel(read_df_any, files) + return sum((len(df) for df in dfs), 0) diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index d36f780a3..15710c410 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -1,16 +1,18 @@ import datetime import io import typing +import uuid -from fastapi import HTTPException from furl import furl from pydantic import BaseModel, Field import gooey_ui as st -from bots.models import Workflow +from bots.models import Workflow, PublishedRun, PublishedRunVisibility, SavedRun from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2.base import BasePage +from daras_ai_v2.breadcrumbs import get_title_breadcrumbs from daras_ai_v2.doc_search_settings_widgets import document_uploader +from daras_ai_v2.field_render import field_title_desc from daras_ai_v2.functional import map_parallel from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.vector_search import ( @@ -19,26 +21,30 @@ ) from recipes.DocSearch import render_documents +DEFAULT_BULK_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d80fd4d8-93fa-11ee-bc13-02420a0001cc/Bulk%20Runner.jpg.png" + class BulkRunnerPage(BasePage): - title = "Bulk Runner & Evaluator" + title = "Bulk Runner" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/87f35df4-88d7-11ee-aac9-02420a00016b/Bulk%20Runner.png.png" workflow = Workflow.BULK_RUNNER slug_versions = ["bulk-runner", "bulk"] + price = 1 class RequestModel(BaseModel): documents: list[str] = Field( title="Input Data Spreadsheet", description=""" -Upload or link to a CSV or google sheet that contains your sample input data. -For example, for Copilot, this would sample questions or for Art QR Code, would would be pairs of image descriptions and URLs. +Upload or link to a CSV or google sheet that contains your sample input data. +For example, for Copilot, this would sample questions or for Art QR Code, would would be pairs of image descriptions and URLs. Remember to includes header names in your CSV too. """, ) run_urls: list[str] = Field( - title="Gooey Workflow URL(s)", + title="Gooey Workflows", description=""" -Paste in one or more Gooey.AI workflow links (on separate lines). -You can add multiple URLs runs from the same recipe (e.g. two versions of your copilot) and we'll run the inputs over both of them. +Provide one or more Gooey.AI workflow runs. +You can add multiple runs from the same recipe (e.g. two versions of your copilot) and we'll run the inputs over both of them. """, ) @@ -55,26 +61,37 @@ class RequestModel(BaseModel): """, ) + eval_urls: list[str] | None = Field( + title="Evaluation Workflows", + description=""" +_(optional)_ Add one or more Gooey.AI Evaluator Workflows to evaluate the results of your runs. + """, + ) + class ResponseModel(BaseModel): output_documents: list[str] + eval_runs: list[str] | None = Field( + title="Evaluation Run URLs", + description=""" +List of URLs to the evaluation runs that you requested. + """, + ) + + def preview_image(self, state: dict) -> str | None: + return DEFAULT_BULK_META_IMG + def render_form_v2(self): - from daras_ai_v2.all_pages import page_slug_map, normalize_slug - - run_urls = st.session_state.get("run_urls", "") - st.session_state.setdefault("__run_urls", "\n".join(run_urls)) - run_urls = ( - st.text_area( - f"##### {self.RequestModel.__fields__['run_urls'].field_info.title}\n{self.RequestModel.__fields__['run_urls'].field_info.description or ''}", - key="__run_urls", - ) - .strip() - .splitlines() + st.write(f"##### {field_title_desc(self.RequestModel, 'run_urls')}") + run_urls = list_view_editor( + add_btn_label="โž• Add a Workflow", + key="run_urls", + render_inputs=render_run_url_inputs, + flatten_dict_key="url", ) - st.session_state["run_urls"] = run_urls files = document_uploader( - f"##### {self.RequestModel.__fields__['documents'].field_info.title}\n{self.RequestModel.__fields__['documents'].field_info.description or ''}", + f"---\n##### {field_title_desc(self.RequestModel, 'documents')}", accept=(".csv", ".xlsx", ".xls", ".json", ".tsv", ".xml"), ) @@ -83,19 +100,9 @@ def render_form_v2(self): output_fields = {} for url in run_urls: - f = furl(url) - slug = f.path.segments[0] - try: - page_cls = page_slug_map[normalize_slug(slug)] - except KeyError as e: - st.error(repr(e)) - continue - - example_id, run_id, uid = extract_query_params(f.query.params) try: - sr = page_cls.get_sr_from_query_params(example_id, run_id, uid) - except HTTPException as e: - st.error(repr(e)) + page_cls, sr, _ = url_to_runs(url) + except: continue schema = page_cls.RequestModel.schema(ref_template="{model}") @@ -115,12 +122,12 @@ def render_form_v2(self): except KeyError: try: keys = {k: k for k in sr.state[field][0].keys()} - except (KeyError, IndexError, AttributeError): + except (KeyError, IndexError, AttributeError, TypeError): pass elif field_props.get("type") == "object": try: keys = {k: k for k in sr.state[field].keys()} - except (KeyError, AttributeError): + except (KeyError, AttributeError, TypeError): pass if keys: for k, ktitle in keys.items(): @@ -136,8 +143,7 @@ def render_form_v2(self): st.write( """ -##### Input Data Preview -Here's what you uploaded: +###### **Preview**: Here's what you uploaded """ ) for file in files: @@ -146,14 +152,15 @@ def render_form_v2(self): if not (required_input_fields or optional_input_fields): return - st.write( - """ ---- + with st.div(className="pt-3"): + st.write( + """ +###### **Columns** Please select which CSV column corresponds to your workflow's input fields. -For the outputs, please fill in what the column name should be that corresponds to each output too. +For the outputs, select the fields that should be included in the output CSV. To understand what each field represents, check out our [API docs](https://api.gooey.ai/docs). - """ - ) + """, + ) visible_col1, visible_col2 = st.columns(2) with st.expander("๐Ÿคฒ Show All Columns"): @@ -188,18 +195,18 @@ def render_form_v2(self): with hidden_col2: st.write("##### Outputs") - # only show the first output field by default, and hide others - try: - first_out_field = next( - field for field in output_fields if "output" in field - ) - except StopIteration: - first_out_field = next(iter(output_fields)) + visible_out_fields = {} + # only show the first output & run url field by default, and hide others + if output_fields: + try: + first_out_field = next( + field for field in output_fields if "output" in field + ) + except StopIteration: + first_out_field = next(iter(output_fields)) + visible_out_fields[first_out_field] = output_fields[first_out_field] + visible_out_fields["run_url"] = "Run URL" - visible_out_fields = { - first_out_field: output_fields[first_out_field], - "run_url": "Run URL", - } hidden_out_fields = { "price": "Price", "run_time": "Run Time", @@ -223,14 +230,37 @@ def render_form_v2(self): if col: output_columns_new[field] = title + st.write("---") + st.write(f"##### {field_title_desc(self.RequestModel, 'eval_urls')}") + list_view_editor( + add_btn_label="โž• Add an Eval", + key="eval_urls", + render_inputs=render_eval_url_inputs, + flatten_dict_key="url", + ) + def render_example(self, state: dict): render_documents(state) def render_output(self): - files = st.session_state.get("output_documents", []) - for file in files: - st.write(file) - st.data_table(file) + eval_runs = st.session_state.get("eval_runs") + + if eval_runs: + _backup = st.session_state + for url in eval_runs: + try: + page_cls, sr, _ = url_to_runs(url) + except SavedRun.DoesNotExist: + continue + st.set_session_state(sr.state) + page_cls().render_output() + st.write("---") + st.set_session_state(_backup) + else: + files = st.session_state.get("output_documents", []) + for file in files: + st.write(file) + st.data_table(file) def run_v2( self, @@ -242,7 +272,7 @@ def run_v2( response.output_documents = [] for doc_ix, doc in enumerate(request.documents): - df = _read_df(doc) + df = read_df_any(doc) in_recs = df.to_dict(orient="records") out_recs = [] @@ -258,7 +288,7 @@ def run_v2( rec_ix = len(out_recs) out_recs.extend(in_recs[df_ix : df_ix + arr_len]) - for url_ix, f, request_body, page_cls in build_requests_for_df( + for url_ix, request_body, page_cls, sr, pr in build_requests_for_df( df, request, df_ix, arr_len ): progress = round( @@ -268,9 +298,6 @@ def run_v2( ) yield f"{progress}%" - example_id, run_id, uid = extract_query_params(f.query.params) - sr = page_cls.get_sr_from_query_params(example_id, run_id, uid) - result, sr = sr.submit_api_call( current_user=self.request.user, request_body=request_body ) @@ -288,7 +315,10 @@ def run_v2( for field, col in request.output_columns.items(): if len(request.run_urls) > 1: - col = f"({url_ix + 1}) {col}" + if pr and pr.title: + col = f"({pr.title}) {col}" + else: + col = f"({url_ix + 1}) {col}" out_val = state.get(field) if isinstance(out_val, list): for arr_ix, item in enumerate(out_val): @@ -323,43 +353,247 @@ def run_v2( ) response.output_documents[doc_ix] = f + if not request.eval_urls: + return + + response.eval_runs = [] + for url in request.eval_urls: + page_cls, sr, pr = url_to_runs(url) + yield f"Running {get_title_breadcrumbs(page_cls, sr, pr).h1_title}..." + request_body = page_cls.RequestModel( + documents=response.output_documents + ).dict(exclude_unset=True) + result, sr = sr.submit_api_call( + current_user=self.request.user, request_body=request_body + ) + result.get(disable_sync_subtasks=False) + sr.refresh_from_db() + response.eval_runs.append(sr.get_app_url()) + def preview_description(self, state: dict) -> str: return """ -Which AI model actually works best for your needs? -Upload your own data and evaluate any Gooey.AI workflow, LLM or AI model against any other. -Great for large data sets, AI model evaluation, task automation, parallel processing and automated testing. -To get started, paste in a Gooey.AI workflow, upload a CSV of your test data (with header names!), check the mapping of headers to workflow inputs and tap Submit. -More tips in the Details below. +Which AI model actually works best for your needs? +Upload your own data and evaluate any Gooey.AI workflow, LLM or AI model against any other. +Great for large data sets, AI model evaluation, task automation, parallel processing and automated testing. +To get started, paste in a Gooey.AI workflow, upload a CSV of your test data (with header names!), check the mapping of headers to workflow inputs and tap Submit. +More tips in the Details below. """ def render_description(self): st.write( """ -Building complex AI workflows like copilot) and then evaluating each iteration is complex. -Workflows are affected by the particular LLM used (GPT4 vs PalM2), their vector DB knowledge sets (e.g. your google docs), how synthetic data creation happened (e.g. how you transformed your video transcript or PDF into structured data), which translation or speech engine you used and your LLM prompts. Every change can affect the quality of your outputs. +Building complex AI workflows like copilot) and then evaluating each iteration is complex. +Workflows are affected by the particular LLM used (GPT4 vs PalM2), their vector DB knowledge sets (e.g. your google docs), how synthetic data creation happened (e.g. how you transformed your video transcript or PDF into structured data), which translation or speech engine you used and your LLM prompts. Every change can affect the quality of your outputs. 1. This bulk tool enables you to do two incredible things: -2. Upload your own set of inputs (e.g. typical questions to your bot) to any gooey workflow (e.g. /copilot) and run them in bulk to generate outputs or answers. -3. Compare the results of competing workflows to determine which one generates better outputs. +2. Upload your own set of inputs (e.g. typical questions to your bot) to any gooey workflow (e.g. /copilot) and run them in bulk to generate outputs or answers. +3. Compare the results of competing workflows to determine which one generates better outputs. To get started: 1. Enter the Gooey.AI Workflow URLs that you'd like to run in bulk 2. Enter a csv of sample inputs to run in bulk -3. Ensure that the mapping between your inputs and API parameters of the Gooey.AI workflow are correctly mapped. -4. Tap Submit. +3. Ensure that the mapping between your inputs and API parameters of the Gooey.AI workflow are correctly mapped. +4. Tap Submit. 5. Wait for results -6. Make a change to your Gooey Workflow, copy its URL and repeat Step 1 (or just add the link to see the results of both workflows together) +6. Make a change to your Gooey Workflow, copy its URL and repeat Step 1 (or just add the link to see the results of both workflows together) """ ) -def build_requests_for_df(df, request, df_ix, arr_len): +def render_run_url_inputs(key: str, del_key: str, d: dict): + from daras_ai_v2.all_pages import all_home_pages + + _prefill_workflow(d, key) + + col1, col2, col3 = st.columns([10, 1, 1], responsive=False) + if not d.get("workflow") and d.get("url"): + with col1: + url = st.text_input( + "", + key=key + ":url", + value=d.get("url"), + placeholder="https://gooey.ai/.../?run_id=...", + ) + else: + with col1: + scol1, scol2, scol3 = st.columns([5, 6, 1], responsive=False) + with scol1: + with st.div(className="pt-1"): + options = { + page_cls.workflow: page_cls.get_recipe_title() + for page_cls in all_home_pages + } + last_workflow_key = "__last_run_url_workflow" + workflow = st.selectbox( + "", + key=key + ":workflow", + default_value=( + d.get("workflow") or st.session_state.get(last_workflow_key) + ), + options=options, + format_func=lambda x: options[x], + ) + d["workflow"] = workflow + # use this to set default for next time + st.session_state[last_workflow_key] = workflow + with scol2: + page_cls = Workflow(workflow).page_cls + options = _get_approved_example_options(page_cls, workflow) + with st.div(className="pt-1"): + url = st.selectbox( + "", + key=key + ":url", + options=options, + default_value=d.get("url"), + format_func=lambda x: options[x], + ) + with scol3: + edit_button(key + ":editmode") + with col2: + url_button(url) + with col3: + del_button(del_key) + + try: + url_to_runs(url) + except Exception as e: + st.error(repr(e)) + d["url"] = url + + +@st.cache_in_session_state +def _get_approved_example_options( + page_cls: typing.Type[BasePage], workflow: Workflow +) -> dict[str, str]: + options = { + # root recipe + page_cls.get_root_published_run().get_app_url(): "Default", + } | { + # approved examples + pr.get_app_url(): get_title_breadcrumbs(page_cls, pr.saved_run, pr).h1_title + for pr in PublishedRun.objects.filter( + workflow=workflow, + is_approved_example=True, + visibility=PublishedRunVisibility.PUBLIC, + ).exclude(published_run_id="") + } + return options + + +def render_eval_url_inputs(key: str, del_key: str, d: dict): + _prefill_workflow(d, key) + + col1, col2, col3 = st.columns([10, 1, 1], responsive=False) + if not d.get("workflow") and d.get("url"): + with col1: + url = st.text_input( + "", + key=key + ":url", + value=d.get("url"), + placeholder="https://gooey.ai/.../?run_id=...", + ) + else: + d["workflow"] = Workflow.BULK_EVAL + with col1: + scol1, scol2 = st.columns([11, 1], responsive=False) + with scol1: + from recipes.BulkEval import BulkEvalPage + + options = { + BulkEvalPage.get_root_published_run().get_app_url(): "Default", + } | { + pr.get_app_url(): pr.title + for pr in PublishedRun.objects.filter( + workflow=Workflow.BULK_EVAL, + is_approved_example=True, + visibility=PublishedRunVisibility.PUBLIC, + ).exclude(published_run_id="") + } + with st.div(className="pt-1"): + url = st.selectbox( + "", + key=key + ":url", + options=options, + default_value=d.get("url"), + format_func=lambda x: options[x], + ) + with scol2: + edit_button(key + ":editmode") + with col2: + url_button(url) + with col3: + del_button(del_key) + + try: + url_to_runs(url) + except Exception as e: + st.error(repr(e)) + d["url"] = url + + +def url_button(url): + st.html( + f""" + + + + """ + ) + + +def edit_button(key: str): + st.button( + '', + key=key, + type="tertiary", + ) + + +def del_button(key: str): + st.button( + '', + key=key, + type="tertiary", + ) + + +def _prefill_workflow(d: dict, key: str): + if st.session_state.get(key + ":editmode"): + d.pop("workflow", None) + elif not d.get("workflow") and d.get("url"): + try: + _, sr, pr = url_to_runs(str(d["url"])) + except Exception: + return + else: + if ( + pr + and pr.saved_run == sr + and pr.visibility == PublishedRunVisibility.PUBLIC + and (pr.is_approved_example or pr.is_root()) + ): + d["workflow"] = pr.workflow + d["url"] = pr.get_app_url() + + +def url_to_runs( + url: str, +) -> tuple[typing.Type[BasePage], SavedRun, PublishedRun | None]: from daras_ai_v2.all_pages import page_slug_map, normalize_slug + f = furl(url) + slug = f.path.segments[0] + page_cls = page_slug_map[normalize_slug(slug)] + example_id, run_id, uid = extract_query_params(f.query.params) + sr, pr = page_cls.get_runs_from_query_params(example_id, run_id, uid) + return page_cls, sr, pr + + +def build_requests_for_df(df, request, df_ix, arr_len): for url_ix, url in enumerate(request.run_urls): - f = furl(url) - slug = f.path.segments[0] - page_cls = page_slug_map[normalize_slug(slug)] + page_cls, sr, pr = url_to_runs(url) schema = page_cls.RequestModel.schema() properties = schema["properties"] @@ -385,9 +619,11 @@ def build_requests_for_df(df, request, df_ix, arr_len): else: request_body[field] = df.at[df_ix, col] # for validation - request_body = page_cls.RequestModel.parse_obj(request_body).dict() + request_body = page_cls.RequestModel.parse_obj(request_body).dict( + exclude_unset=True + ) - yield url_ix, f, request_body, page_cls + yield url_ix, request_body, page_cls, sr, pr def slice_request_df(df, request): @@ -437,7 +673,7 @@ def is_arr(field_props: dict) -> bool: @st.cache_in_session_state def get_columns(files: list[str]) -> list[str]: - dfs = map_parallel(_read_df, files) + dfs = map_parallel(read_df_any, files) return list( { col: None @@ -448,7 +684,7 @@ def get_columns(files: list[str]) -> list[str]: ) -def _read_df(f_url: str) -> "pd.DataFrame": +def read_df_any(f_url: str) -> "pd.DataFrame": import pandas as pd doc_meta = doc_url_to_metadata(f_url) @@ -470,3 +706,48 @@ def _read_df(f_url: str) -> "pd.DataFrame": raise ValueError(f"Unsupported file type: {f_url}") return df.dropna(how="all", axis=1).dropna(how="all", axis=0).fillna("") + + +def list_view_editor( + *, + add_btn_label: str, + key: str, + render_labels: typing.Callable = None, + render_inputs: typing.Callable[[str, str, dict], None], + flatten_dict_key: str = None, +): + if flatten_dict_key: + list_key = f"--list-view:{key}" + st.session_state.setdefault( + list_key, + [{flatten_dict_key: val} for val in st.session_state.get(key, [])], + ) + new_lst = list_view_editor( + add_btn_label=add_btn_label, + key=list_key, + render_labels=render_labels, + render_inputs=render_inputs, + ) + ret = [d[flatten_dict_key] for d in new_lst] + st.session_state[key] = ret + return ret + + old_lst = st.session_state.setdefault(key, []) + add_key = f"--{key}:add" + if st.session_state.get(add_key): + old_lst.append({}) + label_placeholder = st.div() + new_lst = [] + for d in old_lst: + entry_key = d.setdefault("__key__", f"--{key}:{uuid.uuid1()}") + del_key = entry_key + ":del" + if st.session_state.pop(del_key, None): + continue + render_inputs(entry_key, del_key, d) + new_lst.append(d) + if new_lst and render_labels: + with label_placeholder: + render_labels() + st.session_state[key] = new_lst + st.button(add_btn_label, key=add_key) + return new_lst diff --git a/recipes/ChyronPlant.py b/recipes/ChyronPlant.py index b4d3af5c0..116a8ad6a 100644 --- a/recipes/ChyronPlant.py +++ b/recipes/ChyronPlant.py @@ -10,6 +10,7 @@ class ChyronPlantPage(BasePage): title = "Chyron Plant Bot" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.CHYRON_PLANT slug_versions = ["ChyronPlant"] diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index b7361b20f..bb246d6c3 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -1,10 +1,9 @@ import random import typing - -import gooey_ui as st from pydantic import BaseModel +import gooey_ui as st from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_multiselect @@ -12,16 +11,18 @@ run_language_model, LargeLanguageModels, llm_price, + SUPERSCRIPT, ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.prompt_vars import prompt_vars_widget, render_prompt_vars -DEFAULT_COMPARE_LM_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/assets/compare%20llm%20under%201%20mg%20gif.gif" +DEFAULT_COMPARE_LM_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/5e4f4c58-93fc-11ee-a39e-02420a0001ce/LLMs.jpg.png" class CompareLLMPage(BasePage): title = "Large Language Models: GPT-3" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ae42015e-88d7-11ee-aac9-02420a00016b/Compare%20LLMs.png.png" workflow = Workflow.COMPARE_LLM slug_versions = ["CompareLLM", "llm", "compare-large-language-models"] @@ -93,9 +94,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["output_text"] = output_text = {} for selected_model in request.selected_models: - yield f"Running {LargeLanguageModels[selected_model].value}..." - - output_text[selected_model] = run_language_model( + model = LargeLanguageModels[selected_model] + yield f"Running {model.value}..." + ret = run_language_model( model=selected_model, quality=request.quality, num_outputs=request.num_outputs, @@ -103,7 +104,11 @@ def run(self, state: dict) -> typing.Iterator[str | None]: prompt=prompt, max_tokens=request.max_tokens, avoid_repetition=request.avoid_repetition, + stream=True, ) + for i, entries in enumerate(ret): + output_text[selected_model] = [e["content"] for e in entries] + yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." def render_output(self): self._render_outputs(st.session_state, 450) diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 05fd37057..e79ef5d54 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -17,6 +17,7 @@ scheduler_setting, ) from daras_ai_v2.loom_video_widget import youtube_video +from daras_ai_v2.safety_checker import safety_checker from daras_ai_v2.stable_diffusion import ( Text2ImgModels, text2img, @@ -25,9 +26,12 @@ Schedulers, ) +DEFAULT_COMPARE_TEXT2IMG_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ae7b2940-93fc-11ee-8edc-02420a0001cc/Compare%20image%20generators.jpg.png" + class CompareText2ImgPage(BasePage): title = "Compare AI Image Generators" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d127484e-88d9-11ee-b549-02420a000167/Compare%20AI%20Image%20generators.png.png" workflow = Workflow.COMPARE_TEXT2IMG slug_versions = [ "CompareText2Img", @@ -40,6 +44,8 @@ class CompareText2ImgPage(BasePage): "seed": 42, "sd_2_upscaling": False, "image_guidance_scale": 1.2, + "dall_e_3_quality": "standard", + "dall_e_3_style": "vivid", } class RequestModel(BaseModel): @@ -51,6 +57,8 @@ class RequestModel(BaseModel): num_outputs: int | None quality: int | None + dall_e_3_quality: str | None + dall_e_3_style: str | None guidance_scale: float | None seed: int | None @@ -69,6 +77,9 @@ class ResponseModel(BaseModel): typing.Literal[tuple(e.name for e in Text2ImgModels)], list[str] ] + def preview_image(self, state: dict) -> str | None: + return DEFAULT_COMPARE_TEXT2IMG_META_IMG + def related_workflows(self) -> list: from recipes.FaceInpainting import FaceInpaintingPage from recipes.ObjectInpainting import ObjectInpaintingPage @@ -152,7 +163,7 @@ def render_settings(self): negative_prompt_setting() output_resolution_setting() - num_outputs_setting() + num_outputs_setting(st.session_state.get("selected_models", [])) sd_2_upscaling_setting() col1, col2 = st.columns(2) with col1: @@ -168,6 +179,10 @@ def render_output(self): def run(self, state: dict) -> typing.Iterator[str | None]: request: CompareText2ImgPage.RequestModel = self.RequestModel.parse_obj(state) + if not self.request.user.disable_safety_checker: + yield "Running safety checker..." + safety_checker(text=request.text_prompt) + state["output_images"] = output_images = {} for selected_model in request.selected_models: @@ -178,6 +193,8 @@ def run(self, state: dict) -> typing.Iterator[str | None]: prompt=request.text_prompt, num_outputs=request.num_outputs, num_inference_steps=request.quality, + dall_e_3_quality=request.dall_e_3_quality, + dall_e_3_style=request.dall_e_3_style, width=request.output_width, height=request.output_height, guidance_scale=request.guidance_scale, diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index e68ac65e1..f6f4e988b 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -8,12 +8,14 @@ from daras_ai_v2.enum_selector_widget import enum_multiselect from daras_ai_v2.face_restoration import UpscalerModels, run_upscaler_model from daras_ai_v2.stable_diffusion import SD_IMG_MAX_SIZE +from daras_ai_v2.safety_checker import safety_checker -DEFAULT_COMPARE_UPSCALER_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/assets/COMPARE%20IMAGE%20UPSCALERS.jpg" +DEFAULT_COMPARE_UPSCALER_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/2e8ee512-93fe-11ee-a083-02420a0001c8/Image%20upscaler.jpg.png" class CompareUpscalerPage(BasePage): title = "Compare AI Image Upscalers" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/64393e0c-88db-11ee-b428-02420a000168/AI%20Image%20Upscaler.png.png" workflow = Workflow.COMPARE_UPSCALER slug_versions = ["compare-ai-upscalers"] @@ -75,6 +77,10 @@ def render_description(self): def run(self, state: dict) -> typing.Iterator[str | None]: request: CompareUpscalerPage.RequestModel = self.RequestModel.parse_obj(state) + if not self.request.user.disable_safety_checker: + yield "Running safety checker..." + safety_checker(image=request.input_image) + state["output_images"] = output_images = {} for selected_model in request.selected_models: diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index b0d81a25c..57b5bb9f7 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -1,5 +1,6 @@ import typing import uuid +from datetime import datetime, timedelta from django.db.models import TextChoices from pydantic import BaseModel @@ -12,6 +13,9 @@ from daras_ai_v2.gpu_server import call_celery_task_outfile from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.safety_checker import safety_checker +from daras_ai_v2.tabs_widget import MenuTabs + +DEFAULT_DEFORUMSD_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7dc25196-93fe-11ee-9e3a-02420a0001ce/AI%20Animation%20generator.jpg.png" class AnimationModels(TextChoices): @@ -27,6 +31,7 @@ class _AnimationPrompt(TypedDict): AnimationPrompts = list[_AnimationPrompt] CREDITS_PER_FRAME = 1.5 +MODEL_ESTIMATED_TIME_PER_FRAME = 2.4 # seconds def input_prompt_to_animation_prompts(input_prompt: str): @@ -77,7 +82,7 @@ def animation_prompts_editor( st.write("#### ๐Ÿ‘ฉโ€๐Ÿ’ป Animation Prompts") st.caption( """ - Describe the scenes or series of images that you want to generate into an animation. You can add as many prompts as you like. Mention the keyframe number for each prompt i.e. the transition point from the first prompt to the next. + Describe the scenes or series of images that you want to generate into an animation. You can add as many prompts as you like. Mention the keyframe number for each prompt i.e. the transition point from the first prompt to the next. View the โ€˜Detailsโ€™ drop down menu to get started. """ ) @@ -144,7 +149,7 @@ def animation_prompts_editor( ) st.caption( """ - Pro-tip: To avoid abrupt endings on your animation, ensure that the last keyframe prompt is set for a higher number of keyframes/time than the previous transition rate. There should be an ample number of frames between the last frame and the total frame count of the animation. + Pro-tip: To avoid abrupt endings on your animation, ensure that the last keyframe prompt is set for a higher number of keyframes/time than the previous transition rate. There should be an ample number of frames between the last frame and the total frame count of the animation. """ ) @@ -158,6 +163,7 @@ def get_last_frame(prompt_list: list) -> int: class DeforumSDPage(BasePage): title = "AI Animation Generator" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/media/users/kxmNIYAOJbfOURxHBKNCWeUSKiP2/dd88c110-88d6-11ee-9b4f-2b58bd50e819/animation.gif" workflow = Workflow.DEFORUM_SD slug_versions = ["DeforumSD", "animation-generator"] @@ -195,6 +201,9 @@ class RequestModel(BaseModel): class ResponseModel(BaseModel): output_video: str + def preview_image(self, state: dict) -> str | None: + return DEFAULT_DEFORUMSD_META_IMG + def related_workflows(self) -> list: from recipes.VideoBots import VideoBotsPage from recipes.LipsyncTTS import LipsyncTTSPage @@ -225,7 +234,7 @@ def render_form_v2(self): ) st.caption( """ -Pro-tip: The more frames you add, the longer it will take to render the animation. Test your prompts before adding more frames. +Pro-tip: The more frames you add, the longer it will take to render the animation. Test your prompts before adding more frames. """ ) @@ -270,14 +279,14 @@ def render_settings(self): st.text_input( """ ###### Zoom -How should the camera zoom in or out? This setting scales the canvas size, multiplicatively. -1 is static, with numbers greater than 1 moving forward (or zooming in) and numbers less than 1 moving backwards (or zooming out). +How should the camera zoom in or out? This setting scales the canvas size, multiplicatively. +1 is static, with numbers greater than 1 moving forward (or zooming in) and numbers less than 1 moving backwards (or zooming out). """, key="zoom", ) st.caption( """ - With 0 as the starting keyframe, the input of 0: (1.004) can be used to zoom in moderately, starting at frame 0 and continuing until the end. + With 0 as the starting keyframe, the input of 0: (1.004) can be used to zoom in moderately, starting at frame 0 and continuing until the end. """ ) st.text_input( @@ -318,7 +327,7 @@ def render_settings(self): ) st.slider( """ -###### FPS (Frames per second) +###### FPS (Frames per second) Choose fps for the video. """, min_value=10, @@ -358,13 +367,13 @@ def preview_description(self, state: dict) -> str: def render_description(self): st.markdown( f""" - - Every Submit will require approximately 3-5 minutes to render. + - Every Submit will require approximately 3-5 minutes to render. - - Animation is complex: Please watch the video and review our decks to help you. + - Animation is complex: Please watch the video and review our decks to help you. - - Test your image prompts BEFORE adding lots of frames e.g. Tweak key frame images with just 10 frames between them AND then increase the FPS or frame count between them once you like the outputs. This will save you time and credits. + - Test your image prompts BEFORE adding lots of frames e.g. Tweak key frame images with just 10 frames between them AND then increase the FPS or frame count between them once you like the outputs. This will save you time and credits. - - No lost work! All your animations or previously generated versions are in the History tab. If they don't appear here, it likely means they aren't done rendering yet. + - No lost work! All your animations or previously generated versions are in the History tab. If they don't appear here, it likely means they aren't done rendering yet. """ ) @@ -379,23 +388,23 @@ def render_description(self): Hereโ€™s a comprehensive style guide to assist you with different stylized animation prompts: [StableDiffusion CheatSheet](https://supagruen.github.io/StableDiffusion-CheatSheet/) - + """ ) st.write("---") st.markdown( """ - Animation Length: You can indicate how long you want your animation to be by increasing or decreasing your frame count. + Animation Length: You can indicate how long you want your animation to be by increasing or decreasing your frame count. - FPS: Every Animation is set at 12 frames per second by default. You can change this default frame rate/ frames per second (FPS) on the Settings menu. + FPS: Every Animation is set at 12 frames per second by default. You can change this default frame rate/ frames per second (FPS) on the Settings menu. - Prompts: Within your sequence you can input multiple text Prompts for your visuals. Each prompt can be defined for a specific keyframe number. + Prompts: Within your sequence you can input multiple text Prompts for your visuals. Each prompt can be defined for a specific keyframe number. - ##### What are keyframes? + ##### What are keyframes? Keyframes define the transition points from one prompt to the next, or the start and end points of a prompted action set in between the total frame count or sequence. These keyframes or markers are necessary to establish smooth transitions or jump cuts, whatever you prefer. - Use the Camera Settings to generate animations with depth and other 3D parameters. + Use the Camera Settings to generate animations with depth and other 3D parameters. """ ) st.markdown( @@ -417,13 +426,18 @@ def render_output(self): st.write("Output Video") st.video(output_video, autoplay=True) + def estimate_run_duration(self): + # in seconds + return st.session_state.get("max_frames", 100) * MODEL_ESTIMATED_TIME_PER_FRAME + def render_example(self, state: dict): display = self.preview_input(state) st.markdown("```lua\n" + display + "\n```") st.video(state.get("output_video"), autoplay=True) - def preview_input(self, state: dict) -> str: + @classmethod + def preview_input(cls, state: dict) -> str: input_prompt = state.get("input_prompt") if input_prompt: animation_prompts = input_prompt_to_animation_prompts(input_prompt) diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index cca9848f7..6c9489fff 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -26,6 +26,7 @@ from daras_ai_v2.base import BasePage from daras_ai_v2.doc_search_settings_widgets import document_uploader from daras_ai_v2.enum_selector_widget import enum_selector +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS from daras_ai_v2.functional import ( apply_parallel, @@ -40,7 +41,7 @@ from daras_ai_v2.vector_search import doc_url_to_metadata from recipes.DocSearch import render_documents -DEFAULT_YOUTUBE_BOT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6c8f6876-538c-11ee-bea7-02420a000195/youtube%20bot%201.png.png" +DEFAULT_YOUTUBE_BOT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ddc8ffac-93fb-11ee-89fb-02420a0001cb/Youtube%20transcripts.jpg.png" class Columns(IntegerChoices): @@ -56,6 +57,7 @@ class Columns(IntegerChoices): class DocExtractPage(BasePage): title = "Youtube Transcripts + GPT extraction to Google Sheets" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.DOC_EXTRACT slug_versions = [ "doc-extract", @@ -289,7 +291,7 @@ def extract_info(url: str) -> list[dict | None]: headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, ) - r.raise_for_status() + raise_for_status(r) f_bytes = r.content inputpdf = PdfReader(io.BytesIO(f_bytes)) return [ diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 22a6dddb9..15cb6b063 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -23,8 +23,9 @@ from daras_ai_v2.search_ref import ( SearchReference, render_output_with_refs, - apply_response_template, CitationStyles, + apply_response_formattings_prefix, + apply_response_formattings_suffix, ) from daras_ai_v2.vector_search import ( DocSearchRequest, @@ -33,11 +34,12 @@ render_sources_widget, ) -DEFAULT_DOC_SEARCH_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/assets/DOC%20SEARCH.gif" +DEFAULT_DOC_SEARCH_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/bcc7aa58-93fe-11ee-a083-02420a0001c8/Search%20your%20docs.jpg.png" class DocSearchPage(BasePage): title = "Search your Docs with GPT" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/cbbb4dc6-88d7-11ee-bf6c-02420a000166/Search%20your%20docs%20with%20gpt.png.png" workflow = Workflow.DOC_SEARCH slug_versions = ["doc-search"] @@ -193,9 +195,12 @@ def run_v2( citation_style = ( request.citation_style and CitationStyles[request.citation_style] ) or None - apply_response_template( + all_refs_list = apply_response_formattings_prefix( response.output_text, response.references, citation_style ) + apply_response_formattings_suffix( + all_refs_list, response.output_text, citation_style + ) def get_raw_price(self, state: dict) -> float: name = state.get("selected_model") diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 1e4873573..4b9283cde 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -27,7 +27,7 @@ ) from recipes.GoogleGPT import render_output_with_refs, GoogleGPTPage -DEFAULT_DOC_SUMMARY_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/db70c56e-585a-11ee-990b-02420a00018f/doc%20summary.png.png" +DEFAULT_DOC_SUMMARY_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f35796d2-93fe-11ee-b86c-02420a0001c7/Summarize%20with%20GPT.jpg.png" class CombineDocumentsChains(Enum): @@ -38,6 +38,7 @@ class CombineDocumentsChains(Enum): class DocSummaryPage(BasePage): title = "Summarize your Docs with GPT" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/1f858a7a-88d8-11ee-a658-02420a000163/Summarize%20your%20docs%20with%20gpt.png.png" workflow = Workflow.DOC_SUMMARY slug_versions = ["doc-summary"] diff --git a/recipes/EmailFaceInpainting.py b/recipes/EmailFaceInpainting.py index c4c7d6ba0..74749bdb5 100644 --- a/recipes/EmailFaceInpainting.py +++ b/recipes/EmailFaceInpainting.py @@ -9,6 +9,7 @@ from bots.models import Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import db, settings +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.send_email import send_email_via_postmark from daras_ai_v2.stable_diffusion import InpaintingModels @@ -17,9 +18,12 @@ email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" twitter_handle_regex = r"(@)?[A-Za-z0-9_]{1,15}" +DEFAULT_EMAIL_FACE_INPAINTING_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6937427a-9522-11ee-b6d3-02420a0001ea/Email%20photo.jpg.png" + class EmailFaceInpaintingPage(FaceInpaintingPage): title = "AI Generated Photo from Email Profile Lookup" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ec0df5aa-9521-11ee-93d3-02420a0001e5/Email%20Profile%20Lookup.png.png" workflow = Workflow.EMAIL_FACE_INPAINTING slug_versions = ["EmailFaceInpainting", "ai-image-from-email-lookup"] @@ -84,6 +88,9 @@ class ResponseModel(BaseModel): output_images: list[str] email_sent: bool = False + def preview_image(self, state: dict) -> str | None: + return DEFAULT_EMAIL_FACE_INPAINTING_META_IMG + def preview_description(self, state: dict) -> str: return "Find an email's public photo and then draw the face into an AI generated scene using your own prompt + the latest Stable Diffusion or DallE image generator." @@ -359,7 +366,7 @@ def get_photo_for_email(email_address): f"https://api.seon.io/SeonRestService/email-api/v2.2/{email_address}", headers={"X-API-KEY": settings.SEON_API_KEY}, ) - r.raise_for_status() + raise_for_status(r) account_details = glom.glom(r.json(), "data.account_details", default={}) for spec in [ @@ -392,7 +399,7 @@ def get_photo_for_twitter_handle(twitter_handle): f"https://api.twitter.com/2/users/by?usernames={twitter_handle}&user.fields=profile_image_url", headers={"Authorization": f"Bearer {settings.TWITTER_BEARER_TOKEN}"}, ) - r.raise_for_status() + raise_for_status(r) error = glom.glom(r.json(), "errors.0.title", default=None) if error: if error == "Not Found Error": diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index daddcb54a..ee377b53d 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -21,11 +21,15 @@ ) from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.repositioning import repositioning_preview_img +from daras_ai_v2.safety_checker import safety_checker from daras_ai_v2.stable_diffusion import InpaintingModels +DEFAULT_FACE_INPAINTING_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a146bfc0-93ff-11ee-b86c-02420a0001c7/Face%20in%20painting.jpg.png" + class FaceInpaintingPage(BasePage): title = "AI Image with a Face" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/10c2ce06-88da-11ee-b428-02420a000168/ai%20image%20with%20a%20face.png.png" workflow = Workflow.FACE_INPAINTING slug_versions = ["FaceInpainting", "face-in-ai-generated-photo"] @@ -76,6 +80,9 @@ class ResponseModel(BaseModel): diffusion_images: list[str] output_images: list[str] + def preview_image(self, state: dict) -> str | None: + return DEFAULT_FACE_INPAINTING_META_IMG + def preview_description(self, state: dict) -> str: return "Upload & extract a face into an AI-generated photo using your text + the latest Stable Diffusion or DallE image generator." @@ -244,6 +251,10 @@ def render_usage_guide(self): # loom_video("788dfdee763a4e329e28e749239f9810") def run(self, state: dict): + if not self.request.user.disable_safety_checker: + yield "Running safety checker..." + safety_checker(image=state["input_image"]) + yield "Extracting Face..." input_image_url = state["input_image"] diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 636798b71..57b1078ba 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -37,11 +37,12 @@ EmptySearchResults, ) -DEFAULT_GOOGLE_GPT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/assets/WEBSEARCH%20%2B%20CHATGPT.jpg" +DEFAULT_GOOGLE_GPT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/85ed60a2-9405-11ee-9747-02420a0001ce/Web%20search%20GPT.jpg.png" class GoogleGPTPage(BasePage): title = "Web Search + GPT3" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/28649544-9406-11ee-bba3-02420a0001cc/Websearch%20GPT%20option%202.png.png" workflow = Workflow.GOOGLE_GPT slug_versions = ["google-gpt"] diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index 795e716a6..c3dab52f0 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -29,9 +29,12 @@ instruct_pix2pix, ) +DEFAULT_GOOGLE_IMG_GEN_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/dcd82b68-9400-11ee-9e3a-02420a0001ce/Search%20result%20photo.jpg.png" + class GoogleImageGenPage(BasePage): title = "Render Image Search Results with AI" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/eb23c078-88da-11ee-aa86-02420a000165/web%20search%20render.png.png" workflow = Workflow.GOOGLE_IMAGE_GEN slug_versions = ["GoogleImageGen", "render-images-with-ai"] @@ -73,6 +76,9 @@ class ResponseModel(BaseModel): image_urls: list[str] selected_image: str | None + def preview_image(self, state: dict) -> str | None: + return DEFAULT_GOOGLE_IMG_GEN_META_IMG + def related_workflows(self): from recipes.ObjectInpainting import ObjectInpaintingPage from recipes.QRCodeGenerator import QRCodeGeneratorPage diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 80a8bd677..543664065 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -28,9 +28,12 @@ repositioning_preview_widget, ) +DEFAULT_IMG_SEGMENTATION_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/8363ed50-9401-11ee-878f-02420a0001cb/AI%20bg%20changer.jpg.png" + class ImageSegmentationPage(BasePage): title = "AI Background Changer" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/06fc595e-88db-11ee-b428-02420a000168/AI%20Background%20Remover.png.png" workflow = Workflow.IMAGE_SEGMENTATION slug_versions = ["ImageSegmentation", "remove-image-background-with-ai"] @@ -64,6 +67,9 @@ class ResponseModel(BaseModel): resized_image: str resized_mask: str + def preview_image(self, state: dict) -> str | None: + return DEFAULT_IMG_SEGMENTATION_META_IMG + def related_workflows(self) -> list: from recipes.ObjectInpainting import ObjectInpaintingPage from recipes.Img2Img import Img2ImgPage diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 03d4dd668..a97bc7c62 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -19,9 +19,12 @@ ) from daras_ai_v2.safety_checker import safety_checker +DEFAULT_IMG2IMG_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/cc2804ea-9401-11ee-940a-02420a0001c7/Edit%20an%20image.jpg.png" + class Img2ImgPage(BasePage): title = "Edit An Image with AI prompt" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/bcc9351a-88d9-11ee-bf6c-02420a000166/Edit%20an%20image%20with%20AI%201.png.png" workflow = Workflow.IMG_2_IMG slug_versions = ["Img2Img", "ai-photo-editor"] @@ -67,6 +70,9 @@ class RequestModel(BaseModel): class ResponseModel(BaseModel): output_images: list[str] + def preview_image(self, state: dict) -> str | None: + return DEFAULT_IMG2IMG_META_IMG + def related_workflows(self) -> list: from recipes.QRCodeGenerator import QRCodeGeneratorPage from recipes.ObjectInpainting import ObjectInpaintingPage @@ -165,7 +171,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: selected_controlnet_model=request.selected_controlnet_model, prompt=request.text_prompt, num_outputs=request.num_outputs, - init_image=init_image, + init_images=init_image, num_inference_steps=request.quality, negative_prompt=request.negative_prompt, guidance_scale=request.guidance_scale, diff --git a/recipes/LetterWriter.py b/recipes/LetterWriter.py index 23265c4c2..992cb2920 100644 --- a/recipes/LetterWriter.py +++ b/recipes/LetterWriter.py @@ -8,12 +8,14 @@ from bots.models import Workflow from daras_ai.text_format import daras_ai_format_str from daras_ai_v2.base import BasePage +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.language_model import run_language_model from daras_ai_v2.text_training_data_widget import text_training_data, TrainingDataModel class LetterWriterPage(BasePage): title = "Letter Writer" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.LETTER_WRITER slug_versions = ["LetterWriter"] @@ -244,7 +246,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: body = None r = requests.request(method=method, url=url, headers=headers, json=body) - r.raise_for_status() + raise_for_status(r) response_json = r.json() state["response_json"] = response_json diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 68ac496a0..5fe08a09b 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -15,11 +15,12 @@ CREDITS_PER_MB = 2 -DEFAULT_LIPSYNC_GIF = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/91acbbde-5857-11ee-920a-02420a000194/lipsync%20audio.png.png" +DEFAULT_LIPSYNC_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7fc4d302-9402-11ee-98dc-02420a0001ca/Lip%20Sync.jpg.png" class LipsyncPage(BasePage): title = "Lip Syncing" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f33e6332-88d8-11ee-89f9-02420a000169/Lipsync%20TTS.png.png" workflow = Workflow.LIPSYNC slug_versions = ["Lipsync"] @@ -36,7 +37,7 @@ class ResponseModel(BaseModel): output_video: str def preview_image(self, state: dict) -> str | None: - return DEFAULT_LIPSYNC_GIF + return DEFAULT_LIPSYNC_META_IMG def render_form_v2(self) -> bool: st.file_uploader( diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 5a1eb10d4..60886c017 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -9,11 +9,12 @@ from daras_ai_v2.safety_checker import safety_checker from daras_ai_v2.loom_video_widget import youtube_video -DEFAULT_LIPSYNC_TTS_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/assets/lipsync_meta_img.gif" +DEFAULT_LIPSYNC_TTS_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/13b4d352-9456-11ee-8edd-02420a0001c7/Lipsync%20TTS.jpg.png" class LipsyncTTSPage(LipsyncPage, TextToSpeechPage): title = "Lipsync Video with Any Text" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/1acfa370-88d9-11ee-bf6c-02420a000166/Lipsync%20with%20audio%201.png.png" workflow = Workflow.LIPSYNC_TTS slug_versions = ["LipsyncTTS", "lipsync-maker"] @@ -52,6 +53,8 @@ class RequestModel(BaseModel): elevenlabs_model: str | None elevenlabs_stability: float | None elevenlabs_similarity_boost: float | None + elevenlabs_style: float | None + elevenlabs_speaker_boost: bool | None class ResponseModel(BaseModel): output_video: str @@ -59,13 +62,12 @@ class ResponseModel(BaseModel): def related_workflows(self) -> list: from recipes.VideoBots import VideoBotsPage from recipes.DeforumSD import DeforumSDPage - from recipes.CompareText2Img import CompareText2ImgPage return [ VideoBotsPage, TextToSpeechPage, DeforumSDPage, - CompareText2ImgPage, + LipsyncPage, ] def render_form_v2(self): diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index ff4c81fde..d84c59069 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -24,9 +24,12 @@ ) from daras_ai_v2.stable_diffusion import InpaintingModels +DEFAULT_OBJECT_INPAINTING_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/4bca6982-9456-11ee-bc12-02420a0001cc/Product%20photo%20backgrounds.jpg.png" + class ObjectInpaintingPage(BasePage): title = "Generate Product Photo Backgrounds" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f07b731e-88d9-11ee-a658-02420a000163/W.I.3.png.png" workflow = Workflow.OBJECT_INPAINTING slug_versions = ["ObjectInpainting", "product-photo-background-generator"] @@ -73,6 +76,9 @@ class ResponseModel(BaseModel): # diffusion_images: list[str] output_images: list[str] + def preview_image(self, state: dict) -> str | None: + return DEFAULT_OBJECT_INPAINTING_META_IMG + def related_workflows(self) -> list: from recipes.ImageSegmentation import ImageSegmentationPage from recipes.GoogleImageGen import GoogleImageGenPage diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index ed504bbbc..6b7260368 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -20,6 +20,7 @@ ) from daras_ai_v2.base import BasePage from daras_ai_v2.descriptions import prompting101 +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.img_model_settings_widgets import ( output_resolution_setting, img_model_settings, @@ -37,20 +38,22 @@ from recipes.EmailFaceInpainting import get_photo_for_email from recipes.SocialLookupEmail import get_profile_for_email from url_shortener.models import ShortenedURL +from daras_ai_v2.enum_selector_widget import enum_multiselect ATTEMPTS = 1 -DEFAULT_QR_CODE_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f09c8cfa-5393-11ee-a837-02420a000190/ai%20art%20qr%20codes1%201.png.png" +DEFAULT_QR_CODE_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a679a410-9456-11ee-bd77-02420a0001ce/QR%20Code.jpg.png" class QrSources(Enum): qr_code_data = "๐Ÿ”— URL or Text" - qr_code_vcard = "๐Ÿ‘ฉโ€๐Ÿฆฐ Contact Info" + qr_code_vcard = "๐Ÿ“‡ Contact Card" qr_code_file = "๐Ÿ“„ Upload File" - qr_code_input_image = "๐Ÿ“ท Existing QR Code" + qr_code_input_image = "๐Ÿ Existing QR Code" class QRCodeGeneratorPage(BasePage): title = "AI Art QR Code" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/03d6538e-88d5-11ee-ad97-02420a00016c/W.I.2.png.png" workflow = Workflow.QR_CODE slug_versions = ["art-qr-code", "qr", "qr-code"] @@ -59,6 +62,15 @@ class QRCodeGeneratorPage(BasePage): obj_scale=0.65, obj_pos_x=0.5, obj_pos_y=0.5, + image_prompt_controlnet_models=[ + ControlNetModels.sd_controlnet_canny.name, + ControlNetModels.sd_controlnet_depth.name, + ControlNetModels.sd_controlnet_tile.name, + ], + image_prompt_strength=0.3, + image_prompt_scale=1.0, + image_prompt_pos_x=0.5, + image_prompt_pos_y=0.5, ) def __init__(self, *args, **kwargs): @@ -75,6 +87,14 @@ class RequestModel(BaseModel): text_prompt: str negative_prompt: str | None + image_prompt: str | None + image_prompt_controlnet_models: list[ + typing.Literal[tuple(e.name for e in ControlNetModels)], ... + ] | None + image_prompt_strength: float | None + image_prompt_scale: float | None + image_prompt_pos_x: float | None + image_prompt_pos_y: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None selected_controlnet_model: list[ @@ -104,6 +124,8 @@ class ResponseModel(BaseModel): cleaned_qr_code: str def preview_image(self, state: dict) -> str | None: + if len(state.get("output_images") or []) > 0: + return state["output_images"][0] return DEFAULT_QR_CODE_META_IMG def related_workflows(self) -> list: @@ -122,7 +144,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.text_area( """ - ### ๐Ÿ‘ฉโ€๐Ÿ’ป Prompt + ##### ๐Ÿ‘ฉโ€๐Ÿ’ป Prompt Describe the subject/scene of the QR Code. Choose clear prompts and distinguishable visuals to ensure optimal readability. """, @@ -136,7 +158,7 @@ def render_form_v2(self): if st.session_state.get(key): st.session_state[qr_code_source_key] = key break - source = st.radio( + source = st.horizontal_radio( "", options=QrSources._member_names_, key=qr_code_source_key, @@ -184,6 +206,15 @@ def render_form_v2(self): 'A shortened URL enables the QR code to be more beautiful and less "QR-codey" with fewer blocky pixels.' ) + st.file_uploader( + """ + ##### ๐Ÿž๏ธ Reference Image *[optional]* + This image will be used as inspiration to blend with the QR Code. + """, + key="image_prompt", + accept=["image/*"], + ) + def validate_form_v2(self): assert st.session_state.get("text_prompt"), "Please provide a prompt" assert any( @@ -272,7 +303,7 @@ def render_settings(self): st.write( """ - ##### โŒ– Positioning + ##### โŒ– QR Positioning Use this to control where the QR code is placed in the image, and how big it should be. """, className="gui-input", @@ -320,25 +351,91 @@ def render_settings(self): color=255, ) + if st.session_state.get("image_prompt"): + st.write("---") + st.write( + """ + ##### ๐ŸŽจ Inspiration + Use this to control how the image prompt should influence the output. + """, + className="gui-input", + ) + st.slider( + "Inspiration Strength", + min_value=0.0, + max_value=1.0, + step=0.05, + key="image_prompt_strength", + ) + enum_multiselect( + ControlNetModels, + label="Control Net Models", + key="image_prompt_controlnet_models", + checkboxes=False, + allow_none=False, + ) + st.write( + """ + ##### โŒ– Reference Image Positioning + Use this to control where the reference image is placed, and how big it should be. + """, + className="gui-input", + ) + col1, _ = st.columns(2) + with col1: + image_prompt_scale = st.slider( + "Scale", + min_value=0.1, + max_value=1.0, + step=0.05, + key="image_prompt_scale", + ) + col1, col2 = st.columns(2, responsive=False) + with col1: + image_prompt_pos_x = st.slider( + "Position X", + min_value=0.0, + max_value=1.0, + step=0.05, + key="image_prompt_pos_x", + ) + with col2: + image_prompt_pos_y = st.slider( + "Position Y", + min_value=0.0, + max_value=1.0, + step=0.05, + key="image_prompt_pos_y", + ) + + img_cv2 = mask_cv2 = bytes_to_cv2_img( + requests.get(st.session_state["image_prompt"]).content, + ) + repositioning_preview_widget( + img_cv2=img_cv2, + mask_cv2=mask_cv2, + obj_scale=image_prompt_scale, + pos_x=image_prompt_pos_x, + pos_y=image_prompt_pos_y, + out_size=( + st.session_state["output_width"], + st.session_state["output_height"], + ), + color=255, + ) + def render_output(self): state = st.session_state self._render_outputs(state) def render_example(self, state: dict): - col1, col2 = st.columns(2) - with col1: - st.markdown( - f""" - ```text - {state.get("text_prompt", "")} - ``` - """ - ) - with col2: - self._render_outputs(state) + self._render_outputs(state, max_count=1) - def _render_outputs(self, state: dict): - for img in state.get("output_images", []): + def _render_outputs(self, state: dict, max_count: int | None = None): + output_images = list(state.get("output_images", [])) + if max_count: + output_images = output_images[:max_count] + for img in output_images: st.image(img) qr_code_data = ( state.get(QrSources.qr_code_data.name) @@ -376,11 +473,36 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["raw_images"] = raw_images = [] yield f"Running {Text2ImgModels[request.selected_model].value}..." + if isinstance(request.selected_controlnet_model, str): + request.selected_controlnet_model = [request.selected_controlnet_model] + init_images = [image] * len(request.selected_controlnet_model) + if request.image_prompt: + image_prompt = bytes_to_cv2_img(requests.get(request.image_prompt).content) + repositioned_image_prompt, _ = reposition_object( + orig_img=image_prompt, + orig_mask=image_prompt, + out_size=(request.output_width, request.output_height), + out_obj_scale=request.image_prompt_scale, + out_pos_x=request.image_prompt_pos_x, + out_pos_y=request.image_prompt_pos_y, + color=255, + ) + request.image_prompt = upload_file_from_bytes( + "repositioned_image_prompt.png", + cv2_img_to_bytes(repositioned_image_prompt), + ) + init_images += [request.image_prompt] * len( + request.image_prompt_controlnet_models + ) + request.selected_controlnet_model += request.image_prompt_controlnet_models + request.controlnet_conditioning_scale += [ + request.image_prompt_strength + ] * len(request.image_prompt_controlnet_models) state["output_images"] = controlnet( selected_model=request.selected_model, selected_controlnet_model=request.selected_controlnet_model, prompt=request.text_prompt, - init_image=image, + init_images=init_images, num_outputs=request.num_outputs, num_inference_steps=request.quality, negative_prompt=request.negative_prompt, @@ -436,8 +558,8 @@ def vcard_form(*, key: str) -> VCARD: ) if vcard.email and st.button( - "Import other contact info from my email - magic!", - className="link-button", + "Import other contact info from my email - magic!", + type="link", ): imported_vcard = get_vcard_from_email(vcard.email) if not imported_vcard or not imported_vcard.format_name: @@ -474,8 +596,11 @@ def vcard_form(*, key: str) -> VCARD: st.session_state.setdefault("__vcard_data__urls_text", "\n".join(vcard.urls or [])) vcard.urls = ( st.text_area( - "Link(s)", - placeholder="https://www.gooey.ai\nhttps://farmer.chat", + """ + Website Links + *([calend.ly](https://calend.ly) works great!)* + """, + placeholder="https://www.gooey.ai\nhttps://calend.ly/seanblagsvedt", key="__vcard_data__urls_text", ) .strip() @@ -492,11 +617,6 @@ def vcard_form(*, key: str) -> VCARD: vcard.gender = st.text_input( "Gender", key="__vcard_data__gender", placeholder="F" ) - vcard.calendar_url = st.text_input( - "Calendar Link ([calend.ly](calend.ly))", - key="__vcard_data__calendar_url", - placeholder="https://calendar.google.com/calendar/u/0/r", - ) vcard.note = st.text_area( "Notes", key="__vcard_data__note", @@ -611,7 +731,7 @@ def generate_qr_code(qr_code_data: str) -> np.ndarray: def download_qr_code_data(url: str) -> str: r = requests.get(url) - r.raise_for_status() + raise_for_status(r) img = bytes_to_cv2_img(r.content, greyscale=True) return extract_qr_code_data(img) diff --git a/recipes/RelatedQnA.py b/recipes/RelatedQnA.py index 7e2e71271..efddba5ba 100644 --- a/recipes/RelatedQnA.py +++ b/recipes/RelatedQnA.py @@ -16,7 +16,7 @@ from recipes.GoogleGPT import GoogleGPTPage from recipes.RelatedQnADoc import render_qna_outputs -DEFAULT_SEO_CONTENT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/9b415768-5393-11ee-a837-02420a000190/RQnA%20SEO%20content%201.png.png" +DEFAULT_SEO_CONTENT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/cbd2c94e-9456-11ee-a95e-02420a0001cc/People%20also%20ask.jpg.png" class RelatedGoogleGPTResponse(GoogleGPTPage.ResponseModel): @@ -25,6 +25,7 @@ class RelatedGoogleGPTResponse(GoogleGPTPage.ResponseModel): class RelatedQnAPage(BasePage): title = "Generate โ€œPeople Also Askโ€ SEO Content " + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/37b0ba22-88d6-11ee-b549-02420a000167/People%20also%20ask.png.png" workflow = Workflow.RELATED_QNA_MAKER slug_versions = ["related-qna-maker"] diff --git a/recipes/RelatedQnADoc.py b/recipes/RelatedQnADoc.py index 9f3f3d11e..7ebc87080 100644 --- a/recipes/RelatedQnADoc.py +++ b/recipes/RelatedQnADoc.py @@ -24,6 +24,7 @@ class RelatedDocSearchResponse(DocSearchPage.ResponseModel): class RelatedQnADocPage(BasePage): title = '"People Also Ask" Answers from a Doc' + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.RELATED_QNA_MAKER_DOC slug_versions = ["related-qna-maker-doc"] @@ -154,4 +155,4 @@ def render_qna_outputs(state, height, show_count=None): {"output_text": output_text, "references": references}, height ) render_sources_widget(references) - st.write("
") + st.html("
") diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index 674760eb8..a82424ae1 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -14,6 +14,7 @@ from bots.models import Workflow from recipes.GoogleGPT import GoogleSearchMixin from daras_ai_v2.base import BasePage +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS from daras_ai_v2.functional import map_parallel from daras_ai_v2.language_model import ( @@ -36,7 +37,7 @@ KEYWORDS_SEP = re.compile(r"[\n,]") STOP_SEQ = "$" * 10 -SEO_SUMMARY_DEFAULT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/assets/seo.png" +SEO_SUMMARY_DEFAULT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/13d3ab1e-9457-11ee-98a6-02420a0001c9/SEO.jpg.png" BANNED_HOSTS = [ # youtube generally returns garbage @@ -56,6 +57,7 @@ class SEOSummaryPage(BasePage): title = "Create a perfect SEO-optimized Title & Paragraph" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/85f38b42-88d6-11ee-ad97-02420a00016c/Create%20SEO%20optimized%20content%20option%202.png.png" workflow = Workflow.SEO_SUMMARY slug_versions = ["SEOSummary", "seo-paragraph-generator"] @@ -445,6 +447,6 @@ def _call_summarize_url(url: str) -> (str, str): headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, timeout=EXTERNAL_REQUEST_TIMEOUT_SEC, ) - r.raise_for_status() + raise_for_status(r) doc = readability.Document(r.text) return doc.title(), doc.summary() diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index 55e2a9f5d..554a74940 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -17,11 +17,12 @@ from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.pt import PromptTree -DEFAULT_SMARTGPT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/e02d1582-538a-11ee-9d7b-02420a000194/smartgpt%201.png.png" +DEFAULT_SMARTGPT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/3d71b434-9457-11ee-8edd-02420a0001c7/Smart%20GPT.jpg.png" class SmartGPTPage(BasePage): title = "SmartGPT" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ffd24ad8-88d7-11ee-a658-02420a000163/SmartGPT.png.png" workflow = Workflow.SMART_GPT slug_versions = ["SmartGPT"] price = 20 diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index 243c9fa2d..082fdcdaa 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -9,21 +9,23 @@ from daras_ai.text_format import daras_ai_format_str from daras_ai_v2 import settings from daras_ai_v2.base import BasePage +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.language_model import run_language_model, LargeLanguageModels from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.redis_cache import redis_cache_decorator email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" -DEFAULT_SOCIAL_LOOKUP_EMAIL_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/assets/email%20ver%202.png" +DEFAULT_SOCIAL_LOOKUP_EMAIL_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6729ea44-9457-11ee-bd77-02420a0001ce/Profile%20look%20up%20gpt%20email.jpg.png" class SocialLookupEmailPage(BasePage): title = "Profile Lookup + GPT3 for AI-Personalized Emails" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/5fbd475a-88d7-11ee-aac9-02420a00016b/personalized%20email.png.png" workflow = Workflow.SOCIAL_LOOKUP_EMAIL slug_versions = ["SocialLookupEmail", "email-writer-with-profile-lookup"] sane_defaults = { - "selected_model": LargeLanguageModels.text_davinci_003.name, + "selected_model": LargeLanguageModels.gpt_4.name, } class RequestModel(BaseModel): @@ -239,7 +241,7 @@ def get_profile_for_email(email_address) -> dict | None: "https://api.apollo.io/v1/people/match", json={"api_key": settings.APOLLO_API_KEY, "email": email_address}, ) - r.raise_for_status() + raise_for_status(r) person = r.json().get("person") if not person: diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 074462257..77776ddf5 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -14,7 +14,7 @@ num_outputs_setting, ) -DEFAULT_TEXT2AUDIO_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ddc6e894-538b-11ee-a837-02420a000190/text2audio1%201.png.png" +DEFAULT_TEXT2AUDIO_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/85cf8ea4-9457-11ee-bd77-02420a0001ce/Text%20guided%20audio.jpg.png" class Text2AudioModels(Enum): @@ -28,6 +28,7 @@ class Text2AudioModels(Enum): class Text2AudioPage(BasePage): title = "Text guided audio generator" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a4481d58-88d9-11ee-aa86-02420a000165/Text%20guided%20audio%20generator.png.png" workflow = Workflow.TEXT_2_AUDIO slug_versions = ["text2audio"] diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index 0df5d1d3d..c37a8eb5f 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -12,6 +12,7 @@ from daras_ai.image_input import upload_file_from_bytes, storage_blob_for from daras_ai_v2 import settings from daras_ai_v2.base import BasePage +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.gpu_server import GpuEndpoints, call_celery_task_outfile from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.text_to_speech_settings_widgets import ( @@ -22,11 +23,12 @@ TextToSpeechProviders, ) -DEFAULT_TTS_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/assets/cropped_tts_compare_meta_img.gif" +DEFAULT_TTS_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a73181ce-9457-11ee-8edd-02420a0001c7/Voice%20generators.jpg.png" class TextToSpeechPage(BasePage): title = "Compare AI Voice Generators" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/3621e11a-88d9-11ee-b549-02420a000167/Compare%20AI%20voice%20generators.png.png" workflow = Workflow.TEXT_TO_SPEECH slug_versions = [ "TextToSpeech", @@ -70,6 +72,8 @@ class RequestModel(BaseModel): elevenlabs_model: str | None elevenlabs_stability: float | None elevenlabs_similarity_boost: float | None + elevenlabs_style: float | None + elevenlabs_speaker_boost: bool | None class ResponseModel(BaseModel): audio_url: str @@ -198,7 +202,7 @@ def run(self, state: dict): "pace": pace, }, ) - response.raise_for_status() + raise_for_status(response) file_uuid = json.loads(response.text)["uuid"] while True: data = requests.get( @@ -267,6 +271,14 @@ def run(self, state: dict): stability = state.get("elevenlabs_stability", 0.5) similarity_boost = state.get("elevenlabs_similarity_boost", 0.75) + voice_settings = dict( + stability=stability, similarity_boost=similarity_boost + ) + if voice_model == "eleven_multilingual_v2": + voice_settings["style"] = state.get("elevenlabs_style", 0.0) + voice_settings["speaker_boost"] = state.get( + "elevenlabs_speaker_boost", True + ) response = requests.post( f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}", @@ -277,13 +289,10 @@ def run(self, state: dict): json={ "text": text, "model_id": voice_model, - "voice_settings": { - "stability": stability, - "similarity_boost": similarity_boost, - }, + "voice_settings": voice_settings, }, ) - response.raise_for_status() + raise_for_status(response) yield "Uploading Audio file..." state["audio_url"] = upload_file_from_bytes( diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 6822f89a0..17d05e770 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1,6 +1,7 @@ +import json +import mimetypes import os import os.path -import re import typing from django.db.models import QuerySet @@ -17,12 +18,23 @@ run_google_translate, google_translate_language_selector, ) -from daras_ai_v2.base import BasePage, MenuTabs, StateKeys +from daras_ai_v2.azure_doc_extract import ( + azure_form_recognizer, +) +from daras_ai_v2.base import BasePage, MenuTabs +from daras_ai_v2.bot_integration_widgets import ( + general_integration_settings, + broadcast_input, + render_bot_test_link, +) from daras_ai_v2.doc_search_settings_widgets import ( doc_search_settings, document_uploader, ) -from daras_ai_v2.glossary import glossary_input +from daras_ai_v2.enum_selector_widget import enum_multiselect +from daras_ai_v2.field_render import field_title_desc +from daras_ai_v2.functions import LLMTools +from daras_ai_v2.glossary import glossary_input, validate_glossary_document from daras_ai_v2.language_model import ( run_language_model, calc_gpt_tokens, @@ -35,6 +47,10 @@ CHATML_ROLE_USER, CHATML_ROLE_SYSTEM, model_max_tokens, + get_entry_images, + get_entry_text, + format_chat_entry, + SUPERSCRIPT, ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.lipsync_settings_widgets import lipsync_settings @@ -42,7 +58,13 @@ from daras_ai_v2.prompt_vars import render_prompt_vars, prompt_vars_widget from daras_ai_v2.query_generator import generate_final_search_query from daras_ai_v2.query_params import gooey_get_query_params -from daras_ai_v2.search_ref import apply_response_template, parse_refs, CitationStyles +from daras_ai_v2.query_params_util import extract_query_params +from daras_ai_v2.search_ref import ( + parse_refs, + CitationStyles, + apply_response_formattings_prefix, + apply_response_formattings_suffix, +) from daras_ai_v2.text_output_widget import text_output from daras_ai_v2.text_to_speech_settings_widgets import ( TextToSpeechProviders, @@ -58,97 +80,37 @@ from recipes.TextToSpeech import TextToSpeechPage from url_shortener.models import ShortenedURL -DEFAULT_COPILOT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/c8b24b0c-538a-11ee-a1a3-02420a00018d/meta%20tags1%201.png.png" +DEFAULT_COPILOT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f454d64a-9457-11ee-b6d5-02420a0001cb/Copilot.jpg.png" -BOT_SCRIPT_RE = re.compile( - # start of line - r"^" - # name of bot / user - r"([\w\ \t]+)" - # colon - r"\:\ ", - flags=re.M, -) +# BOT_SCRIPT_RE = re.compile( +# # start of line +# r"^" +# # name of bot / user +# r"([\w\ \t]{3,30})" +# # colon +# r"\:\ ", +# flags=re.M, +# ) SAFETY_BUFFER = 100 -def show_landbot_widget(): - landbot_url = st.session_state.get("landbot_url") - if not landbot_url: - st.html("", **{"data-landbot-config-url": ""}) - return +def exec_tool_call(call: dict): + tool_name = call["function"]["name"] + tool = LLMTools[tool_name] + yield f"๐Ÿ›  {tool.label}..." + kwargs = json.loads(call["function"]["arguments"]) + return tool.fn(**kwargs) - f = furl(landbot_url) - config_path = os.path.join(f.host, *f.path.segments[:2]) - config_url = f"https://storage.googleapis.com/{config_path}/index.json" - st.html( - # language=HTML - """ - - """, - **{"data-landbot-config-url": config_url}, - ) - - -def parse_script(bot_script: str) -> (str, list[ConversationEntry]): - # run regex to find scripted messages in script text - script_matches = list(BOT_SCRIPT_RE.finditer(bot_script)) - # extract system message from script - system_message = bot_script - if script_matches: - system_message = system_message[: script_matches[0].start()] - system_message = system_message.strip() - # extract pre-scripted messages from script - scripted_msgs: list[ConversationEntry] = [] - for idx in range(len(script_matches)): - match = script_matches[idx] - try: - next_match = script_matches[idx + 1] - except IndexError: - next_match_start = None - else: - next_match_start = next_match.start() - if (len(script_matches) - idx) % 2 == 0: - role = CHATML_ROLE_USER - else: - role = CHATML_ROLE_ASSISTANT - scripted_msgs.append( - { - "role": role, - "display_name": match.group(1).strip(), - "content": bot_script[match.end() : next_match_start].strip(), - } - ) - return system_message, scripted_msgs +class ReplyButton(typing.TypedDict): + id: str + title: str class VideoBotsPage(BasePage): title = "Copilot for your Enterprise" # "Create Interactive Video Bots" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/8c014530-88d4-11ee-aac9-02420a00016b/Copilot.png.png" workflow = Workflow.VIDEO_BOTS slug_versions = ["video-bots", "bots", "copilot"] @@ -191,9 +153,15 @@ class VideoBotsPage(BasePage): } class RequestModel(BaseModel): - input_prompt: str bot_script: str | None + input_prompt: str + input_images: list[str] | None + input_documents: list[str] | None + + # conversation history/context + messages: list[ConversationEntry] | None + # tts settings tts_provider: typing.Literal[ tuple(e.name for e in TextToSpeechProviders) @@ -215,6 +183,11 @@ class RequestModel(BaseModel): selected_model: typing.Literal[ tuple(e.name for e in LargeLanguageModels) ] | None + document_model: str | None = Field( + title="๐Ÿฉป Photo / Document Intelligence", + description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? " + "(via [Azure](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/how-to-guides/use-sdk-rest-api?view=doc-intel-3.1.0&tabs=linux&pivots=programming-language-rest-api))", + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None @@ -228,9 +201,6 @@ class RequestModel(BaseModel): face_padding_left: int | None face_padding_right: int | None - # conversation history/context - messages: list[ConversationEntry] | None - # doc search task_instructions: str | None query_instructions: str | None @@ -246,25 +216,33 @@ class RequestModel(BaseModel): citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None use_url_shortener: bool | None - user_language: str | None + user_language: str | None = Field( + title="๐Ÿ”  User Language", + description="If provided, the copilot will translate user messages to English and the copilot's response back to the selected language.", + ) # llm_language: str | None = "en" <-- implicit since this is hardcoded everywhere in the code base (from facebook and bots to slack and copilot etc.) input_glossary_document: str | None = Field( title="Input Glossary", description=""" -Translation Glossary for User Langauge -> LLM Language (English) +Translation Glossary for User Langauge -> LLM Language (English) """, ) output_glossary_document: str | None = Field( title="Output Glossary", description=""" -Translation Glossary for LLM Language (English) -> User Langauge +Translation Glossary for LLM Language (English) -> User Langauge """, ) variables: dict[str, typing.Any] | None + tools: list[LLMTools] | None = Field( + title="๐Ÿ› ๏ธ Tools", + description="Give your copilot superpowers by giving it access to tools. Powered by [Function calling](https://platform.openai.com/docs/guides/function-calling).", + ) + class ResponseModel(BaseModel): - final_prompt: str + final_prompt: str | list[ConversationEntry] output_text: list[str] @@ -282,7 +260,11 @@ class ResponseModel(BaseModel): # doc search references: list[SearchReference] | None final_search_query: str | None - final_keyword_query: str | None + final_keyword_query: str | list[str] | None + + # function calls + output_documents: list[str] | None + reply_buttons: list[ReplyButton] | None def preview_image(self, state: dict) -> str | None: return DEFAULT_COPILOT_META_IMG @@ -310,20 +292,20 @@ def get_submit_container_props(self): def render_description(self): st.write( """ -Have you ever wanted to create a bot that you could talk to about anything? Ever wanted to create your own https://dara.network/RadBots or https://Farmer.CHAT? This is how. +Have you ever wanted to create a bot that you could talk to about anything? Ever wanted to create your own https://dara.network/RadBots or https://Farmer.CHAT? This is how. + +This workflow takes a dialog LLM prompt describing your character, a collection of docs & links and optional an video clip of your botโ€™s face and voice settings. -This workflow takes a dialog LLM prompt describing your character, a collection of docs & links and optional an video clip of your botโ€™s face and voice settings. - -We use all these to build a bot that anyone can speak to about anything and you can host directly in your own site or app, or simply connect to your Facebook, WhatsApp or Instagram page. +We use all these to build a bot that anyone can speak to about anything and you can host directly in your own site or app, or simply connect to your Facebook, WhatsApp or Instagram page. How It Works: -1. Appends the user's question to the bottom of your dialog script. +1. Appends the user's question to the bottom of your dialog script. 2. Sends the appended script to OpenAIโ€™s GPT3 asking it to respond to the question in the style of your character 3. Synthesizes your character's response as audio using your voice settings (using Google Text-To-Speech or Uberduck) 4. Lip syncs the face video clip to the voice clip 5. Shows the resulting video to the user -PS. This is the workflow that we used to create RadBots - a collection of Turing-test videobots, authored by leading international writers, singers and playwrights - and really inspired us to create Gooey.AI so that every person and organization could create their own fantastic characters, in any personality of their choosing. It's also the workflow that powers https://Farmer.CHAT and was demo'd at the UN General Assembly in April 2023 as a multi-lingual WhatsApp bot for Indian, Ethiopian and Kenyan farmers. +PS. This is the workflow that we used to create RadBots - a collection of Turing-test videobots, authored by leading international writers, singers and playwrights - and really inspired us to create Gooey.AI so that every person and organization could create their own fantastic characters, in any personality of their choosing. It's also the workflow that powers https://Farmer.CHAT and was demo'd at the UN General Assembly in April 2023 as a multi-lingual WhatsApp bot for Indian, Ethiopian and Kenyan farmers. """ ) @@ -331,7 +313,7 @@ def render_form_v2(self): st.text_area( """ ##### ๐Ÿ“ Prompt - High-level system instructions to the copilot + optional example conversations between the bot and the user. + High-level system instructions. """, key="bot_script", height=300, @@ -351,6 +333,14 @@ def render_form_v2(self): "keyword_instructions", ) + def validate_form_v2(self): + input_glossary = st.session_state.get("input_glossary_document", "") + output_glossary = st.session_state.get("output_glossary_document", "") + if input_glossary: + validate_glossary_document(input_glossary) + if output_glossary: + validate_glossary_document(output_glossary) + def render_usage_guide(self): youtube_video("-j2su1r8pEg") @@ -375,31 +365,39 @@ def render_settings(self): st.write("---") doc_search_settings(keyword_instructions_allowed=True) st.write("---") - language_model_settings() + + language_model_settings(show_document_model=True) st.write("---") google_translate_language_selector( - """ - ##### ๐Ÿ”  User Language - If provided, the copilot will translate user messages to English and the copilot's response back to the selected language. - """, + f"##### {field_title_desc(self.RequestModel, 'user_language')}", key="user_language", ) + enable_glossary = st.checkbox( + "๐Ÿ“– Customize with Glossary", + value=bool( + st.session_state.get("input_glossary_document") + or st.session_state.get("output_glossary_document") + ), + ) st.markdown( """ - ###### ๐Ÿ“– Customize with Glossary Provide a glossary to customize translation and improve accuracy of domain-specific terms. - If not specified or invalid, no glossary will be used. Read about the expected format [here](https://docs.google.com/document/d/1TwzAvFmFYekloRKql2PXNPIyqCbsHRL8ZtnWkzAYrh8/edit?usp=sharing). + If not specified or invalid, no glossary will be used. Read about the expected format [here](https://docs.google.com/document/d/1TwzAvFmFYekloRKql2PXNPIyqCbsHRL8ZtnWkzAYrh8/edit?usp=sharing). """ ) - glossary_input( - f"##### {self.RequestModel.__fields__['input_glossary_document'].field_info.title}\n{self.RequestModel.__fields__['input_glossary_document'].field_info.description or ''}", - key="input_glossary_document", - ) - glossary_input( - f"##### {self.RequestModel.__fields__['output_glossary_document'].field_info.title}\n{self.RequestModel.__fields__['output_glossary_document'].field_info.description or ''}", - key="output_glossary_document", - ) + if enable_glossary: + glossary_input( + f"##### {field_title_desc(self.RequestModel, 'input_glossary_document')}", + key="input_glossary_document", + ) + glossary_input( + f"##### {field_title_desc(self.RequestModel, 'output_glossary_document')}", + key="output_glossary_document", + ) + else: + st.session_state["input_glossary_document"] = None + st.session_state["output_glossary_document"] = None st.write("---") if not "__enable_audio" in st.session_state: @@ -425,13 +423,20 @@ def render_settings(self): st.file_uploader( """ #### ๐Ÿ‘ฉโ€๐Ÿฆฐ Input Face - Upload a video/image that contains faces to use - *Recommended - mp4 / mov / png / jpg / gif* + Upload a video/image that contains faces to use + *Recommended - mp4 / mov / png / jpg / gif* """, key="input_face", ) lipsync_settings() + st.write("---") + enum_multiselect( + enum_cls=LLMTools, + label="##### " + field_title_desc(self.RequestModel, "tools"), + key="tools", + ) + def fields_to_save(self) -> [str]: fields = super().fields_to_save() + ["landbot_url"] if "elevenlabs_api_key" in fields: @@ -458,81 +463,32 @@ def render_example(self, state: dict): st.write(truncate_text_words(output_text[0], maxlen=200)) def render_output(self): + # chat window with st.div(className="pb-3"): - with st.div( - className="pb-1", - style=dict( - maxHeight="80vh", - overflowY="scroll", - display="flex", - flexDirection="column-reverse", - border="1px solid #c9c9c9", - ), - ): - with msg_container_widget(CHATML_ROLE_ASSISTANT): - output_text = st.session_state.get("output_text", []) - output_video = st.session_state.get("output_video", []) - output_audio = st.session_state.get("output_audio", []) - if output_text: - st.write(f"**Assistant**") - for idx, text in enumerate(output_text): - st.write(text) - try: - st.video(output_video[idx], autoplay=True) - except IndexError: - try: - st.audio(output_audio[idx]) - except IndexError: - pass - - input_prompt = st.session_state.get("input_prompt") - if input_prompt: - with msg_container_widget(CHATML_ROLE_USER): - st.write(f"**User** \\\n{input_prompt}") - - for entry in reversed(st.session_state.get("messages", [])): - with msg_container_widget(entry["role"]): - display_name = entry.get("display_name") or entry["role"] - display_name = display_name.capitalize() - st.write(f'**{display_name}** \\\n{entry["content"]}') - - with st.div( - className="px-3 pt-3 d-flex gap-1", - style=dict(background="rgba(239, 239, 239, 0.6)"), - ): - with st.div(className="flex-grow-1"): - new_input = st.text_area( - "", placeholder="Send a message", height=50 - ) - - if st.button("โœˆ Send", style=dict(height="3.2rem")): - messsages = st.session_state.get("messages", []) - raw_input_text = st.session_state.get("raw_input_text") or "" - raw_output_text = (st.session_state.get("raw_output_text") or [""])[ - 0 - ] - if raw_input_text and raw_output_text: - messsages += [ - { - "role": CHATML_ROLE_USER, - "content": raw_input_text, - }, - { - "role": CHATML_ROLE_ASSISTANT, - "content": raw_output_text, - }, - ] - st.session_state["messages"] = messsages - st.session_state["input_prompt"] = new_input - self.on_submit() - + chat_list_view() + ( + pressed_send, + new_input, + new_input_images, + new_input_documents, + ) = chat_input_view() + + if pressed_send: + self.on_send(new_input, new_input_images, new_input_documents) + + # clear chat inputs if st.button("๐Ÿ—‘๏ธ Clear"): st.session_state["messages"] = [] st.session_state["input_prompt"] = "" + st.session_state["input_images"] = None + st.session_state["new_input_documents"] = None st.session_state["raw_input_text"] = "" self.clear_outputs() + st.session_state["final_keyword_query"] = "" + st.session_state["final_search_query"] = "" st.experimental_rerun() + # render sources references = st.session_state.get("references", []) if not references: return @@ -545,6 +501,43 @@ def render_output(self): label_visibility="collapsed", ) + def on_send( + self, + new_input: str, + new_input_images: list[str], + new_input_documents: list[str], + ): + prev_input = st.session_state.get("raw_input_text") or "" + prev_output = (st.session_state.get("raw_output_text") or [""])[0] + prev_input_images = st.session_state.get("input_images") + prev_input_documents = st.session_state.get("input_documents") + + if (prev_input or prev_input_images or prev_input_documents) and prev_output: + # append previous input to the history + st.session_state["messages"] = st.session_state.get("messages", []) + [ + format_chat_entry( + role=CHATML_ROLE_USER, + content=prev_input, + images=prev_input_images, + ), + format_chat_entry( + role=CHATML_ROLE_ASSISTANT, + content=prev_output, + ), + ] + + # add new input to the state + if new_input_documents: + filenames = ", ".join( + furl(url.strip("/")).path.segments[-1] for url in new_input_documents + ) + new_input = f"Files: {filenames}\n\n{new_input}" + st.session_state["input_prompt"] = new_input + st.session_state["input_images"] = new_input_images or None + st.session_state["input_documents"] = new_input_documents or None + + self.on_submit() + def render_steps(self): if st.session_state.get("tts_provider"): st.video(st.session_state.get("input_face"), caption="Input Face") @@ -557,9 +550,15 @@ def render_steps(self): final_keyword_query = st.session_state.get("final_keyword_query") if final_keyword_query: - st.text_area( - "**Final Keyword Query**", value=final_keyword_query, disabled=True - ) + if isinstance(final_keyword_query, list): + st.write("**Final Keyword Query**") + st.json(final_keyword_query) + else: + st.text_area( + "**Final Keyword Query**", + value=str(final_keyword_query), + disabled=True, + ) references = st.session_state.get("references", []) if references: @@ -568,11 +567,11 @@ def render_steps(self): final_prompt = st.session_state.get("final_prompt") if final_prompt: - text_output( - "**Final Prompt**", - value=final_prompt, - height=300, - ) + if isinstance(final_prompt, str): + text_output("**Final Prompt**", value=final_prompt, height=300) + else: + st.write("**Final Prompt**") + st.json(final_prompt) for idx, text in enumerate(st.session_state.get("raw_output_text", [])): st.text_area( @@ -629,16 +628,29 @@ def run(self, state: dict) -> typing.Iterator[str | None]: """ user_input = request.input_prompt.strip() - if not user_input: + if not (user_input or request.input_images or request.input_documents): return model = LargeLanguageModels[request.selected_model] is_chat_model = model.is_chat_model() saved_msgs = request.messages.copy() bot_script = request.bot_script + ocr_texts = [] + if request.document_model and (request.input_images or request.input_documents): + yield "Running Azure Form Recognizer..." + for url in (request.input_images or []) + (request.input_documents or []): + ocr_text = ( + azure_form_recognizer(url, model_id="prebuilt-read") + .get("content", "") + .strip() + ) + if not ocr_text: + continue + ocr_texts.append(ocr_text) + # translate input text if request.user_language and request.user_language != "en": - yield f"Translating input to english..." + yield f"Translating Input to English..." user_input = run_google_translate( texts=[user_input], source_language=request.user_language, @@ -646,8 +658,20 @@ def run(self, state: dict) -> typing.Iterator[str | None]: glossary_url=request.input_glossary_document, )[0] + if ocr_texts: + yield f"Translating Image Text to English..." + ocr_texts = run_google_translate( + texts=ocr_texts, + source_language="auto", + target_language="en", + ) + for text in ocr_texts: + user_input = f"Exracted Text: {text!r}\n\n{user_input}" + # parse the bot script - system_message, scripted_msgs = parse_script(bot_script) + # system_message, scripted_msgs = parse_script(bot_script) + system_message = bot_script.strip() + scripted_msgs = [] # consturct the system prompt if system_message: @@ -657,41 +681,39 @@ def run(self, state: dict) -> typing.Iterator[str | None]: else: system_prompt = None - # get user/bot display names - try: - bot_display_name = scripted_msgs[-1]["display_name"] - except IndexError: - bot_display_name = CHATML_ROLE_ASSISTANT - try: - user_display_name = scripted_msgs[-2]["display_name"] - except IndexError: - user_display_name = CHATML_ROLE_USER - - # construct user prompt + # # get user/bot display names + # try: + # bot_display_name = scripted_msgs[-1]["display_name"] + # except IndexError: + # bot_display_name = CHATML_ROLE_ASSISTANT + # try: + # user_display_name = scripted_msgs[-2]["display_name"] + # except IndexError: + # user_display_name = CHATML_ROLE_USER + + # save raw input for reference state["raw_input_text"] = user_input - user_prompt = { - "role": CHATML_ROLE_USER, - "display_name": user_display_name, - "content": user_input, - } # if documents are provided, run doc search on the saved msgs and get back the references references = None if request.documents: # formulate the search query as a history of all the messages - query_msgs = saved_msgs + [user_prompt] + query_msgs = saved_msgs + [ + format_chat_entry(role=CHATML_ROLE_USER, content=user_input) + ] clip_idx = convo_window_clipper( query_msgs, model_max_tokens[model] // 2, sep=" " ) query_msgs = query_msgs[clip_idx:] chat_history = "\n".join( - f'{msg["role"]}: """{msg["content"]}"""' for msg in query_msgs + f'{entry["role"]}: """{get_entry_text(entry)}"""' + for entry in query_msgs ) query_instructions = (request.query_instructions or "").strip() if query_instructions: - yield "Generating search query..." + yield "Creating search query..." state["final_search_query"] = generate_final_search_query( request=request, instructions=query_instructions, @@ -700,17 +722,25 @@ def run(self, state: dict) -> typing.Iterator[str | None]: else: query_msgs.reverse() state["final_search_query"] = "\n---\n".join( - msg["content"] for msg in query_msgs + get_entry_text(entry) for entry in query_msgs ) keyword_instructions = (request.keyword_instructions or "").strip() if keyword_instructions: - yield "Extracting keywords..." - state["final_keyword_query"] = generate_final_search_query( - request=request, + yield "Finding keywords..." + k_request = request.copy() + # other models dont support JSON mode + k_request.selected_model = LargeLanguageModels.gpt_4_turbo.name + keyword_query = generate_final_search_query( + request=k_request, instructions=keyword_instructions, context={**state, "messages": chat_history}, + response_format_type="json_object", ) + if keyword_query and isinstance(keyword_query, dict): + keyword_query = list(keyword_query.values())[0] + state["final_keyword_query"] = keyword_query + # return # perform doc search references = yield from get_top_k_references( @@ -734,17 +764,22 @@ def run(self, state: dict) -> typing.Iterator[str | None]: if references: # add task instructions task_instructions = render_prompt_vars(request.task_instructions, state) - user_prompt["content"] = ( + user_input = ( references_as_prompt(references) + f"\n**********\n{task_instructions.strip()}\n**********\n" - + user_prompt["content"] + + user_input ) + # construct user prompt + user_prompt = format_chat_entry( + role=CHATML_ROLE_USER, content=user_input, images=request.input_images + ) + # truncate the history to fit the model's max tokens history_window = scripted_msgs + saved_msgs max_history_tokens = ( model_max_tokens[model] - - calc_gpt_tokens([system_prompt, user_prompt], is_chat_model=is_chat_model) + - calc_gpt_tokens([system_prompt, user_input], is_chat_model=is_chat_model) - request.max_tokens - SAFETY_BUFFER ) @@ -760,14 +795,11 @@ def run(self, state: dict) -> typing.Iterator[str | None]: prompt_messages.append( { "role": CHATML_ROLE_ASSISTANT, - "display_name": bot_display_name, "content": "", } ) - # final prompt to display - prompt = "\n".join(format_chatml_message(entry) for entry in prompt_messages) - state["final_prompt"] = prompt + state["final_prompt"] = prompt_messages # ensure input script is not too big max_allowed_tokens = model_max_tokens[model] - calc_gpt_tokens( @@ -777,9 +809,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: if max_allowed_tokens < 0: raise ValueError("Input Script is too long! Please reduce the script size.") - yield f"Running {model.value}..." + yield f"Summarizing with {model.value}..." if is_chat_model: - output_text = run_language_model( + chunks = run_language_model( model=request.selected_model, messages=[ {"role": s["role"], "content": s["content"]} @@ -789,9 +821,14 @@ def run(self, state: dict) -> typing.Iterator[str | None]: num_outputs=request.num_outputs, temperature=request.sampling_temperature, avoid_repetition=request.avoid_repetition, + tools=request.tools, + stream=True, ) else: - output_text = run_language_model( + prompt = "\n".join( + format_chatml_message(entry) for entry in prompt_messages + ) + chunks = run_language_model( model=request.selected_model, prompt=prompt, max_tokens=max_allowed_tokens, @@ -800,34 +837,59 @@ def run(self, state: dict) -> typing.Iterator[str | None]: temperature=request.sampling_temperature, avoid_repetition=request.avoid_repetition, stop=[CHATML_END_TOKEN, CHATML_START_TOKEN], + stream=True, ) - # save model response - state["raw_output_text"] = [ - "".join(snippet for snippet, _ in parse_refs(text, references)) - for text in output_text - ] - - # translate response text - if request.user_language and request.user_language != "en": - yield f"Translating response to {request.user_language}..." - output_text = run_google_translate( - texts=output_text, - source_language="en", - target_language=request.user_language, - glossary_url=request.output_glossary_document, - ) - state["raw_tts_text"] = [ + citation_style = ( + request.citation_style and CitationStyles[request.citation_style] + ) or None + for i, entries in enumerate(chunks): + if not entries: + continue + output_text = [entry["content"] for entry in entries] + if request.tools: + # output_text, tool_call_choices = output_text + state["output_documents"] = output_documents = [] + for call in entries[0].get("tool_calls") or []: + result = yield from exec_tool_call(call) + output_documents.append(result) + + # save model response + state["raw_output_text"] = [ "".join(snippet for snippet, _ in parse_refs(text, references)) for text in output_text ] - if references: - citation_style = ( - request.citation_style and CitationStyles[request.citation_style] - ) or None - apply_response_template(output_text, references, citation_style) + # translate response text + if request.user_language and request.user_language != "en": + yield f"Translating response to {request.user_language}..." + output_text = run_google_translate( + texts=output_text, + source_language="en", + target_language=request.user_language, + glossary_url=request.output_glossary_document, + ) + state["raw_tts_text"] = [ + "".join(snippet for snippet, _ in parse_refs(text, references)) + for text in output_text + ] + + if references: + all_refs_list = apply_response_formattings_prefix( + output_text, references, citation_style + ) + else: + all_refs_list = None - state["output_text"] = output_text + state["output_text"] = output_text + if all(entry.get("finish_reason") for entry in entries): + if all_refs_list: + apply_response_formattings_suffix( + all_refs_list, state["output_text"], citation_style + ) + finish_reason = entries[0]["finish_reason"] + yield f"Completed with {finish_reason=}" # avoid changing this message since it's used to detect end of stream + else: + yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." state["output_audio"] = [] state["output_video"] = [] @@ -837,7 +899,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: tts_state = dict(state) for text in state.get("raw_tts_text", state["raw_output_text"]): tts_state["text_prompt"] = text - yield from TextToSpeechPage().run(tts_state) + yield from TextToSpeechPage( + request=self.request, run_user=self.run_user + ).run(tts_state) state["output_audio"].append(tts_state["audio_url"]) if not request.input_face: @@ -845,7 +909,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: lip_state = dict(state) for audio_url in state["output_audio"]: lip_state["input_audio"] = audio_url - yield from LipsyncPage().run(lip_state) + yield from LipsyncPage(request=self.request, run_user=self.run_user).run( + lip_state + ) state["output_video"].append(lip_state["output_video"]) def get_tabs(self): @@ -891,7 +957,7 @@ def render_selected_tab(self, selected_tab): with col2: st.write( """ - + #### Part 2: [Interactive Chatbots for your Content - Part 2: Make your Chatbot - How to use Gooey.AI Workflows ](https://youtu.be/h817RolPjq4) """ @@ -900,7 +966,7 @@ def render_selected_tab(self, selected_tab): """
+
""", unsafe_allow_html=True, @@ -910,17 +976,17 @@ def render_selected_tab(self, selected_tab): st.text_input( "###### ๐Ÿค– [Landbot](https://landbot.io/) URL", key="landbot_url" ) - - show_landbot_widget() + show_landbot_widget() def messenger_bot_integration(self): - from routers.facebook import ig_connect_url, fb_connect_url - from routers.slack import slack_connect_url + from routers.facebook_api import ig_connect_url, fb_connect_url + from routers.slack_api import slack_connect_url + from recipes.VideoBotsStats import VideoBotsStatsPage st.markdown( # language=html f""" -

Connect this bot to your Website, Instagram, Whatsapp & More

+

Connect this bot to your Website, Instagram, Whatsapp & More

Your can connect your FB Messenger account and Slack Workspace here directly.
If you ping us at support@gooey.ai, we'll add your other accounts too! @@ -929,26 +995,26 @@ def messenger_bot_integration(self): --> @@ -965,63 +1031,84 @@ def messenger_bot_integration(self): if not integrations: return - current_sr = self.get_sr_from_query_params_dict(gooey_get_query_params()) + current_run, published_run = self.get_runs_from_query_params( + *extract_query_params(gooey_get_query_params()) + ) for bi in integrations: - is_connected = bi.saved_run == current_sr + is_connected = (bi.saved_run == current_run) or ( + ( + bi.saved_run + and published_run + and bi.saved_run.example_id == published_run.published_run_id + ) + or ( + bi.published_run + and published_run + and bi.published_run == published_run + ) + ) col1, col2, col3, *_ = st.columns([1, 1, 2]) with col1: favicon = Platform(bi.platform).get_favicon() - st.markdown( - f'  ' - f'{bi}' - if bi.saved_run - else f"{bi}", - unsafe_allow_html=True, - ) + if bi.published_run: + url = self.app_url( + example_id=bi.published_run.published_run_id, + tab_name=MenuTabs.paths[MenuTabs.integrations], + ) + elif bi.saved_run: + url = self.app_url( + run_id=bi.saved_run.run_id, + uid=bi.saved_run.uid, + example_id=bi.saved_run.example_id, + tab_name=MenuTabs.paths[MenuTabs.integrations], + ) + else: + url = None + if url: + href = f'{bi}' + else: + href = f"{bi}" + with st.div(className="mt-2"): + st.markdown( + f'  {href}', + unsafe_allow_html=True, + ) with col2: - pressed = st.button( + pressed_connect = st.button( "๐Ÿ”Œ๐Ÿ’”๏ธ Disconnect" if is_connected else "๐Ÿ–‡๏ธ Connect", key=f"btn_connect_{bi.id}", + type="tertiary", ) - with col3: - if bi.platform == Platform.SLACK: - with st.expander("๐Ÿ“จ Slack Settings"): - read_receipt_key = "slack_read_receipt_" + str(bi.id) - st.session_state.setdefault( - read_receipt_key, bi.slack_read_receipt_msg - ) - read_msg = st.text_input( - "Read Receipt (leave blank to disable)", - key=read_receipt_key, - placeholder=bi.slack_read_receipt_msg, - ) - bot_name_key = "slack_bot_name_" + str(bi.id) - st.session_state.setdefault(bot_name_key, bi.name) - bot_name = st.text_input( - "Channel Specific Bot Name (to be displayed in Slack)", - key=bot_name_key, - placeholder=bi.name, - ) - if st.button("Reset to Default"): - bi.name = st.session_state.get( - StateKeys.page_title, bi.name - ) - bi.slack_read_receipt_msg = BotIntegration._meta.get_field( - "slack_read_receipt_msg" - ).default - bi.save() - st.experimental_rerun() - if st.button("Update"): - bi.slack_read_receipt_msg = read_msg - bi.name = bot_name - bi.save() - st.experimental_rerun() - if not pressed: + render_bot_test_link(bi) + stats_url = furl(VideoBotsStatsPage.app_url(), args={"bi_id": bi.id}) + st.html( + f""" + ๐Ÿ“Š Analytics + """ + ) + if is_connected: + with col3, st.expander(f"๐Ÿ“จ {bi.get_platform_display()} Settings"): + if bi.platform == Platform.SLACK: + self.slack_specific_settings(bi) + general_integration_settings(bi) + if bi.platform in [Platform.SLACK, Platform.WHATSAPP]: + st.write("---") + broadcast_input(bi) + if not pressed_connect: continue if is_connected: bi.saved_run = None + bi.published_run = None else: - bi.saved_run = current_sr + # set bot language from state + bi.user_language = ( + st.session_state.get("user_language") or bi.user_language + ) + bi.saved_run = current_run + if published_run and published_run.saved_run_id == current_run.id: + bi.published_run = published_run + else: + bi.published_run = None if bi.platform == Platform.SLACK: from daras_ai_v2.slack_bot import send_confirmation_msg @@ -1031,6 +1118,220 @@ def messenger_bot_integration(self): st.write("---") + def slack_specific_settings(self, bi: BotIntegration): + if st.session_state.get(f"_bi_reset_{bi.id}"): + pr = self.get_current_published_run() + st.session_state[f"_bi_name_{bi.id}"] = ( + pr and pr.title + ) or self.get_recipe_title() + st.session_state[ + f"_bi_slack_read_receipt_msg_{bi.id}" + ] = BotIntegration._meta.get_field("slack_read_receipt_msg").default + + bi.slack_read_receipt_msg = st.text_input( + """ + ##### โœ… Read Receipt + This message is sent immediately after recieving a user message and replaced with the copilot's response once it's ready. + (leave blank to disable) + """, + placeholder=bi.slack_read_receipt_msg, + value=bi.slack_read_receipt_msg, + key=f"_bi_slack_read_receipt_msg_{bi.id}", + ) + bi.name = st.text_input( + """ + ##### ๐Ÿชช Channel Specific Bot Name + This is the name the bot will post as in this specific channel (to be displayed in Slack) + """, + placeholder=bi.name, + value=bi.name, + key=f"_bi_name_{bi.id}", + ) + st.caption("Enable streaming messages to Slack in real-time.") + + +def show_landbot_widget(): + landbot_url = st.session_state.get("landbot_url") + if not landbot_url: + st.html("", **{"data-landbot-config-url": ""}) + return + + f = furl(landbot_url) + config_path = os.path.join(f.host, *f.path.segments[:2]) + config_url = f"https://storage.googleapis.com/{config_path}/index.json" + + st.html( + # language=HTML + """ + + """, + **{"data-landbot-config-url": config_url}, + ) + + +# def parse_script(bot_script: str) -> (str, list[ConversationEntry]): +# # run regex to find scripted messages in script text +# script_matches = list(BOT_SCRIPT_RE.finditer(bot_script)) +# # extract system message from script +# system_message = bot_script +# if script_matches: +# system_message = system_message[: script_matches[0].start()] +# system_message = system_message.strip() +# # extract pre-scripted messages from script +# scripted_msgs: list[ConversationEntry] = [] +# for idx in range(len(script_matches)): +# match = script_matches[idx] +# try: +# next_match = script_matches[idx + 1] +# except IndexError: +# next_match_start = None +# else: +# next_match_start = next_match.start() +# if (len(script_matches) - idx) % 2 == 0: +# role = CHATML_ROLE_USER +# else: +# role = CHATML_ROLE_ASSISTANT +# scripted_msgs.append( +# { +# "role": role, +# "display_name": match.group(1).strip(), +# "content": bot_script[match.end() : next_match_start].strip(), +# } +# ) +# return system_message, scripted_msgs + + +def chat_list_view(): + # render a reversed list view + with st.div( + className="pb-1", + style=dict( + maxHeight="80vh", + overflowY="scroll", + display="flex", + flexDirection="column-reverse", + border="1px solid #c9c9c9", + ), + ): + with st.div(className="px-3"): + show_raw_msgs = st.checkbox("_Show Raw Output_") + # render the last output + with msg_container_widget(CHATML_ROLE_ASSISTANT): + if show_raw_msgs: + output_text = st.session_state.get("raw_output_text", []) + else: + output_text = st.session_state.get("output_text", []) + output_video = st.session_state.get("output_video", []) + output_audio = st.session_state.get("output_audio", []) + if output_text: + st.write(f"**Assistant**") + for idx, text in enumerate(output_text): + st.write(text) + try: + st.video(output_video[idx], autoplay=True) + except IndexError: + try: + st.audio(output_audio[idx]) + except IndexError: + pass + output_documents = st.session_state.get("output_documents", []) + if output_documents: + for doc in output_documents: + st.write(doc) + messages = st.session_state.get("messages", []).copy() + # add last input to history if present + if show_raw_msgs: + input_prompt = st.session_state.get("raw_input_text") + else: + input_prompt = st.session_state.get("input_prompt") + input_images = st.session_state.get("input_images") + if input_prompt or input_images: + messages += [ + format_chat_entry( + role=CHATML_ROLE_USER, content=input_prompt, images=input_images + ), + ] + # render history + for entry in reversed(messages): + with msg_container_widget(entry["role"]): + images = get_entry_images(entry) + text = get_entry_text(entry) + if text or images: + st.write(f"**{entry['role'].capitalize()}** \n{text}") + if images: + for im in images: + st.image(im, style={"maxHeight": "200px"}) + + +def chat_input_view() -> tuple[bool, str, list[str], list[str]]: + with st.div( + className="px-3 pt-3 d-flex gap-1", + style=dict(background="rgba(239, 239, 239, 0.6)"), + ): + show_uploader_key = "--show-file-uploader" + show_uploader = st.session_state.setdefault(show_uploader_key, False) + if st.button( + "๐Ÿ“Ž", + style=dict(height="3.2rem", backgroundColor="white"), + ): + show_uploader = not show_uploader + st.session_state[show_uploader_key] = show_uploader + + with st.div(className="flex-grow-1"): + new_input = st.text_area("", placeholder="Send a message", height=50) + + pressed_send = st.button("โœˆ Send", style=dict(height="3.2rem")) + + if show_uploader: + uploaded_files = st.file_uploader("", accept_multiple_files=True) + new_input_images = [] + new_input_documents = [] + for f in uploaded_files: + mime_type = mimetypes.guess_type(f)[0] or "" + if mime_type.startswith("image/"): + new_input_images.append(f) + else: + new_input_documents.append(f) + else: + new_input_images = None + new_input_documents = None + + return pressed_send, new_input, new_input_images, new_input_documents + + +def msg_container_widget(role: str): + return st.div( + className="px-3 py-1 pt-2", + style=dict( + background="rgba(239, 239, 239, 0.6)" + if role == CHATML_ROLE_USER + else "#fff", + ), + ) + def convo_window_clipper( window: list[ConversationEntry], @@ -1047,14 +1348,3 @@ def convo_window_clipper( ): return i + step return 0 - - -def msg_container_widget(role: str): - return st.div( - className="px-3 py-1 pt-2", - style=dict( - background="rgba(239, 239, 239, 0.6)" - if role == CHATML_ROLE_USER - else "#fff", - ), - ) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py new file mode 100644 index 000000000..c94b847f2 --- /dev/null +++ b/recipes/VideoBotsStats.py @@ -0,0 +1,673 @@ +from daras_ai_v2.base import BasePage, MenuTabs +import gooey_ui as st +from furl import furl + +from app_users.models import AppUser + +from bots.models import ( + Workflow, + Platform, + BotIntegration, + Conversation, + Message, + Feedback, + ConversationQuerySet, + FeedbackQuerySet, + MessageQuerySet, +) +from datetime import datetime, timedelta +import pandas as pd +import plotly.graph_objects as go +import base64 +from daras_ai_v2.language_model import ( + CHATML_ROLE_ASSISTANT, + CHATML_ROLE_USER, +) +from recipes.VideoBots import VideoBotsPage +from django.db.models.functions import ( + TruncMonth, + TruncDay, + TruncWeek, + TruncYear, + Concat, +) +from django.db.models import Count + +ID_COLUMNS = [ + "conversation__fb_page_id", + "conversation__ig_account_id", + "conversation__wa_phone_number", + "conversation__slack_user_id", +] + + +class VideoBotsStatsPage(BasePage): + title = "Copilot Analytics" # "Create Interactive Video Bots" + slug_versions = ["analytics", "stats"] + workflow = ( + Workflow.VIDEO_BOTS + ) # this is a hidden page, so this isn't used but type checking requires a workflow + + def _get_current_app_url(self): + # this is overwritten to include the query params in the copied url for the share button + args = dict(self.request.query_params) + return furl(self.app_url(), args=args).tostr() + + def show_title_breadcrumb_share(self, run_title, run_url, bi): + with st.div(className="d-flex justify-content-between mt-4"): + with st.div(className="d-lg-flex d-block align-items-center"): + with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): + with st.breadcrumbs(): + st.breadcrumb_item( + VideoBotsPage.title, + link_to=VideoBotsPage.app_url(), + className="text-muted", + ) + st.breadcrumb_item( + run_title, + link_to=run_url, + className="text-muted", + ) + st.breadcrumb_item( + "Integrations", + link_to=VideoBotsPage().get_tab_url(MenuTabs.integrations), + ) + + author = ( + AppUser.objects.filter(uid=bi.billing_account_uid).first() + or self.request.user + ) + self.render_author( + author, + show_as_link=self.is_current_user_admin(), + ) + + with st.div(className="d-flex align-items-center"): + with st.div(className="d-flex align-items-start right-action-icons"): + self._render_social_buttons(show_button_text=True) + + st.markdown(f"# ๐Ÿ“Š {bi.name} Analytics") + + def render(self): + self.setup_render() + + if not self.request.user or self.request.user.is_anonymous: + st.write("**Please Login to view stats for your bot integrations**") + return + if not self.request.user.is_paying and not self.is_current_user_admin(): + st.write( + "**Please upgrade to a paid plan to view stats for your bot integrations**" + ) + return + if self.is_current_user_admin(): + bot_integrations = BotIntegration.objects.all().order_by( + "platform", "-created_at" + ) + else: + bot_integrations = BotIntegration.objects.filter( + billing_account_uid=self.request.user.uid + ).order_by("platform", "-created_at") + + if bot_integrations.count() == 0: + st.write( + "**Please connect a bot to a platform to view stats for your bot integrations or login to an account with connected bot integrations**" + ) + return + + allowed_bids = [bi.id for bi in bot_integrations] + bid = self.request.query_params.get("bi_id", allowed_bids[0]) + if int(bid) not in allowed_bids: + bid = allowed_bids[0] + bi = BotIntegration.objects.get(id=bid) + run_title, run_url = self.parse_run_info(bi) + + self.show_title_breadcrumb_share(run_title, run_url, bi) + + col1, col2 = st.columns([1, 2]) + + with col1: + conversations, messages = self.calculate_overall_stats( + bid, bi, run_title, run_url + ) + + ( + start_date, + end_date, + view, + factor, + trunc_fn, + ) = self.render_date_view_inputs() + + df = self.calculate_stats_binned_by_time( + bi, start_date, end_date, factor, trunc_fn + ) + + if df.empty or "date" not in df.columns: + st.write("No data to show yet.") + return + + with col2: + self.plot_graphs(view, df) + + st.write("---") + st.session_state.setdefault("details", self.request.query_params.get("details")) + details = st.horizontal_radio( + "### Details", + options=[ + "All Conversations", + "All Messages", + "Feedback Positive", + "Feedback Negative", + "Answered Successfully", + "Answered Unsuccessfully", + ], + key="details", + ) + + if details == "All Conversations": + options = [ + "Messages", + "Correct Answers", + "Thumbs up", + "Name", + ] + elif details == "Feedback Positive": + options = [ + "Name", + "Rating", + "Question Answered", + ] + elif details == "Feedback Negative": + options = [ + "Name", + "Rating", + "Question Answered", + ] + elif details == "Answered Successfully": + options = [ + "Name", + "Sent", + ] + elif details == "Answered Unsuccessfully": + options = [ + "Name", + "Sent", + ] + else: + options = [] + + sort_by = None + if options: + query_sort_by = self.request.query_params.get("sort_by") + st.session_state.setdefault( + "sort_by", query_sort_by if query_sort_by in options else options[0] + ) + st.selectbox( + "Sort by", + options=options, + key="sort_by", + ) + sort_by = st.session_state["sort_by"] + + df = self.get_tabular_data( + bi, run_url, conversations, messages, details, sort_by, rows=500 + ) + + if not df.empty: + columns = df.columns.tolist() + st.data_table( + [columns] + + [ + [ + dict( + readonly=True, + displayData=str(df.iloc[idx, col]), + data=str(df.iloc[idx, col]), + ) + for col in range(len(columns)) + ] + for idx in range(min(500, len(df))) + ] + ) + # download as csv button + st.html("
") + if st.checkbox("Export"): + df = self.get_tabular_data( + bi, run_url, conversations, messages, details, sort_by + ) + csv = df.to_csv() + b64 = base64.b64encode(csv.encode()).decode() + st.html( + f'Download CSV File' + ) + st.caption("Includes full data (UI only shows first 500 rows)") + else: + st.write("No data to show yet.") + + # we store important inputs in the url so the user can return to the same view (e.g. bookmark it) + # this also allows them to share the url (once organizations are supported) + # and allows us to share a url to a specific view with users + new_url = furl( + self.app_url(), + args={ + "bi_id": bid, + "view": view, + "details": details, + "start_date": start_date.strftime("%Y-%m-%d"), + "end_date": end_date.strftime("%Y-%m-%d"), + "sort_by": sort_by, + }, + ).tostr() + st.change_url(new_url, self.request) + + def render_date_view_inputs(self): + start_of_year_date = datetime.now().replace(month=1, day=1) + st.session_state.setdefault( + "start_date", + self.request.query_params.get( + "start_date", start_of_year_date.strftime("%Y-%m-%d") + ), + ) + start_date: datetime = ( + st.date_input("Start date", key="start_date") or start_of_year_date + ) + st.session_state.setdefault( + "end_date", + self.request.query_params.get( + "end_date", datetime.now().strftime("%Y-%m-%d") + ), + ) + end_date: datetime = st.date_input("End date", key="end_date") or datetime.now() + st.session_state.setdefault( + "view", self.request.query_params.get("view", "Weekly") + ) + view = st.horizontal_radio( + "### View", + options=["Daily", "Weekly", "Monthly"], + key="view", + label_visibility="collapsed", + ) + factor = 1 + if view == "Weekly": + trunc_fn = TruncWeek + elif view == "Daily": + if end_date - start_date > timedelta(days=31): + st.write( + "**Note: Date ranges greater than 31 days show weekly averages in daily view**" + ) + factor = 1.0 / 7.0 + trunc_fn = TruncWeek + else: + trunc_fn = TruncDay + elif view == "Monthly": + trunc_fn = TruncMonth + else: + trunc_fn = TruncYear + return start_date, end_date, view, factor, trunc_fn + + def parse_run_info(self, bi): + saved_run = bi.get_active_saved_run() + run_title = ( + bi.published_run.title + if bi.published_run + else saved_run.page_title + if saved_run and saved_run.page_title + else "This Copilot Run" + if saved_run + else "No Run Connected" + ) + run_url = furl(saved_run.get_app_url()).tostr() if saved_run else "" + return run_title, run_url + + def calculate_overall_stats(self, bid, bi, run_title, run_url): + conversations: ConversationQuerySet = Conversation.objects.filter( + bot_integration__id=bid + ).order_by() # type: ignore + # due to things like personal convos for slack, each user can have multiple conversations + users = conversations.get_unique_users().order_by() + messages: MessageQuerySet = Message.objects.filter(conversation__in=conversations).order_by() # type: ignore + user_messages = messages.filter(role=CHATML_ROLE_USER).order_by() + bot_messages = messages.filter(role=CHATML_ROLE_ASSISTANT).order_by() + num_active_users_last_7_days = ( + user_messages.filter( + conversation__in=users, + created_at__gte=datetime.now() - timedelta(days=7), + ) + .distinct( + *ID_COLUMNS, + ) + .count() + ) + num_active_users_last_30_days = ( + user_messages.filter( + conversation__in=users, + created_at__gte=datetime.now() - timedelta(days=30), + ) + .distinct( + *ID_COLUMNS, + ) + .count() + ) + positive_feedbacks = Feedback.objects.filter( + message__conversation__bot_integration=bi, + rating=Feedback.Rating.RATING_THUMBS_UP, + ).count() + negative_feedbacks = Feedback.objects.filter( + message__conversation__bot_integration=bi, + rating=Feedback.Rating.RATING_THUMBS_DOWN, + ).count() + run_link = f'Powered By: {run_title}' + connection_detail = ( + bi.fb_page_name + or bi.wa_phone_number + or bi.ig_username + or (bi.slack_team_name + " - " + bi.slack_channel_name) + ) + st.markdown( + f""" + - Platform: {Platform(bi.platform).name.capitalize()} + - Created on: {bi.created_at.strftime("%b %d, %Y")} + - Last Updated: {bi.updated_at.strftime("%b %d, %Y")} + - {run_link} + - Connected to: {connection_detail} + * {users.count()} Users + * {num_active_users_last_7_days} Active Users (Last 7 Days) + * {num_active_users_last_30_days} Active Users (Last 30 Days) + * {conversations.count()} Conversations + * {user_messages.count()} User Messages + * {bot_messages.count()} Bot Messages + * {messages.count()} Total Messages + * {positive_feedbacks} Positive Feedbacks + * {negative_feedbacks} Negative Feedbacks + """, + unsafe_allow_html=True, + ) + + return conversations, messages + + def calculate_stats_binned_by_time( + self, bi, start_date, end_date, factor, trunc_fn + ): + messages_received = ( + Message.objects.filter( + created_at__date__gte=start_date, + created_at__date__lte=end_date, + conversation__bot_integration=bi, + role=CHATML_ROLE_USER, + ) + .order_by() + .annotate(date=trunc_fn("created_at")) + .values("date") + .annotate(Messages_Sent=Count("id")) + .annotate(Convos=Count("conversation_id", distinct=True)) + .annotate( + Senders=Count( + Concat( + *ID_COLUMNS, + ), + distinct=True, + ) + ) + .annotate(Unique_feedback_givers=Count("feedbacks", distinct=True)) + .values( + "date", + "Messages_Sent", + "Convos", + "Senders", + "Unique_feedback_givers", + ) + ) + + positive_feedbacks = ( + Feedback.objects.filter( + created_at__date__gte=start_date, + created_at__date__lte=end_date, + message__conversation__bot_integration=bi, + rating=Feedback.Rating.RATING_THUMBS_UP, + ) + .order_by() + .annotate(date=trunc_fn("created_at")) + .values("date") + .annotate(Pos_feedback=Count("id")) + .values("date", "Pos_feedback") + ) + + negative_feedbacks = ( + Feedback.objects.filter( + created_at__date__gte=start_date, + created_at__date__lte=end_date, + message__conversation__bot_integration=bi, + rating=Feedback.Rating.RATING_THUMBS_DOWN, + ) + .order_by() + .annotate(date=trunc_fn("created_at")) + .values("date") + .annotate(Neg_feedback=Count("id")) + .values("date", "Neg_feedback") + ) + + df = pd.DataFrame( + messages_received, + columns=[ + "date", + "Messages_Sent", + "Convos", + "Senders", + "Unique_feedback_givers", + ], + ) + df = df.merge( + pd.DataFrame(positive_feedbacks, columns=["date", "Pos_feedback"]), + how="outer", + left_on="date", + right_on="date", + ) + df = df.merge( + pd.DataFrame(negative_feedbacks, columns=["date", "Neg_feedback"]), + how="outer", + left_on="date", + right_on="date", + ) + df["Messages_Sent"] = df["Messages_Sent"] * factor + df["Convos"] = df["Convos"] * factor + df["Senders"] = df["Senders"] * factor + df["Unique_feedback_givers"] = df["Unique_feedback_givers"] * factor + df["Pos_feedback"] = df["Pos_feedback"] * factor + df["Neg_feedback"] = df["Neg_feedback"] * factor + df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] + df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] + df.fillna(0, inplace=True) + df = df.round(0).astype("int32", errors="ignore") + return df + + def plot_graphs(self, view, df): + fig = go.Figure( + data=[ + go.Bar( + x=list(df["date"]), + y=list(df["Messages_Sent"]), + text=list(df["Messages_Sent"]), + texttemplate="%{text}", + insidetextanchor="middle", + insidetextfont=dict(size=24), + hovertemplate="Messages Sent: %{y:.0f}", + ), + ], + layout=dict( + margin=dict(l=0, r=0, t=28, b=0), + yaxis=dict( + title="User Messages Sent", + range=[ + 0, + df["Messages_Sent"].max() + 10, + ], + tickvals=[ + *range( + int(df["Messages_Sent"].max() / 10), + int(df["Messages_Sent"].max()) + 1, + int(df["Messages_Sent"].max() / 10) + 1, + ) + ], + ), + title=dict( + text=f"{view} Messages Sent", + ), + height=300, + template="plotly_white", + ), + ) + st.plotly_chart(fig) + st.write("---") + fig = go.Figure( + data=[ + go.Scatter( + name="Senders", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Senders"]), + text=list(df["Senders"]), + hovertemplate="Active Users: %{y:.0f}", + ), + go.Scatter( + name="Conversations", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Convos"]), + text=list(df["Convos"]), + hovertemplate="Conversations: %{y:.0f}", + ), + go.Scatter( + name="Feedback Givers", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Unique_feedback_givers"]), + text=list(df["Unique_feedback_givers"]), + hovertemplate="Feedback Givers: %{y:.0f}", + ), + go.Scatter( + name="Positive Feedbacks", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Pos_feedback"]), + text=list(df["Pos_feedback"]), + hovertemplate="Positive Feedbacks: %{y:.0f}", + ), + go.Scatter( + name="Negative Feedbacks", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Neg_feedback"]), + text=list(df["Neg_feedback"]), + hovertemplate="Negative Feedbacks: %{y:.0f}", + ), + go.Scatter( + name="Messages per Convo", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Msgs_per_convo"]), + text=list(df["Msgs_per_convo"]), + hovertemplate="Messages per Convo: %{y:.0f}", + ), + go.Scatter( + name="Messages per Sender", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Msgs_per_user"]), + text=list(df["Msgs_per_user"]), + hovertemplate="Messages per User: %{y:.0f}", + ), + ], + layout=dict( + margin=dict(l=0, r=0, t=28, b=0), + yaxis=dict( + title="Unique Count", + ), + title=dict( + text=f"{view} Usage Trends", + ), + height=300, + template="plotly_white", + ), + ) + max_value = ( + df[ + [ + "Senders", + "Convos", + "Unique_feedback_givers", + "Pos_feedback", + "Neg_feedback", + "Msgs_per_convo", + "Msgs_per_user", + ] + ] + .max() + .max() + ) + if max_value < 10: + # only set fixed axis scale if the data doesn't have enough values to make it look good + # otherwise default behaviour is better since it adapts when user deselects a line + fig.update_yaxes( + range=[ + -0.2, + max_value + 10, + ], + tickvals=[ + *range( + int(max_value / 10), + int(max_value) + 10, + int(max_value / 10) + 1, + ) + ], + ) + st.plotly_chart(fig) + + def get_tabular_data( + self, bi, run_url, conversations, messages, details, sort_by, rows=10000 + ): + df = pd.DataFrame() + if details == "All Conversations": + df = conversations.to_df_format(row_limit=rows) + elif details == "All Messages": + df = messages.order_by("-created_at", "conversation__id").to_df_format( + row_limit=rows + ) + df = df.sort_values(by=["Name", "Sent"], ascending=False).reset_index() + df.drop(columns=["index"], inplace=True) + elif details == "Feedback Positive": + pos_feedbacks: FeedbackQuerySet = Feedback.objects.filter( + message__conversation__bot_integration=bi, + rating=Feedback.Rating.RATING_THUMBS_UP, + ) # type: ignore + df = pos_feedbacks.to_df_format(row_limit=rows) + df["Run URL"] = run_url + df["Bot"] = bi.name + elif details == "Feedback Negative": + neg_feedbacks: FeedbackQuerySet = Feedback.objects.filter( + message__conversation__bot_integration=bi, + rating=Feedback.Rating.RATING_THUMBS_DOWN, + ) # type: ignore + df = neg_feedbacks.to_df_format(row_limit=rows) + df["Run URL"] = run_url + df["Bot"] = bi.name + elif details == "Answered Successfully": + successful_messages: MessageQuerySet = Message.objects.filter( + conversation__bot_integration=bi, + analysis_result__contains={"Answered": True}, + ) # type: ignore + df = successful_messages.to_df_analysis_format(row_limit=rows) + df["Run URL"] = run_url + df["Bot"] = bi.name + elif details == "Answered Unsuccessfully": + unsuccessful_messages: MessageQuerySet = Message.objects.filter( + conversation__bot_integration=bi, + analysis_result__contains={"Answered": False}, + ) # type: ignore + df = unsuccessful_messages.to_df_analysis_format(row_limit=rows) + df["Run URL"] = run_url + df["Bot"] = bi.name + + if sort_by: + df.sort_values(by=[sort_by], ascending=False, inplace=True) + + return df diff --git a/recipes/asr.py b/recipes/asr.py index 172f57a7c..a297bf9c3 100644 --- a/recipes/asr.py +++ b/recipes/asr.py @@ -25,11 +25,12 @@ from daras_ai_v2.text_output_widget import text_outputs from recipes.DocSearch import render_documents -DEFAULT_ASR_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/3b98d906-538b-11ee-9c77-02420a000193/Speech1%201.png.png" +DEFAULT_ASR_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/1916825c-93fa-11ee-97be-02420a0001c8/Speech.jpg.png" class AsrPage(BasePage): title = "Speech Recognition & Translation" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/5fb7e5f6-88d9-11ee-aa86-02420a000165/Speech.png.png" workflow = Workflow.ASR slug_versions = ["asr", "speech"] diff --git a/recipes/embeddings_page.py b/recipes/embeddings_page.py index f2474f7ea..76efe16d9 100644 --- a/recipes/embeddings_page.py +++ b/recipes/embeddings_page.py @@ -39,6 +39,7 @@ class EmbeddingModels(models.TextChoices): class EmbeddingsPage(BasePage): title = "Embeddings" + explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.EMBEDDINGS slug_versions = ["embeddings", "embed", "text-embedings"] price = 1 diff --git a/routers/api.py b/routers/api.py index 245cc048b..e58ae3b2b 100644 --- a/routers/api.py +++ b/routers/api.py @@ -121,7 +121,7 @@ def run_api_json( return call_api( page_cls=page_cls, user=user, - request_body=page_request.dict(), + request_body=page_request.dict(exclude_unset=True), query_params=dict(request.query_params), ) @@ -175,7 +175,7 @@ def run_api_json_async( ret = call_api( page_cls=page_cls, user=user, - request_body=page_request.dict(), + request_body=page_request.dict(exclude_unset=True), query_params=dict(request.query_params), run_async=True, ) @@ -332,14 +332,11 @@ def submit_api_call( state = self.get_sr_from_query_params_dict(query_params).to_dict() if state is None: raise HTTPException(status_code=404) - # set sane defaults for k, v in self.sane_defaults.items(): state.setdefault(k, v) - - # remove None values & insert request data - request_dict = {k: v for k, v in request_body.items() if v is not None} - state.update(request_dict) + # insert request data + state.update(request_body) # set streamlit session state st.set_session_state(state) @@ -370,7 +367,7 @@ def build_api_response( run_async: bool, created_at: str, ): - web_url = str(furl(self.app_url(run_id=run_id, uid=uid))) + web_url = self.app_url(run_id=run_id, uid=uid) if run_async: status_url = str( furl(settings.API_BASE_URL, query_params=dict(run_id=run_id)) diff --git a/routers/billing.py b/routers/billing.py index 87403b673..5f2f2cb06 100644 --- a/routers/billing.py +++ b/routers/billing.py @@ -7,7 +7,7 @@ from furl import furl from starlette.datastructures import FormData -from app_users.models import AppUser +from app_users.models import AppUser, PaymentProvider from daras_ai_v2 import settings from daras_ai_v2.settings import templates @@ -239,9 +239,13 @@ def _handle_invoice_paid(uid: str, invoice_data): "get", "/v1/invoices/{invoice}/lines".format(invoice=quote_plus(invoice_id)), ) - amount = line_items.data[0].quantity user = AppUser.objects.get_or_create_from_uid(uid)[0] - user.add_balance(amount, invoice_id) + user.add_balance( + payment_provider=PaymentProvider.STRIPE, + invoice_id=invoice_id, + amount=line_items.data[0].quantity, + charged_amount=line_items.data[0].amount, + ) if not user.is_paying: user.is_paying = True user.save(update_fields=["is_paying"]) diff --git a/routers/broadcast_api.py b/routers/broadcast_api.py new file mode 100644 index 000000000..30fa9e0ca --- /dev/null +++ b/routers/broadcast_api.py @@ -0,0 +1,100 @@ +import typing + +from django.db.models import Q +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from pydantic import BaseModel, Field + +from app_users.models import AppUser +from auth.token_authentication import api_auth_header +from bots.models import BotIntegration +from bots.tasks import send_broadcast_msgs_chunked +from recipes.VideoBots import ReplyButton, VideoBotsPage + +app = APIRouter() + + +class BotBroadcastFilters(BaseModel): + wa_phone_number__in: list[str] | None = Field( + description="A list of WhatsApp phone numbers to broadcast to." + ) + slack_user_id__in: list[str] | None = Field( + description="A list of Slack user IDs to broadcast to." + ) + slack_user_name__icontains: str | None = Field( + description="Filter by the Slack user's name. Case insensitive." + ) + slack_channel_is_personal: bool | None = Field( + description="Filter by whether the Slack channel is personal. By default, will broadcast to both public and personal slack channels." + ) + + +class BotBroadcastRequestModel(BaseModel): + text: str = Field(description="Message to broadcast to all users") + audio: str | None = Field(description="Audio URL to send to all users") + video: str | None = Field(description="Video URL to send to all users") + documents: list[str] | None = Field(description="Video URL to send to all users") + buttons: list[ReplyButton] | None = Field( + description="Buttons to send to all users" + ) + filters: BotBroadcastFilters | None = Field( + description="Filters to select users to broadcast to. If not provided, will broadcast to all users of this bot." + ) + + +R = typing.TypeVar("R", bound=BotBroadcastRequestModel) + + +@app.post( + f"/v2/{VideoBotsPage.slug_versions[0]}/broadcast/send/", + operation_id=VideoBotsPage.slug_versions[0] + "__broadcast", + name=f"Send Broadcast Message", +) +@app.post( + f"/v2/{VideoBotsPage.slug_versions[0]}/broadcast/send", + include_in_schema=False, +) +def broadcast_api_json( + bot_request: BotBroadcastRequestModel, + user: AppUser = Depends(api_auth_header), + example_id: str | None = None, + run_id: str | None = None, +): + bi_qs = BotIntegration.objects.filter(billing_account_uid=user.uid) + if example_id: + bi_qs = bi_qs.filter( + Q(published_run__published_run_id=example_id) + | Q(saved_run__example_id=example_id) + ) + elif run_id: + bi_qs = bi_qs.filter(saved_run__run_id=run_id, saved_run__uid=user.uid) + else: + return HTTPException( + status_code=400, + detail="Must provide either example_id or run_id as a query parameter.", + ) + if not bi_qs.exists(): + return HTTPException( + status_code=404, + detail=f"Could not find a bot in your account with the given {example_id=} & {run_id=}. " + "Please use the same account that you used to create the bot.", + ) + + total = 0 + for bi in bi_qs: + convo_qs = bi.conversations.all() + if bot_request.filters: + convo_qs = convo_qs.filter(**bot_request.filters.dict(exclude_unset=True)) + total += convo_qs.count() + send_broadcast_msgs_chunked( + text=bot_request.text, + audio=bot_request.audio, + video=bot_request.video, + documents=bot_request.documents, + buttons=bot_request.buttons, + bi=bi, + convo_qs=convo_qs, + ) + + return {"status": "success", "count": total} diff --git a/routers/facebook.py b/routers/facebook_api.py similarity index 97% rename from routers/facebook.py rename to routers/facebook_api.py index ca98d65a9..af409c8fd 100644 --- a/routers/facebook.py +++ b/routers/facebook_api.py @@ -9,9 +9,10 @@ from bots.models import BotIntegration from daras_ai_v2 import settings, db +from daras_ai_v2.bots import _on_msg, request_json +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.facebook_bots import WhatsappBot, FacebookBot from daras_ai_v2.functional import map_parallel -from daras_ai_v2.bots import _on_msg, request_json router = APIRouter() @@ -61,7 +62,7 @@ def get_currently_connected_fb_pages(user_access_token): "fields": "id,name,access_token,instagram_business_account{id,username}", }, ) - r.raise_for_status() + raise_for_status(r) fb_pages = r.json()["data"] return fb_pages @@ -171,7 +172,7 @@ def _get_access_token_from_code(code: str) -> str: "code": code, }, ) - r.raise_for_status() + raise_for_status(r) return r.json()["access_token"] @@ -185,4 +186,4 @@ def _subscribe_to_page(fb_page: dict): "access_token": fb_page["access_token"], }, ) - r.raise_for_status() + raise_for_status(r) diff --git a/routers/paypal.py b/routers/paypal.py new file mode 100644 index 000000000..5e6ebb94e --- /dev/null +++ b/routers/paypal.py @@ -0,0 +1,137 @@ +import base64 + +import requests +from fastapi import APIRouter, Depends +from fastapi.requests import Request +from fastapi.responses import JSONResponse +from furl import furl + +from app_users.models import AppUser, PaymentProvider +from daras_ai_v2 import settings +from daras_ai_v2.bots import request_json +from daras_ai_v2.exceptions import raise_for_status +from routers.billing import available_subscriptions + +router = APIRouter() + + +def generate_auth_header() -> str: + """ + Generate an OAuth 2.0 access token for authenticating with PayPal REST APIs. + @see https://developer.paypal.com/api/rest/authentication/ + """ + assert ( + settings.PAYPAL_CLIENT_ID and settings.PAYPAL_SECRET + ), "Missing API Credentials" + auth = base64.b64encode( + (settings.PAYPAL_CLIENT_ID + ":" + settings.PAYPAL_SECRET).encode() + ).decode() + response = requests.post( + str(furl(settings.PAYPAL_BASE) / "v1/oauth2/token"), + data="grant_type=client_credentials", + headers={"Authorization": f"Basic {auth}"}, + ) + raise_for_status(response) + data = response.json() + access_token = data.get("access_token") + assert access_token, "Missing access token in response" + return f"Bearer " + access_token + + +# Create an order to start the transaction. +# @see https://developer.paypal.com/docs/api/orders/v2/#orders_create +@router.post("/__/paypal/orders/create/") +def create_order(request: Request, payload=Depends(request_json)): + if not request.user or request.user.is_anonymous: + return JSONResponse({}, status_code=401) + + lookup_key = "addon" + quantity = payload["quantity"] + unit_amount = ( + available_subscriptions[lookup_key]["stripe"]["price_data"]["unit_amount"] / 100 + ) + value = int(quantity * unit_amount) + + response = requests.post( + str(furl(settings.PAYPAL_BASE) / "v2/checkout/orders"), + headers={ + "Content-Type": "application/json", + "Authorization": generate_auth_header(), + # Uncomment one of these to force an error for negative testing (in sandbox mode only). Documentation: + # https://developer.paypal.com/tools/sandbox/negative-testing/request-headers/ + # "PayPal-Mock-Response": '{"mock_application_codes": "MISSING_REQUIRED_PARAMETER"}' + # "PayPal-Mock-Response": '{"mock_application_codes": "PERMISSION_DENIED"}' + # "PayPal-Mock-Response": '{"mock_application_codes": "INTERNAL_SERVER_ERROR"}' + }, + json={ + "intent": "CAPTURE", + "purchase_units": [ + { + "amount": { + "currency_code": "USD", + "value": str(value), + "breakdown": { + "item_total": { + "currency_code": "USD", + "value": str(value), + } + }, + }, + "items": [ + { + "name": "Top up Credits", + "quantity": str(quantity), + "unit_amount": { + "currency_code": "USD", + "value": str(unit_amount), + }, + } + ], + "custom_id": request.user.uid, + }, + ], + }, + ) + raise_for_status(response) + return JSONResponse(response.json(), response.status_code) + + +# Capture payment for the created order to complete the transaction. +# @see https://developer.paypal.com/docs/api/orders/v2/#orders_capture +@router.post("/__/paypal/orders/{order_id}/capture/") +def capture_order(order_id: str): + response = requests.post( + str(furl(settings.PAYPAL_BASE) / f"v2/checkout/orders/{order_id}/capture"), + headers={ + "Content-Type": "application/json", + "Authorization": generate_auth_header(), + # Uncomment one of these to force an error for negative testing (in sandbox mode only). Documentation: + # https://developer.paypal.com/tools/sandbox/negative-testing/request-headers/ + # "PayPal-Mock-Response": '{"mock_application_codes": "INSTRUMENT_DECLINED"}' + # "PayPal-Mock-Response": '{"mock_application_codes": "TRANSACTION_REFUSED"}' + # "PayPal-Mock-Response": '{"mock_application_codes": "INTERNAL_SERVER_ERROR"}' + }, + ) + _handle_invoice_paid(order_id) + return JSONResponse(response.json(), response.status_code) + + +def _handle_invoice_paid(order_id: str): + response = requests.get( + str(furl(settings.PAYPAL_BASE) / f"v2/checkout/orders/{order_id}"), + headers={"Authorization": generate_auth_header()}, + ) + raise_for_status(response) + order = response.json() + purchase_unit = order["purchase_units"][0] + uid = purchase_unit["payments"]["captures"][0]["custom_id"] + user = AppUser.objects.get_or_create_from_uid(uid)[0] + user.add_balance( + payment_provider=PaymentProvider.PAYPAL, + invoice_id=order_id, + amount=int(purchase_unit["items"][0]["quantity"]), + charged_amount=int(float(purchase_unit["amount"]["value"]) * 100), + ) + if not user.is_paying: + user.is_paying = True + user.save(update_fields=["is_paying"]) diff --git a/routers/root.py b/routers/root.py index e28947be9..e136840a9 100644 --- a/routers/root.py +++ b/routers/root.py @@ -21,20 +21,25 @@ import gooey_ui as st from app_users.models import AppUser +from bots.models import Workflow from daras_ai.image_input import upload_file_from_bytes, safe_filename from daras_ai_v2 import settings from daras_ai_v2.all_pages import all_api_pages, normalize_slug, page_slug_map +from daras_ai_v2.api_examples_widget import api_example_generator from daras_ai_v2.asr import FFMPEG_WAV_ARGS, check_wav_audio_format from daras_ai_v2.base import ( RedirectException, + get_example_request_body, ) from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_scripts from daras_ai_v2.db import FIREBASE_SESSION_COOKIE from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_content import build_meta_tags, raw_build_meta_tags +from daras_ai_v2.meta_preview_url import meta_preview_url from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.secrets_widget import secrets_widget from daras_ai_v2.settings import templates +from daras_ai_v2.tabs_widget import MenuTabs from routers.api import request_form_files app = APIRouter() @@ -195,7 +200,7 @@ def file_upload(request: Request, form_data: FormData = Depends(request_form_fil if content_type.startswith("image/"): with Image(blob=data) as img: - if img.format not in ["png", "jpeg", "jpg", "gif"]: + if img.format.lower() not in ["png", "jpeg", "jpg", "gif"]: img.format = "png" content_type = "image/png" filename += ".png" @@ -227,48 +232,141 @@ def explore_page(request: Request, json_data: dict = Depends(request_json)): return ret -@app.post("/api-keys/") -def api_keys(request: Request, json_data=Depends(request_json)): - def render_fn(): - st.write("---") - st.write("### ๐Ÿ” API keys", className="mt-2") - manage_api_keys(request.user) +@app.post("/api/") +def api_docs_page(request: Request, json_data: dict = Depends(request_json)): + ret = st.runner( + lambda: page_wrapper( + request=request, render_fn=lambda: _api_docs_page(request) + ), + **json_data, + ) + ret |= { + "meta": raw_build_meta_tags( + url=get_og_url_path(request), + title="Gooey.AI API Platform", + description="Explore resources, tutorials, API docs, and dynamic examples to get the most out of GooeyAI's developer platform.", + image=meta_preview_url( + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/e48d59be-aaee-11ee-b112-02420a000175/API%20Docs.png.png" + ), + ), + } + return ret - st.write("### ๐Ÿ›๏ธ Non-Gooey API keys", className="mt-2") - secrets_widget(request.user) - json_data.setdefault("state", {}) - return st.runner( - lambda: page_wrapper(request=request, render_fn=render_fn), - **json_data, +def _api_docs_page(request): + api_docs_url = str(furl(settings.API_BASE_URL) / "docs") + + st.markdown( + f""" +# Gooey.AI API Platform + +##### ๐Ÿ“– Introduction +You can interact with the API through HTTP requests from any language. + +If you're comfortable with OpenAPI specs, jump straight to our complete API + +##### ๐Ÿ” Authentication +The Gooey.AI API uses API keys for authentication. Visit the [API Keys](#api-keys) section to retrieve the API key you'll use in your requests. + +Remember that your API key is a secret! Do not share it with others or expose it in any client-side code (browsers, apps). Production requests must be routed through your own backend server where your API key can be securely loaded from an environment variable or key management service. + +All API requests should include your API key in an Authorization HTTP header as follows: + +```bash +Authorization: Bearer GOOEY_API_KEY +``` + """, + unsafe_allow_html=True, ) + st.write("---") + options = { + page_cls.workflow.value: page_cls().get_recipe_title() + for page_cls in all_api_pages + } + workflow = Workflow( + st.selectbox( + "##### โš• API Generator\nChoose a workflow to see how you can interact with it via the API", + options=options, + format_func=lambda x: options[x], + ) + ) + + st.write("###### ๐Ÿ“ค Example Request") + + include_all = st.checkbox("Show all fields") + as_async = st.checkbox("Run Async") + as_form_data = st.checkbox("Upload Files via Form Data") + + page = workflow.page_cls(request=request) + state = page.get_root_published_run().saved_run.to_dict() + request_body = get_example_request_body( + page.RequestModel, state, include_all=include_all + ) + response_body = page.get_example_response_body( + state, as_async=as_async, include_all=include_all + ) + + api_example_generator( + api_url=page._get_current_api_url(), + request_body=request_body, + as_form_data=as_form_data, + as_async=as_async, + ) + st.write("") + + st.write("###### ๐ŸŽ Example Response") + st.json(response_body, expanded=True) + + st.write("---") + with st.tag("a", id="api-keys"): + st.write("##### ๐Ÿ” API keys") + + if not page.request.user or page.request.user.is_anonymous: + st.write( + "**Please [Login](/login/?next=/api/) to generate the `$GOOEY_API_KEY`**" + ) + return + + manage_api_keys(page.request.user) + @app.post("/") @app.post("/{page_slug}/") -@app.post("/{page_slug}/{tab}/") +@app.post("/{page_slug}/{run_slug_or_tab}/") +@app.post("/{page_slug}/{run_slug_or_tab}/{tab}/") def st_page( request: Request, page_slug="", + run_slug_or_tab="", tab="", json_data: dict = Depends(request_json), ): + run_slug, tab = _extract_run_slug_and_tab(run_slug_or_tab, tab) + try: + selected_tab = MenuTabs.paths_reverse[tab] + except KeyError: + raise HTTPException(status_code=404) + try: page_cls = page_slug_map[normalize_slug(page_slug)] except KeyError: raise HTTPException(status_code=404) + # ensure the latest slug is used latest_slug = page_cls.slug_versions[-1] if latest_slug != page_slug: return RedirectResponse( - request.url.replace(path=os.path.join("/", latest_slug, tab, "")) + request.url.replace(path=os.path.join("/", latest_slug, run_slug, tab, "")) ) example_id, run_id, uid = extract_query_params(request.query_params) - page = page_cls(tab=tab, request=request, run_user=get_run_user(request, uid)) + page = page_cls( + tab=selected_tab, request=request, run_user=get_run_user(request, uid) + ) - state = json_data.setdefault("state", {}) + state = json_data.get("state", {}) if not state: db_state = page.get_sr_from_query_params(example_id, run_id, uid).to_dict() if db_state is not None: @@ -282,24 +380,11 @@ def st_page( ret = st.runner( lambda: page_wrapper(request, page.render), query_params=dict(request.query_params), - **json_data, + state=state, ) except RedirectException as e: return RedirectResponse(e.url, status_code=e.status_code) - # Canonical URLs should not include uid or run_id (don't index specific runs). - # In the case of examples, all tabs other than "Run" are duplicates of the page - # without the `example_id`, and so their canonical shouldn't include `example_id` - canonical_url = str( - furl( - str(settings.APP_BASE_URL), - query_params={"example_id": example_id} if not tab and example_id else {}, - ) - / latest_slug - / tab - / "/" # preserve trailing slash - ) - ret |= { "meta": build_meta_tags( url=get_og_url_path(request), @@ -309,11 +394,6 @@ def st_page( uid=uid, example_id=example_id, ) - + [dict(tagName="link", rel="canonical", href=canonical_url)] - # + [ - # dict(tagName="link", rel="icon", href="/static/favicon.ico"), - # dict(tagName="link", rel="stylesheet", href="/static/css/app.css"), - # ], } return ret @@ -355,3 +435,12 @@ def page_wrapper(request: Request, render_fn: typing.Callable, **kwargs): st.html(templates.get_template("footer.html").render(**context)) st.html(templates.get_template("login_scripts.html").render(**context)) + + +def _extract_run_slug_and_tab(run_slug_or_tab, tab) -> tuple[str, str]: + if run_slug_or_tab and tab: + return run_slug_or_tab, tab + elif run_slug_or_tab in MenuTabs.paths_reverse: + return "", run_slug_or_tab + else: + return run_slug_or_tab, "" diff --git a/routers/slack.py b/routers/slack_api.py similarity index 99% rename from routers/slack.py rename to routers/slack_api.py index 20e19fec6..4186a474e 100644 --- a/routers/slack.py +++ b/routers/slack_api.py @@ -14,6 +14,7 @@ from bots.tasks import create_personal_channels_for_all_members from daras_ai_v2 import settings from daras_ai_v2.bots import _on_msg, request_json, request_urlencoded_body +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.slack_bot import ( SlackBot, invite_bot_account_to_channel, @@ -83,7 +84,7 @@ def slack_connect_redirect(request: Request): ).url, auth=HTTPBasicAuth(settings.SLACK_CLIENT_ID, settings.SLACK_CLIENT_SECRET), ) - r.raise_for_status() + raise_for_status(r) print("> slack_connect_redirect:", r.text) data = r.json() @@ -255,7 +256,7 @@ def slack_connect_redirect_shortcuts( ).url, auth=HTTPBasicAuth(settings.SLACK_CLIENT_ID, settings.SLACK_CLIENT_SECRET), ) - res.raise_for_status() + raise_for_status(res) res = res.json() print("> slack_connect_redirect_shortcuts:", res) diff --git a/scripts/create_fixture.py b/scripts/create_fixture.py index 9482cef64..d22ec9c8f 100644 --- a/scripts/create_fixture.py +++ b/scripts/create_fixture.py @@ -2,23 +2,52 @@ from django.core import serializers -from bots.models import SavedRun +from app_users.models import AppUser +from bots.models import BotIntegration, PublishedRun def run(): - qs = SavedRun.objects.filter(run_id__isnull=True) with open("fixture.json", "w") as f: + objs = list(get_objects()) serializers.serialize( "json", - get_objects(qs), + objs, indent=2, stream=f, progress_output=sys.stdout, - object_count=qs.count(), + object_count=len(objs), ) -def get_objects(qs): - for obj in qs: - obj.parent = None +def get_objects(): + for pr in PublishedRun.objects.all(): + set_fk_null(pr.saved_run) + if pr.saved_run_id: + yield pr.saved_run + if pr.created_by_id: + yield pr.created_by + if pr.last_edited_by_id: + yield pr.last_edited_by + yield pr + + for version in pr.versions.all(): + set_fk_null(version.saved_run) + yield version.saved_run + if version.changed_by_id: + yield version.changed_by + yield version + + for obj in BotIntegration.objects.all(): + if not obj.saved_run_id: + continue + set_fk_null(obj.saved_run) + yield obj.saved_run + + yield AppUser.objects.get(uid=obj.billing_account_uid) yield obj + + +def set_fk_null(obj): + for field in obj._meta.get_fields(): + if field.is_relation and field.many_to_one: + setattr(obj, field.name, None) diff --git a/scripts/run_all_diffusion.py b/scripts/run_all_diffusion.py index eabb9b1f5..183260c5d 100644 --- a/scripts/run_all_diffusion.py +++ b/scripts/run_all_diffusion.py @@ -43,7 +43,7 @@ # guidance_scale=7, # ) # # r = requests.get(GpuEndpoints.sd_multi / "magic") -# # r.raise_for_status() +# # raise_for_status(r) # # img2img( # # selected_model=Img2ImgModels.sd_1_5.name, # # prompt=get_random_string(100, string.ascii_letters), @@ -108,7 +108,7 @@ selected_controlnet_model=controlnet_model.name, prompt=get_random_string(100, string.ascii_letters), num_outputs=4, - init_image=random_img, + init_images=random_img, num_inference_steps=1, guidance_scale=7, ), diff --git a/scripts/test_wa_msg_send.py b/scripts/test_wa_msg_send.py new file mode 100644 index 000000000..c31fca0af --- /dev/null +++ b/scripts/test_wa_msg_send.py @@ -0,0 +1,124 @@ +from time import sleep + +from daras_ai_v2.bots import _feedback_start_buttons +from daras_ai_v2.facebook_bots import WhatsappBot + + +def run(bot_number: str, user_number: str): + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="", + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Text With Buttons", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Audio with text", + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Audio + Video", + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + video="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6f019f2a-b714-11ee-82f3-02420a000172/gooey.ai%20lipsync.mp4#t=0.001", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + video="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6f019f2a-b714-11ee-82f3-02420a000172/gooey.ai%20lipsync.mp4#t=0.001", + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + video="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6f019f2a-b714-11ee-82f3-02420a000172/gooey.ai%20lipsync.mp4#t=0.001", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Video with text", + video="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6f019f2a-b714-11ee-82f3-02420a000172/gooey.ai%20lipsync.mp4#t=0.001", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Some Docs", + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Some Docs", + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Audio + Docs", + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + buttons=_feedback_start_buttons(), + ) diff --git a/server.py b/server.py index 414f55ff6..2d66bf7d7 100644 --- a/server.py +++ b/server.py @@ -25,7 +25,7 @@ SessionAuthBackend, ) from daras_ai_v2 import settings -from routers import billing, facebook, api, root, slack +from routers import billing, facebook_api, api, root, slack_api, paypal, broadcast_api import url_shortener.routers as url_shortener app = FastAPI(title="GOOEY.AI", docs_url=None, redoc_url="/docs") @@ -33,11 +33,13 @@ app.mount("/static", StaticFiles(directory="static"), name="static") app.include_router(api.app) +app.include_router(broadcast_api.app) app.include_router(billing.router, include_in_schema=False) -app.include_router(facebook.router, include_in_schema=False) -app.include_router(slack.router, include_in_schema=False) +app.include_router(facebook_api.router, include_in_schema=False) +app.include_router(slack_api.router, include_in_schema=False) app.include_router(root.app, include_in_schema=False) app.include_router(url_shortener.app, include_in_schema=False) +app.include_router(paypal.router, include_in_schema=False) app.add_middleware( CORSMiddleware, diff --git a/templates/account.html b/templates/account.html index f06ba7e53..2755facc1 100644 --- a/templates/account.html +++ b/templates/account.html @@ -2,6 +2,8 @@ {% block content %} +
-
\ No newline at end of file + diff --git a/templates/login_options.html b/templates/login_options.html index bc137935a..c287f13f0 100644 --- a/templates/login_options.html +++ b/templates/login_options.html @@ -13,7 +13,7 @@

Sign in to Gooey.AI

Loading...

-

๐Ÿ’ฐ Sign up via Email does not activate credits. Please choose one of the top options instead.

+

๐Ÿ’ฐ Sign up via non-Gmail accounts does not activate credits.

diff --git a/tests/test_public_endpoints.py b/tests/test_public_endpoints.py index 02006b92f..be18647d8 100644 --- a/tests/test_public_endpoints.py +++ b/tests/test_public_endpoints.py @@ -5,14 +5,14 @@ from bots.models import SavedRun from daras_ai_v2.all_pages import all_api_pages from daras_ai_v2.tabs_widget import MenuTabs -from routers import facebook -from routers.slack import slack_connect_redirect_shortcuts, slack_connect_redirect +from routers import facebook_api +from routers.slack_api import slack_connect_redirect_shortcuts, slack_connect_redirect from server import app client = TestClient(app) excluded_endpoints = [ - facebook.fb_webhook_verify.__name__, # gives 403 + facebook_api.fb_webhook_verify.__name__, # gives 403 slack_connect_redirect.__name__, slack_connect_redirect_shortcuts.__name__, "get_run_status", # needs query params diff --git a/url_shortener/admin.py b/url_shortener/admin.py index f4d6bb4c0..e7f102a7a 100644 --- a/url_shortener/admin.py +++ b/url_shortener/admin.py @@ -137,3 +137,6 @@ class VisitorClickInfoAdmin(admin.ModelAdmin): ordering = ["-created_at"] autocomplete_fields = ["shortened_url"] actions = [export_to_csv, export_to_excel] + readonly_fields = [ + "created_at", + ] diff --git a/url_shortener/routers.py b/url_shortener/routers.py index fba39ffa9..1b3991cf8 100644 --- a/url_shortener/routers.py +++ b/url_shortener/routers.py @@ -21,10 +21,11 @@ def url_shortener(hashid: str, request: Request): return Response(status_code=410, content="This link has expired") # increment the click count ShortenedURL.objects.filter(id=surl.id).update(clicks=F("clicks") + 1) - if surl.enable_analytics: - save_click_info.delay( - surl.id, request.client.host, request.headers.get("user-agent", "") - ) + # disable because iplist.cc is down + # if surl.enable_analytics: + # save_click_info.delay( + # surl.id, request.client.host, request.headers.get("user-agent", "") + # ) if surl.url: return RedirectResponse( url=surl.url, status_code=303 # because youtu.be redirects are 303