From 70555b5f5fcdacec4c399e9f0002f23e961ea09b Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Wed, 13 Dec 2023 08:30:29 -0800 Subject: [PATCH 01/85] Added cost table --- .../0049_alter_savedrun_workflow.py | 52 +++++++++++++++++++ cost/__init__.py | 0 cost/admin.py | 18 +++++++ cost/apps.py | 7 +++ cost/migrations/__init__.py | 0 cost/models.py | 31 +++++++++++ cost/tests.py | 3 ++ cost/views.py | 3 ++ daras_ai_v2/settings.py | 1 + 9 files changed, 115 insertions(+) create mode 100644 bots/migrations/0049_alter_savedrun_workflow.py create mode 100644 cost/__init__.py create mode 100644 cost/admin.py create mode 100644 cost/apps.py create mode 100644 cost/migrations/__init__.py create mode 100644 cost/models.py create mode 100644 cost/tests.py create mode 100644 cost/views.py diff --git a/bots/migrations/0049_alter_savedrun_workflow.py b/bots/migrations/0049_alter_savedrun_workflow.py new file mode 100644 index 000000000..82780828c --- /dev/null +++ b/bots/migrations/0049_alter_savedrun_workflow.py @@ -0,0 +1,52 @@ +# Generated by Django 4.2.5 on 2023-12-12 08:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0048_alter_messageattachment_url"), + ] + + operations = [ + migrations.AlterField( + model_name="savedrun", + name="workflow", + field=models.IntegerField( + choices=[ + (1, "Doc Search"), + (2, "Doc Summary"), + (3, "Google GPT"), + (4, "Copilot"), + (5, "Lipysnc + TTS"), + (6, "Text to Speech"), + (7, "Speech Recognition"), + (8, "Lipsync"), + (9, "Deforum Animation"), + (10, "Compare Text2Img"), + (11, "Text2Audio"), + (12, "Img2Img"), + (13, "Face Inpainting"), + (14, "Google Image Gen"), + (15, "Compare AI Upscalers"), + (16, "SEO Summary"), + (17, "Email Face Inpainting"), + (18, "Social Lookup Email"), + (19, "Object Inpainting"), + (20, "Image Segmentation"), + (21, "Compare LLM"), + (22, "Chyron Plant"), + (23, "Letter Writer"), + (24, "Smart GPT"), + (25, "AI QR Code"), + (26, "Doc Extract"), + (27, "Related QnA Maker"), + (28, "Related QnA Maker Doc"), + (29, "Embeddings"), + (30, "Bulk Runner"), + (31, "Bulk Evaluator"), + ], + default=4, + ), + ), + ] diff --git a/cost/__init__.py b/cost/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cost/admin.py b/cost/admin.py new file mode 100644 index 000000000..a7ae9e50f --- /dev/null +++ b/cost/admin.py @@ -0,0 +1,18 @@ +from django.contrib import admin +from cost import models + +# Register your models here. + + +@admin.register(models.Cost) +class CostAdmin(admin.ModelAdmin): + list_display = [ + "run_id", + "provider", + "model", + "param", + "notes", + "quantity", + "calculation_notes", + "cost", + ] diff --git a/cost/apps.py b/cost/apps.py new file mode 100644 index 000000000..72b24d71a --- /dev/null +++ b/cost/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class CostConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "cost" + verbose_name = "Cost" diff --git a/cost/migrations/__init__.py b/cost/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cost/models.py b/cost/models.py new file mode 100644 index 000000000..9207f4b94 --- /dev/null +++ b/cost/models.py @@ -0,0 +1,31 @@ +import requests +from django.db import models, IntegrityError, transaction +from django.utils import timezone +from firebase_admin import auth + +from bots.custom_fields import CustomURLField +from daras_ai.image_input import upload_file_from_bytes, guess_ext_from_response +from daras_ai_v2 import settings, db +from gooeysite.bg_db_conn import db_middleware + + +# class CostQuerySet(models.QuerySet): + + +class Cost(models.Model): + run_id = models.ForeignKey( + "bots.SavedRun", + on_delete=models.SET_NULL, + related_name="cost", + null=True, + default=None, + blank=True, + help_text="The run that was last saved by the user.", + ) + provider = models.CharField(max_length=255, default="", blank=True) + model = models.TextField("model", blank=True) + param = models.TextField("param", blank=True) + notes = models.TextField(default="", blank=True) + quantity = models.IntegerField(default=1) + calculation_notes = models.TextField(default="", blank=True) + cost = models.FloatField(default=0.0) diff --git a/cost/tests.py b/cost/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/cost/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/cost/views.py b/cost/views.py new file mode 100644 index 000000000..91ea44a21 --- /dev/null +++ b/cost/views.py @@ -0,0 +1,3 @@ +from django.shortcuts import render + +# Create your views here. diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index a13565115..12de870a7 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -55,6 +55,7 @@ "files", "url_shortener", "glossary_resources", + "cost", ] MIDDLEWARE = [ From ea4542b34b8a1a79769f6f06f884c3208c8c949e Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 14 Dec 2023 08:39:20 -0800 Subject: [PATCH 02/85] Initial table --- cost/migrations/0001_initial.py | 48 +++++++++++++++++++++++++++++++++ cost/models.py | 1 + 2 files changed, 49 insertions(+) create mode 100644 cost/migrations/0001_initial.py diff --git a/cost/migrations/0001_initial.py b/cost/migrations/0001_initial.py new file mode 100644 index 000000000..dddbf1dfc --- /dev/null +++ b/cost/migrations/0001_initial.py @@ -0,0 +1,48 @@ +# Generated by Django 4.2.5 on 2023-12-13 20:14 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ("bots", "0049_alter_savedrun_workflow"), + ] + + operations = [ + migrations.CreateModel( + name="Cost", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("provider", models.CharField(blank=True, default="", max_length=255)), + ("model", models.TextField(blank=True, verbose_name="model")), + ("param", models.TextField(blank=True, verbose_name="param")), + ("notes", models.TextField(blank=True, default="")), + ("quantity", models.IntegerField(default=1)), + ("calculation_notes", models.TextField(blank=True, default="")), + ("cost", models.FloatField(default=0.0)), + ( + "run_id", + models.ForeignKey( + blank=True, + default=None, + help_text="The run that was last saved by the user.", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="cost", + to="bots.savedrun", + ), + ), + ], + ), + ] diff --git a/cost/models.py b/cost/models.py index 9207f4b94..19e20cf05 100644 --- a/cost/models.py +++ b/cost/models.py @@ -29,3 +29,4 @@ class Cost(models.Model): quantity = models.IntegerField(default=1) calculation_notes = models.TextField(default="", blank=True) cost = models.FloatField(default=0.0) + input = models. From d57b5c1488432b2c994c62001b8842761e3d8d02 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 14 Dec 2023 12:41:20 -0800 Subject: [PATCH 03/85] Incorrect names --- app_users/models.py | 1 + cost/models.py | 32 -------------------- {cost => costs}/__init__.py | 0 {cost => costs}/admin.py | 6 ++-- {cost => costs}/apps.py | 6 ++-- {cost => costs}/migrations/0001_initial.py | 0 {cost => costs}/migrations/__init__.py | 0 costs/models.py | 35 ++++++++++++++++++++++ {cost => costs}/tests.py | 0 {cost => costs}/views.py | 0 daras_ai_v2/settings.py | 2 +- 11 files changed, 43 insertions(+), 39 deletions(-) delete mode 100644 cost/models.py rename {cost => costs}/__init__.py (100%) rename {cost => costs}/admin.py (72%) rename {cost => costs}/apps.py (55%) rename {cost => costs}/migrations/0001_initial.py (100%) rename {cost => costs}/migrations/__init__.py (100%) create mode 100644 costs/models.py rename {cost => costs}/tests.py (100%) rename {cost => costs}/views.py (100%) diff --git a/app_users/models.py b/app_users/models.py index 5e373f414..8594a44fb 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -225,6 +225,7 @@ class AppUserTransaction(models.Model): amount = models.IntegerField() end_balance = models.IntegerField() created_at = models.DateTimeField(editable=False, blank=True, default=timezone.now) + dollar_amt = models.DecimalField(default=decimal(0.0), blank=True) class Meta: verbose_name = "Transaction" diff --git a/cost/models.py b/cost/models.py deleted file mode 100644 index 19e20cf05..000000000 --- a/cost/models.py +++ /dev/null @@ -1,32 +0,0 @@ -import requests -from django.db import models, IntegrityError, transaction -from django.utils import timezone -from firebase_admin import auth - -from bots.custom_fields import CustomURLField -from daras_ai.image_input import upload_file_from_bytes, guess_ext_from_response -from daras_ai_v2 import settings, db -from gooeysite.bg_db_conn import db_middleware - - -# class CostQuerySet(models.QuerySet): - - -class Cost(models.Model): - run_id = models.ForeignKey( - "bots.SavedRun", - on_delete=models.SET_NULL, - related_name="cost", - null=True, - default=None, - blank=True, - help_text="The run that was last saved by the user.", - ) - provider = models.CharField(max_length=255, default="", blank=True) - model = models.TextField("model", blank=True) - param = models.TextField("param", blank=True) - notes = models.TextField(default="", blank=True) - quantity = models.IntegerField(default=1) - calculation_notes = models.TextField(default="", blank=True) - cost = models.FloatField(default=0.0) - input = models. diff --git a/cost/__init__.py b/costs/__init__.py similarity index 100% rename from cost/__init__.py rename to costs/__init__.py diff --git a/cost/admin.py b/costs/admin.py similarity index 72% rename from cost/admin.py rename to costs/admin.py index a7ae9e50f..4b3ed79a3 100644 --- a/cost/admin.py +++ b/costs/admin.py @@ -1,11 +1,11 @@ from django.contrib import admin -from cost import models +from costs import models # Register your models here. -@admin.register(models.Cost) -class CostAdmin(admin.ModelAdmin): +@admin.register(models.UsageCost) +class CostsAdmin(admin.ModelAdmin): list_display = [ "run_id", "provider", diff --git a/cost/apps.py b/costs/apps.py similarity index 55% rename from cost/apps.py rename to costs/apps.py index 72b24d71a..55bed0ca1 100644 --- a/cost/apps.py +++ b/costs/apps.py @@ -1,7 +1,7 @@ from django.apps import AppConfig -class CostConfig(AppConfig): +class CostsConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" - name = "cost" - verbose_name = "Cost" + name = "costs" + verbose_name = "Costs" diff --git a/cost/migrations/0001_initial.py b/costs/migrations/0001_initial.py similarity index 100% rename from cost/migrations/0001_initial.py rename to costs/migrations/0001_initial.py diff --git a/cost/migrations/__init__.py b/costs/migrations/__init__.py similarity index 100% rename from cost/migrations/__init__.py rename to costs/migrations/__init__.py diff --git a/costs/models.py b/costs/models.py new file mode 100644 index 000000000..bc75b8bc2 --- /dev/null +++ b/costs/models.py @@ -0,0 +1,35 @@ +from django.db import models + + +# class CostQuerySet(models.QuerySet): + + +class UsageCost(models.Model): + saved_run = models.ForeignKey( + "bots.SavedRun", + on_delete=models.CASCADE, + related_name="usage_costs", + null=True, + default=None, + blank=True, + help_text="The run that was last saved by the user.", + ) + + class Provider(models.IntegerChoices): + OpenAI = 1, "OpenAI" + GPT_4 = 2, "GPT-4" + dalle_e = 3, "dalle-e" + whisper = 4, "whisper" + GPT_3_5 = 5, "GPT-3.5" + + provider = models.IntegerField(choices=Provider.choices) + model = models.TextField("model", blank=True) + param = models.TextField("param", blank=True) # combine with quantity as JSON obj + notes = models.TextField(default="", blank=True) + quantity = models.JSONField(default=dict, blank=True) + calculation_notes = models.TextField(default="", blank=True) + dollar_amt = models.DecimalField( + max_digits=13, + decimal_places=8, + ) + created_at = models.DateTimeField(auto_now_add=True) diff --git a/cost/tests.py b/costs/tests.py similarity index 100% rename from cost/tests.py rename to costs/tests.py diff --git a/cost/views.py b/costs/views.py similarity index 100% rename from cost/views.py rename to costs/views.py diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 12de870a7..812ec32bc 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -55,7 +55,7 @@ "files", "url_shortener", "glossary_resources", - "cost", + "costs", ] MIDDLEWARE = [ From 956ebe14b09e05e7133bab91ad5c18607b3d8607 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Wed, 27 Dec 2023 09:07:35 -0800 Subject: [PATCH 04/85] Change names --- app_users/models.py | 1 - costs/admin.py | 6 +++--- costs/migrations/0001_initial.py | 27 +++++++++++++++++++-------- costs/models.py | 8 +++----- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/app_users/models.py b/app_users/models.py index 8594a44fb..5e373f414 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -225,7 +225,6 @@ class AppUserTransaction(models.Model): amount = models.IntegerField() end_balance = models.IntegerField() created_at = models.DateTimeField(editable=False, blank=True, default=timezone.now) - dollar_amt = models.DecimalField(default=decimal(0.0), blank=True) class Meta: verbose_name = "Transaction" diff --git a/costs/admin.py b/costs/admin.py index 4b3ed79a3..db3485d79 100644 --- a/costs/admin.py +++ b/costs/admin.py @@ -7,12 +7,12 @@ @admin.register(models.UsageCost) class CostsAdmin(admin.ModelAdmin): list_display = [ - "run_id", + "saved_run", "provider", "model", "param", "notes", - "quantity", "calculation_notes", - "cost", + "dollar_amt", + "created_at", ] diff --git a/costs/migrations/0001_initial.py b/costs/migrations/0001_initial.py index dddbf1dfc..4051d1109 100644 --- a/costs/migrations/0001_initial.py +++ b/costs/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.5 on 2023-12-13 20:14 +# Generated by Django 4.2.5 on 2023-12-14 20:52 from django.db import migrations, models import django.db.models.deletion @@ -13,7 +13,7 @@ class Migration(migrations.Migration): operations = [ migrations.CreateModel( - name="Cost", + name="UsageCost", fields=[ ( "id", @@ -24,22 +24,33 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), - ("provider", models.CharField(blank=True, default="", max_length=255)), + ( + "provider", + models.IntegerField( + choices=[ + (1, "OpenAI"), + (2, "GPT-4"), + (3, "dalle-e"), + (4, "whisper"), + (5, "GPT-3.5"), + ] + ), + ), ("model", models.TextField(blank=True, verbose_name="model")), ("param", models.TextField(blank=True, verbose_name="param")), ("notes", models.TextField(blank=True, default="")), - ("quantity", models.IntegerField(default=1)), ("calculation_notes", models.TextField(blank=True, default="")), - ("cost", models.FloatField(default=0.0)), + ("dollar_amt", models.DecimalField(decimal_places=8, max_digits=13)), + ("created_at", models.DateTimeField(auto_now_add=True)), ( - "run_id", + "saved_run", models.ForeignKey( blank=True, default=None, help_text="The run that was last saved by the user.", null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="cost", + on_delete=django.db.models.deletion.CASCADE, + related_name="usage_costs", to="bots.savedrun", ), ), diff --git a/costs/models.py b/costs/models.py index bc75b8bc2..92dcb5ce9 100644 --- a/costs/models.py +++ b/costs/models.py @@ -1,9 +1,6 @@ from django.db import models -# class CostQuerySet(models.QuerySet): - - class UsageCost(models.Model): saved_run = models.ForeignKey( "bots.SavedRun", @@ -24,9 +21,10 @@ class Provider(models.IntegerChoices): provider = models.IntegerField(choices=Provider.choices) model = models.TextField("model", blank=True) - param = models.TextField("param", blank=True) # combine with quantity as JSON obj + param = models.TextField( + "param", blank=True + ) # contains input/output tokens and quantity notes = models.TextField(default="", blank=True) - quantity = models.JSONField(default=dict, blank=True) calculation_notes = models.TextField(default="", blank=True) dollar_amt = models.DecimalField( max_digits=13, From 7cc9f676745c77a7142e5d906ba38e69edc55c6a Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Wed, 3 Jan 2024 20:01:35 -0800 Subject: [PATCH 05/85] Fixed migration errors --- .../0049_alter_savedrun_workflow.py | 52 ------------ celeryapp/tasks.py | 1 + costs/admin.py | 16 ++++ costs/migrations/0001_initial.py | 4 +- costs/migrations/0002_providerpricing.py | 79 +++++++++++++++++++ costs/models.py | 39 +++++++-- daras_ai_v2/language_model.py | 4 + 7 files changed, 134 insertions(+), 61 deletions(-) delete mode 100644 bots/migrations/0049_alter_savedrun_workflow.py create mode 100644 costs/migrations/0002_providerpricing.py diff --git a/bots/migrations/0049_alter_savedrun_workflow.py b/bots/migrations/0049_alter_savedrun_workflow.py deleted file mode 100644 index 82780828c..000000000 --- a/bots/migrations/0049_alter_savedrun_workflow.py +++ /dev/null @@ -1,52 +0,0 @@ -# Generated by Django 4.2.5 on 2023-12-12 08:26 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("bots", "0048_alter_messageattachment_url"), - ] - - operations = [ - migrations.AlterField( - model_name="savedrun", - name="workflow", - field=models.IntegerField( - choices=[ - (1, "Doc Search"), - (2, "Doc Summary"), - (3, "Google GPT"), - (4, "Copilot"), - (5, "Lipysnc + TTS"), - (6, "Text to Speech"), - (7, "Speech Recognition"), - (8, "Lipsync"), - (9, "Deforum Animation"), - (10, "Compare Text2Img"), - (11, "Text2Audio"), - (12, "Img2Img"), - (13, "Face Inpainting"), - (14, "Google Image Gen"), - (15, "Compare AI Upscalers"), - (16, "SEO Summary"), - (17, "Email Face Inpainting"), - (18, "Social Lookup Email"), - (19, "Object Inpainting"), - (20, "Image Segmentation"), - (21, "Compare LLM"), - (22, "Chyron Plant"), - (23, "Letter Writer"), - (24, "Smart GPT"), - (25, "AI QR Code"), - (26, "Doc Extract"), - (27, "Related QnA Maker"), - (28, "Related QnA Maker Doc"), - (29, "Embeddings"), - (30, "Bulk Runner"), - (31, "Bulk Evaluator"), - ], - default=4, - ), - ), - ] diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 1868c7a18..7bd144091 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -8,6 +8,7 @@ import gooey_ui as st from app_users.models import AppUser from bots.models import SavedRun +from costs.models import UsageCost from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings diff --git a/costs/admin.py b/costs/admin.py index db3485d79..12c106a14 100644 --- a/costs/admin.py +++ b/costs/admin.py @@ -16,3 +16,19 @@ class CostsAdmin(admin.ModelAdmin): "dollar_amt", "created_at", ] + + +@admin.register(models.ProviderPricing) +class ProviderAdmin(admin.ModelAdmin): + list_display = [ + "type", + "provider", + "product", + "param", + "cost", + "unit", + "created_at", + "last_updated", + "updated_by", + "pricing_url", + ] diff --git a/costs/migrations/0001_initial.py b/costs/migrations/0001_initial.py index 4051d1109..d3f6d1948 100644 --- a/costs/migrations/0001_initial.py +++ b/costs/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.5 on 2023-12-14 20:52 +# Generated by Django 4.2.5 on 2024-01-04 03:59 from django.db import migrations, models import django.db.models.deletion @@ -8,7 +8,7 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ("bots", "0049_alter_savedrun_workflow"), + ("bots", "0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more"), ] operations = [ diff --git a/costs/migrations/0002_providerpricing.py b/costs/migrations/0002_providerpricing.py new file mode 100644 index 000000000..6b0bb17a1 --- /dev/null +++ b/costs/migrations/0002_providerpricing.py @@ -0,0 +1,79 @@ +# Generated by Django 4.2.5 on 2024-01-04 04:00 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0001_initial"), + ] + + operations = [ + migrations.CreateModel( + name="ProviderPricing", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("type", models.TextField(choices=[("LLM", "Llm")])), + ( + "provider", + models.IntegerField( + choices=[ + (1, "OpenAI"), + (2, "GPT-4"), + (3, "dalle-e"), + (4, "whisper"), + (5, "GPT-3.5"), + ] + ), + ), + ( + "product", + models.TextField( + choices=[ + ("GPT-4 Vision (openai)", "gpt_4_vision"), + ("GPT-4 Turbo (openai)", "gpt_4_turbo"), + ("GPT-4 (openai)", "gpt_4"), + ("GPT-4 32K (openai)", "gpt_4_32k"), + ("ChatGPT (openai)", "gpt_3_5_turbo"), + ("ChatGPT 16k (openai)", "gpt_3_5_turbo_16k"), + ("Llama 2 (Meta AI)", "llama2_70b_chat"), + ("PaLM 2 Text (Google)", "palm2_chat"), + ("PaLM 2 Chat (Google)", "palm2_text"), + ("GPT-3.5 Davinci-3 (openai)", "text_davinci_003"), + ("GPT-3.5 Davinci-2 (openai)", "text_davinci_002"), + ("Curie (openai)", "text_curie_001"), + ("Babbage (openai)", "text_babbage_001"), + ("Ada (openai)", "text_ada_001"), + ("Codex [Deprecated] (openai)", "code_davinci_002"), + ] + ), + ), + ( + "param", + models.TextField( + choices=[ + ("input", "Input"), + ("output", "Output"), + ("input image", "Input Image"), + ("output image", "Output Image"), + ] + ), + ), + ("cost", models.DecimalField(decimal_places=8, max_digits=13)), + ("unit", models.TextField(default="")), + ("notes", models.TextField(default="")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("last_updated", models.DateTimeField(auto_now=True)), + ("updated_by", models.TextField(default="")), + ("pricing_url", models.TextField(default="")), + ], + ), + ] diff --git a/costs/models.py b/costs/models.py index 92dcb5ce9..ae04baf83 100644 --- a/costs/models.py +++ b/costs/models.py @@ -1,4 +1,13 @@ from django.db import models +from daras_ai_v2.language_model import LargeLanguageModels + + +class Provider(models.IntegerChoices): + OpenAI = 1, "OpenAI" + GPT_4 = 2, "GPT-4" + dalle_e = 3, "dalle-e" + whisper = 4, "whisper" + GPT_3_5 = 5, "GPT-3.5" class UsageCost(models.Model): @@ -12,13 +21,6 @@ class UsageCost(models.Model): help_text="The run that was last saved by the user.", ) - class Provider(models.IntegerChoices): - OpenAI = 1, "OpenAI" - GPT_4 = 2, "GPT-4" - dalle_e = 3, "dalle-e" - whisper = 4, "whisper" - GPT_3_5 = 5, "GPT-3.5" - provider = models.IntegerField(choices=Provider.choices) model = models.TextField("model", blank=True) param = models.TextField( @@ -31,3 +33,26 @@ class Provider(models.IntegerChoices): decimal_places=8, ) created_at = models.DateTimeField(auto_now_add=True) + + +class ProviderPricing(models.Model): + class Type(models.TextChoices): + LLM = "LLM" + + class Param(models.TextChoices): + input = "input" + output = "output" + input_image = "input image" + output_image = "output image" + + type = models.TextField(choices=Type.choices) + provider = models.IntegerField(choices=Provider.choices) + product = models.TextField(choices=LargeLanguageModels.choices()) + param = models.TextField(choices=Param.choices) + cost = models.DecimalField(max_digits=13, decimal_places=8) + unit = models.TextField(default="") + notes = models.TextField(default="") + created_at = models.DateTimeField(auto_now_add=True) + last_updated = models.DateTimeField(auto_now=True) + updated_by = models.TextField(default="") + pricing_url = models.TextField(default="") diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index f63c7e01d..77d04d01a 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -65,6 +65,10 @@ class LargeLanguageModels(Enum): code_davinci_002 = "Codex [Deprecated] (openai)" + @classmethod + def choices(cls): + return tuple((e.value, e.name) for e in cls) + @classmethod def _deprecated(cls): return {cls.code_davinci_002} From 9c25a857bbfe36b714a52d2ef92e46f8cdaae7ae Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Wed, 3 Jan 2024 22:40:55 -0800 Subject: [PATCH 06/85] Return token count for openai chat and palm chat --- celeryapp/tasks.py | 1 - daras_ai_v2/language_model.py | 56 ++++++++++++++++++++++++----------- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 7bd144091..1868c7a18 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -8,7 +8,6 @@ import gooey_ui as st from app_users.models import AppUser from bots.models import SavedRun -from costs.models import UsageCost from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 77d04d01a..676b9b9b6 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -350,7 +350,7 @@ def run_language_model( format_chat_entry(role=entry["role"], content=get_entry_text(entry)) for entry in messages ] - result = _run_chat_model( + result, output_token, input_token = _run_chat_model( api=api, model=model_name, messages=messages, # type: ignore @@ -372,10 +372,14 @@ def run_language_model( else (entry.get("content") or "").strip() for entry in result ] + print("out_contentttt", out_content, input_token, output_token) if tools: - return out_content, [(entry.get("tool_calls") or []) for entry in result] + return ( + out_content, + [(entry.get("tool_calls") or []) for entry in result], + ) else: - return out_content + return out_content, input_token, output_token else: if tools: raise ValueError("Only OpenAI chat models support Tools") @@ -442,7 +446,7 @@ def _run_chat_model( avoid_repetition: bool, tools: list[LLMTools] | None, response_format_type: typing.Literal["text", "json_object"] | None, -) -> list[ConversationEntry]: +) -> tuple[list[ConversationEntry], int, int]: match api: case LLMApis.openai: return _run_openai_chat( @@ -493,7 +497,7 @@ def _run_openai_chat( avoid_repetition: bool, tools: list[LLMTools] | None, response_format_type: typing.Literal["text", "json_object"] | None, -) -> list[ConversationEntry]: +) -> tuple[list[ConversationEntry], int, int]: from openai._types import NOT_GIVEN if avoid_repetition: @@ -524,7 +528,11 @@ def _run_openai_chat( for model_str in model ], ) - return [choice.message.dict() for choice in r.choices] + return ( + [choice.message.dict() for choice in r.choices], + r.usage.completion_tokens, + r.usage.prompt_tokens, + ) @retry_if(openai_should_retry) @@ -630,7 +638,7 @@ def _run_palm_chat( max_output_tokens: int, candidate_count: int, temperature: float, -) -> list[ConversationEntry]: +) -> tuple[list[ConversationEntry], int, int]: """ Args: model_id: The model id to use for the request. See available models: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models @@ -670,14 +678,23 @@ def _run_palm_chat( ) r.raise_for_status() - return [ - { - "role": msg["author"], - "content": msg["content"], - } - for pred in r.json()["predictions"] - for msg in pred["candidates"] - ] + print( + "real r.json()", + r.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + ) + + return ( + [ + { + "role": msg["author"], + "content": msg["content"], + } + for pred in r.json()["predictions"] + for msg in pred["candidates"] + ], + r.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + r.json()["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"], + ) @retry_if(vertex_ai_should_retry) @@ -688,7 +705,7 @@ def _run_palm_text( max_output_tokens: int, candidate_count: int, temperature: float, -) -> list[str]: +) -> tuple[list[str], int, int]: """ Args: model_id: The model id to use for the request. See available models: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models @@ -714,7 +731,12 @@ def _run_palm_text( }, ) res.raise_for_status() - return [prediction["content"] for prediction in res.json()["predictions"]] + print("res.json()", res.json()) + return ( + [prediction["content"] for prediction in res.json()["predictions"]], + res.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + res.json()["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"], + ) def format_chatml_message(entry: ConversationEntry) -> str: From d28125ecbc7716efcc99ae422baffce4719618e4 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 4 Jan 2024 10:35:08 -0800 Subject: [PATCH 07/85] Get tokens from response --- daras_ai_v2/language_model.py | 38 +++++++++++++++++++++-------------- recipes/CompareLLM.py | 1 + 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 676b9b9b6..3e9400ff5 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -329,7 +329,11 @@ def run_language_model( avoid_repetition: bool = False, tools: list[LLMTools] = None, response_format_type: typing.Literal["text", "json_object"] = None, -) -> list[str] | tuple[list[str], list[list[dict]]] | list[dict]: +) -> ( + tuple[list[str], int, int] + | tuple[list[str], list[list[dict]], int, int] + | tuple[list[dict], int, int] +): assert bool(prompt) != bool( messages ), "Pleave provide exactly one of { prompt, messages }" @@ -372,19 +376,20 @@ def run_language_model( else (entry.get("content") or "").strip() for entry in result ] - print("out_contentttt", out_content, input_token, output_token) if tools: return ( out_content, [(entry.get("tool_calls") or []) for entry in result], + output_token, + input_token, ) else: - return out_content, input_token, output_token + return out_content, output_token, input_token else: if tools: raise ValueError("Only OpenAI chat models support Tools") logger.info(f"{model_name=}, {len(prompt)=}, {max_tokens=}, {temperature=}") - result = _run_text_model( + result, output_token, input_token = _run_text_model( api=api, model=model_name, prompt=prompt, @@ -395,7 +400,7 @@ def run_language_model( avoid_repetition=avoid_repetition, quality=quality, ) - return [msg.strip() for msg in result] + return [msg.strip() for msg in result], output_token, input_token def _run_text_model( @@ -409,7 +414,7 @@ def _run_text_model( stop: list[str] | None, avoid_repetition: bool, quality: float, -) -> list[str]: +) -> tuple[list[str], int, int]: match api: case LLMApis.openai: return _run_openai_text( @@ -528,6 +533,7 @@ def _run_openai_chat( for model_str in model ], ) + print("entire r", r) return ( [choice.message.dict() for choice in r.choices], r.usage.completion_tokens, @@ -557,7 +563,11 @@ def _run_openai_text( frequency_penalty=0.1 if avoid_repetition else 0, presence_penalty=0.25 if avoid_repetition else 0, ) - return [choice.text for choice in r.choices] + return ( + [choice.text for choice in r.choices], + r.usage.completion_tokens, + r.usage.prompt_tokens, + ) def _get_openai_client(model: str): @@ -586,7 +596,7 @@ def _run_together_chat( temperature: float, repetition_penalty: float, num_outputs: int, -) -> list[ConversationEntry]: +) -> tuple[list[ConversationEntry], int, int]: """ Args: model: The model version to use for the request. @@ -614,11 +624,15 @@ def _run_together_chat( range(num_outputs), ) ret = [] + total_out_tokens = 0 + total_in_tokens = 0 for r in results: r.raise_for_status() data = r.json() output = data["output"] error = output.get("error") + total_out_tokens += output.get("usage", {}).get("completion_tokens", 0) + total_in_tokens += output.get("usage", {}).get("prompt_tokens", 0) if error: raise ValueError(error) ret.append( @@ -627,7 +641,7 @@ def _run_together_chat( "content": output["choices"][0]["text"], } ) - return ret + return ret, total_out_tokens, total_in_tokens @retry_if(vertex_ai_should_retry) @@ -678,11 +692,6 @@ def _run_palm_chat( ) r.raise_for_status() - print( - "real r.json()", - r.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], - ) - return ( [ { @@ -731,7 +740,6 @@ def _run_palm_text( }, ) res.raise_for_status() - print("res.json()", res.json()) return ( [prediction["content"] for prediction in res.json()["predictions"]], res.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 583317ddc..c9ade07d8 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -105,6 +105,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: max_tokens=request.max_tokens, avoid_repetition=request.avoid_repetition, ) + print("final output", output_text) def render_output(self): self._render_outputs(st.session_state, 450) From 20e4fba79a1e9d2ba97bb2bdbb0af237d05c3eb4 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Fri, 5 Jan 2024 19:11:00 -0800 Subject: [PATCH 08/85] Saving LLM works --- celeryapp/tasks.py | 1 + costs/admin.py | 7 +- costs/cost_utils.py | 36 +++++++++++ ...t_model_remove_usagecost_param_and_more.py | 64 +++++++++++++++++++ .../0004_remove_usagecost_saved_run.py | 16 +++++ .../0005_remove_usagecost_provider_pricing.py | 16 +++++ ...st_provider_pricing_usagecost_saved_run.py | 38 +++++++++++ .../0007_alter_providerpricing_provider.py | 25 ++++++++ ...08_alter_providerpricing_param_and_more.py | 33 ++++++++++ .../0009_alter_providerpricing_product.py | 19 ++++++ ...0010_remove_usagecost_calculation_notes.py | 16 +++++ costs/models.py | 46 +++++++------ daras_ai_v2/language_model.py | 45 +++++++++---- recipes/CompareLLM.py | 1 - 14 files changed, 326 insertions(+), 37 deletions(-) create mode 100644 costs/cost_utils.py create mode 100644 costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py create mode 100644 costs/migrations/0004_remove_usagecost_saved_run.py create mode 100644 costs/migrations/0005_remove_usagecost_provider_pricing.py create mode 100644 costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py create mode 100644 costs/migrations/0007_alter_providerpricing_provider.py create mode 100644 costs/migrations/0008_alter_providerpricing_param_and_more.py create mode 100644 costs/migrations/0009_alter_providerpricing_product.py create mode 100644 costs/migrations/0010_remove_usagecost_calculation_notes.py diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 1868c7a18..7bd144091 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -8,6 +8,7 @@ import gooey_ui as st from app_users.models import AppUser from bots.models import SavedRun +from costs.models import UsageCost from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings diff --git a/costs/admin.py b/costs/admin.py index 12c106a14..f3d99d472 100644 --- a/costs/admin.py +++ b/costs/admin.py @@ -8,11 +8,9 @@ class CostsAdmin(admin.ModelAdmin): list_display = [ "saved_run", - "provider", - "model", - "param", + "provider_pricing", + "quantity", "notes", - "calculation_notes", "dollar_amt", "created_at", ] @@ -27,6 +25,7 @@ class ProviderAdmin(admin.ModelAdmin): "param", "cost", "unit", + "notes", "created_at", "last_updated", "updated_by", diff --git a/costs/cost_utils.py b/costs/cost_utils.py new file mode 100644 index 000000000..c04a30c92 --- /dev/null +++ b/costs/cost_utils.py @@ -0,0 +1,36 @@ +from costs.models import UsageCost, ProviderPricing +from bots.models import SavedRun +from django.forms.models import model_to_dict + + +def get_provider_pricing( + type: str, + provider: str, + product: str, + param: str, +) -> ProviderPricing: + return ProviderPricing.objects.get( + type=type, + provider=provider, + product=product, + param=param, + ) + + +def record_cost( + run_id: str | None, + uid: str | None, + provider_pricing: ProviderPricing, + quantity: int, +) -> UsageCost: + saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) + cost = UsageCost( + saved_run=saved_run, + provider_pricing=provider_pricing, + quantity=quantity, + notes="", + dollar_amt=provider_pricing.cost * quantity, + created_at=saved_run.created_at, + ) + cost.save() + return cost diff --git a/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py b/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py new file mode 100644 index 000000000..4aa89cadd --- /dev/null +++ b/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py @@ -0,0 +1,64 @@ +# Generated by Django 4.2.5 on 2024-01-05 06:02 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0002_providerpricing"), + ] + + operations = [ + migrations.RemoveField( + model_name="usagecost", + name="model", + ), + migrations.RemoveField( + model_name="usagecost", + name="param", + ), + migrations.RemoveField( + model_name="usagecost", + name="provider", + ), + migrations.AddField( + model_name="usagecost", + name="provider_pricing", + field=models.ForeignKey( + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="usage_costs", + to="costs.providerpricing", + ), + ), + migrations.AddField( + model_name="usagecost", + name="quantity", + field=models.IntegerField(default=1), + ), + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("GPT-4 Vision (openai)", "gpt_4_vision"), + ("GPT-4 Turbo (openai)", "gpt_4_turbo"), + ("GPT-4 (openai)", "gpt_4"), + ("GPT-4 32K (openai)", "gpt_4_32k"), + ("ChatGPT (openai)", "gpt_3_5_turbo"), + ("ChatGPT 16k (openai)", "gpt_3_5_turbo_16k"), + ("Llama 2 (Meta AI)", "llama2_70b_chat"), + ("PaLM 2 Chat (Google)", "palm2_chat"), + ("PaLM 2 Text (Google)", "palm2_text"), + ("GPT-3.5 Davinci-3 (openai)", "text_davinci_003"), + ("GPT-3.5 Davinci-2 (openai)", "text_davinci_002"), + ("Curie (openai)", "text_curie_001"), + ("Babbage (openai)", "text_babbage_001"), + ("Ada (openai)", "text_ada_001"), + ("Codex [Deprecated] (openai)", "code_davinci_002"), + ] + ), + ), + ] diff --git a/costs/migrations/0004_remove_usagecost_saved_run.py b/costs/migrations/0004_remove_usagecost_saved_run.py new file mode 100644 index 000000000..2e68a519e --- /dev/null +++ b/costs/migrations/0004_remove_usagecost_saved_run.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.5 on 2024-01-05 07:33 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0003_remove_usagecost_model_remove_usagecost_param_and_more"), + ] + + operations = [ + migrations.RemoveField( + model_name="usagecost", + name="saved_run", + ), + ] diff --git a/costs/migrations/0005_remove_usagecost_provider_pricing.py b/costs/migrations/0005_remove_usagecost_provider_pricing.py new file mode 100644 index 000000000..2ff682ab0 --- /dev/null +++ b/costs/migrations/0005_remove_usagecost_provider_pricing.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.5 on 2024-01-05 08:06 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0004_remove_usagecost_saved_run"), + ] + + operations = [ + migrations.RemoveField( + model_name="usagecost", + name="provider_pricing", + ), + ] diff --git a/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py b/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py new file mode 100644 index 000000000..ba207c963 --- /dev/null +++ b/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py @@ -0,0 +1,38 @@ +# Generated by Django 4.2.5 on 2024-01-05 08:06 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more"), + ("costs", "0005_remove_usagecost_provider_pricing"), + ] + + operations = [ + migrations.AddField( + model_name="usagecost", + name="provider_pricing", + field=models.ForeignKey( + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="usage_costs", + to="costs.providerpricing", + ), + ), + migrations.AddField( + model_name="usagecost", + name="saved_run", + field=models.ForeignKey( + blank=True, + default=None, + help_text="The run that was last saved by the user.", + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="usage_costs", + to="bots.savedrun", + ), + ), + ] diff --git a/costs/migrations/0007_alter_providerpricing_provider.py b/costs/migrations/0007_alter_providerpricing_provider.py new file mode 100644 index 000000000..e6220e391 --- /dev/null +++ b/costs/migrations/0007_alter_providerpricing_provider.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.5 on 2024-01-06 00:51 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0006_usagecost_provider_pricing_usagecost_saved_run"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="provider", + field=models.TextField( + choices=[ + ("OpenAI", "Openai"), + ("GPT-4", "Gpt 4"), + ("dalle-e", "Dalle E"), + ("whisper", "Whisper"), + ("GPT-3.5", "Gpt 3 5"), + ] + ), + ), + ] diff --git a/costs/migrations/0008_alter_providerpricing_param_and_more.py b/costs/migrations/0008_alter_providerpricing_param_and_more.py new file mode 100644 index 000000000..bcc9a0030 --- /dev/null +++ b/costs/migrations/0008_alter_providerpricing_param_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 4.2.5 on 2024-01-06 01:10 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0007_alter_providerpricing_provider"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="param", + field=models.TextField(choices=[("Input", "input"), ("Output", "output")]), + ), + migrations.AlterField( + model_name="providerpricing", + name="provider", + field=models.TextField( + choices=[ + ("vertex_ai", "Vertex AI"), + ("openai", "OpenAI"), + ("together", "Together"), + ] + ), + ), + migrations.AlterField( + model_name="providerpricing", + name="type", + field=models.TextField(choices=[("LLM", "LLM")]), + ), + ] diff --git a/costs/migrations/0009_alter_providerpricing_product.py b/costs/migrations/0009_alter_providerpricing_product.py new file mode 100644 index 000000000..5e5a51019 --- /dev/null +++ b/costs/migrations/0009_alter_providerpricing_product.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.5 on 2024-01-06 01:28 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0008_alter_providerpricing_param_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[("gpt-4-vision-preview", "gpt-4-vision-preview")] + ), + ), + ] diff --git a/costs/migrations/0010_remove_usagecost_calculation_notes.py b/costs/migrations/0010_remove_usagecost_calculation_notes.py new file mode 100644 index 000000000..3eba7ace1 --- /dev/null +++ b/costs/migrations/0010_remove_usagecost_calculation_notes.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.5 on 2024-01-06 02:02 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0009_alter_providerpricing_product"), + ] + + operations = [ + migrations.RemoveField( + model_name="usagecost", + name="calculation_notes", + ), + ] diff --git a/costs/models.py b/costs/models.py index ae04baf83..cc997eb83 100644 --- a/costs/models.py +++ b/costs/models.py @@ -2,12 +2,17 @@ from daras_ai_v2.language_model import LargeLanguageModels -class Provider(models.IntegerChoices): - OpenAI = 1, "OpenAI" - GPT_4 = 2, "GPT-4" - dalle_e = 3, "dalle-e" - whisper = 4, "whisper" - GPT_3_5 = 5, "GPT-3.5" +class Provider(models.TextChoices): + vertex_ai = "vertex_ai", "Vertex AI" + openai = "openai", "OpenAI" + together = "together", "Together" + + +class Product(models.TextChoices): + gpt_4_vision = ( + "gpt-4-vision-preview", + "gpt-4-vision-preview", + ) class UsageCost(models.Model): @@ -21,13 +26,15 @@ class UsageCost(models.Model): help_text="The run that was last saved by the user.", ) - provider = models.IntegerField(choices=Provider.choices) - model = models.TextField("model", blank=True) - param = models.TextField( - "param", blank=True - ) # contains input/output tokens and quantity + provider_pricing = models.ForeignKey( + "costs.ProviderPricing", + on_delete=models.CASCADE, + related_name="usage_costs", + null=True, + default=None, + ) + quantity = models.IntegerField(default=1) notes = models.TextField(default="", blank=True) - calculation_notes = models.TextField(default="", blank=True) dollar_amt = models.DecimalField( max_digits=13, decimal_places=8, @@ -37,17 +44,15 @@ class UsageCost(models.Model): class ProviderPricing(models.Model): class Type(models.TextChoices): - LLM = "LLM" + LLM = "LLM", "LLM" class Param(models.TextChoices): - input = "input" - output = "output" - input_image = "input image" - output_image = "output image" + input = "Input", "input" + output = "Output", "output" type = models.TextField(choices=Type.choices) - provider = models.IntegerField(choices=Provider.choices) - product = models.TextField(choices=LargeLanguageModels.choices()) + provider = models.TextField(choices=Provider.choices) + product = models.TextField(choices=Product.choices) param = models.TextField(choices=Param.choices) cost = models.DecimalField(max_digits=13, decimal_places=8) unit = models.TextField(default="") @@ -56,3 +61,6 @@ class Param(models.TextChoices): last_updated = models.DateTimeField(auto_now=True) updated_by = models.TextField(default="") pricing_url = models.TextField(default="") + + def __str__(self): + return self.type + " " + self.provider + " " + self.product + " " + self.param diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index d6e93da89..f78998de0 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -27,6 +27,8 @@ get_redis_cache, ) from daras_ai_v2.text_splitter import default_length_function +from daras_ai_v2.query_params_util import extract_query_params +from daras_ai_v2.query_params import gooey_get_query_params DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible." @@ -65,10 +67,6 @@ class LargeLanguageModels(Enum): code_davinci_002 = "Codex [Deprecated] (openai)" - @classmethod - def choices(cls): - return tuple((e.value, e.name) for e in cls) - @classmethod def _deprecated(cls): return {cls.code_davinci_002} @@ -329,11 +327,9 @@ def run_language_model( avoid_repetition: bool = False, tools: list[LLMTools] = None, response_format_type: typing.Literal["text", "json_object"] = None, -) -> ( - tuple[list[str], int, int] - | tuple[list[str], list[list[dict]], int, int] - | tuple[list[dict], int, int] -): +) -> list[str] | tuple[list[str], list[list[dict]]] | list[dict]: + from costs.cost_utils import record_cost, get_provider_pricing + assert bool(prompt) != bool( messages ), "Pleave provide exactly one of { prompt, messages }" @@ -380,11 +376,35 @@ def run_language_model( return ( out_content, [(entry.get("tool_calls") or []) for entry in result], - output_token, - input_token, ) else: - return out_content, output_token, input_token + provider_pricing_in = get_provider_pricing( + type="LLM", + provider=api.name, + product=model_name, + param="Input", + ) + + provider_pricing_out = get_provider_pricing( + type="LLM", + provider=api.name, + product=model_name, + param="Output", + ) + example_id, run_id, uid = extract_query_params(gooey_get_query_params()) + record_cost( + run_id=run_id, + uid=uid, + provider_pricing=provider_pricing_in, + quantity=input_token, + ) + record_cost( + run_id=run_id, + uid=uid, + provider_pricing=provider_pricing_out, + quantity=output_token, + ) + return out_content else: if tools: raise ValueError("Only OpenAI chat models support Tools") @@ -533,7 +553,6 @@ def _run_openai_chat( for model_str in model ], ) - print("entire r", r) return ( [choice.message.dict() for choice in r.choices], r.usage.completion_tokens, diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index c9ade07d8..583317ddc 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -105,7 +105,6 @@ def run(self, state: dict) -> typing.Iterator[str | None]: max_tokens=request.max_tokens, avoid_repetition=request.avoid_repetition, ) - print("final output", output_text) def render_output(self): self._render_outputs(st.session_state, 450) From ee0e058c72a1aafced67c060285f44e95d827236 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Mon, 8 Jan 2024 09:02:21 -0800 Subject: [PATCH 09/85] Temp --- costs/cost_utils.py | 1 - costs/models.py | 12 ++++++------ daras_ai_v2/language_model.py | 4 ++++ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/costs/cost_utils.py b/costs/cost_utils.py index c04a30c92..f7799a970 100644 --- a/costs/cost_utils.py +++ b/costs/cost_utils.py @@ -1,6 +1,5 @@ from costs.models import UsageCost, ProviderPricing from bots.models import SavedRun -from django.forms.models import model_to_dict def get_provider_pricing( diff --git a/costs/models.py b/costs/models.py index cc997eb83..ece1500f8 100644 --- a/costs/models.py +++ b/costs/models.py @@ -1,11 +1,11 @@ from django.db import models -from daras_ai_v2.language_model import LargeLanguageModels +from daras_ai_v2.language_model import LLMApis -class Provider(models.TextChoices): - vertex_ai = "vertex_ai", "Vertex AI" - openai = "openai", "OpenAI" - together = "together", "Together" +# class Provider(models.TextChoices): +# vertex_ai = "vertex_ai", "Vertex AI" +# openai = "openai", "OpenAI" +# together = "together", "Together" class Product(models.TextChoices): @@ -51,7 +51,7 @@ class Param(models.TextChoices): output = "Output", "output" type = models.TextField(choices=Type.choices) - provider = models.TextField(choices=Provider.choices) + provider = models.TextField(choices=LLMApis.choices()) product = models.TextField(choices=Product.choices) param = models.TextField(choices=Param.choices) cost = models.DecimalField(max_digits=13, decimal_places=8) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index f78998de0..d1a4eb184 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -45,6 +45,10 @@ class LLMApis(Enum): openai = "OpenAI" together = "Together" + @classmethod + def choices(cls): + return tuple((api.name, api.value) for api in cls) + class LargeLanguageModels(Enum): gpt_4_vision = "GPT-4 Vision (openai)" From a5a64f24e8f1f6e238b2e36b3fc78cc89a714622 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 11 Jan 2024 01:26:36 -0800 Subject: [PATCH 10/85] Added choices --- costs/cost_utils.py | 1 + .../0011_alter_providerpricing_product.py | 35 +++++++++++ .../0012_alter_providerpricing_product.py | 37 +++++++++++ .../0013_alter_providerpricing_product.py | 37 +++++++++++ .../0014_alter_providerpricing_product.py | 40 ++++++++++++ .../0015_alter_providerpricing_product.py | 43 +++++++++++++ costs/models.py | 63 ++++++++++++++++--- 7 files changed, 248 insertions(+), 8 deletions(-) create mode 100644 costs/migrations/0011_alter_providerpricing_product.py create mode 100644 costs/migrations/0012_alter_providerpricing_product.py create mode 100644 costs/migrations/0013_alter_providerpricing_product.py create mode 100644 costs/migrations/0014_alter_providerpricing_product.py create mode 100644 costs/migrations/0015_alter_providerpricing_product.py diff --git a/costs/cost_utils.py b/costs/cost_utils.py index f7799a970..25f6ec28b 100644 --- a/costs/cost_utils.py +++ b/costs/cost_utils.py @@ -8,6 +8,7 @@ def get_provider_pricing( product: str, param: str, ) -> ProviderPricing: + print("get_provider_pricing", type, provider, product, param) return ProviderPricing.objects.get( type=type, provider=provider, diff --git a/costs/migrations/0011_alter_providerpricing_product.py b/costs/migrations/0011_alter_providerpricing_product.py new file mode 100644 index 000000000..b49d3f308 --- /dev/null +++ b/costs/migrations/0011_alter_providerpricing_product.py @@ -0,0 +1,35 @@ +# Generated by Django 4.2.5 on 2024-01-10 08:50 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0010_remove_usagecost_calculation_notes"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ("gpt-4-1106-preview", "gpt-4-1106-preview"), + ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/migrations/0012_alter_providerpricing_product.py b/costs/migrations/0012_alter_providerpricing_product.py new file mode 100644 index 000000000..3bc92c01c --- /dev/null +++ b/costs/migrations/0012_alter_providerpricing_product.py @@ -0,0 +1,37 @@ +# Generated by Django 4.2.5 on 2024-01-10 16:45 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0011_alter_providerpricing_product"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ("openai-gpt-4-turbo-prod-ca-1", "openai-gpt-4-turbo-prod-ca-1"), + ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), + ("gpt-3.5-turbo", "gpt-3.5-turbo"), + ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/migrations/0013_alter_providerpricing_product.py b/costs/migrations/0013_alter_providerpricing_product.py new file mode 100644 index 000000000..9449fbf91 --- /dev/null +++ b/costs/migrations/0013_alter_providerpricing_product.py @@ -0,0 +1,37 @@ +# Generated by Django 4.2.5 on 2024-01-10 16:50 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0012_alter_providerpricing_product"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ("openai-gpt-4-turbo-prod-ca-1", "gpt-4-1106-preview"), + ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), + ("gpt-3.5-turbo", "gpt-3.5-turbo"), + ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/migrations/0014_alter_providerpricing_product.py b/costs/migrations/0014_alter_providerpricing_product.py new file mode 100644 index 000000000..22b5439ed --- /dev/null +++ b/costs/migrations/0014_alter_providerpricing_product.py @@ -0,0 +1,40 @@ +# Generated by Django 4.2.5 on 2024-01-10 16:58 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0013_alter_providerpricing_product"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ( + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + ), + ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), + ("gpt-3.5-turbo", "gpt-3.5-turbo"), + ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/migrations/0015_alter_providerpricing_product.py b/costs/migrations/0015_alter_providerpricing_product.py new file mode 100644 index 000000000..2658de04f --- /dev/null +++ b/costs/migrations/0015_alter_providerpricing_product.py @@ -0,0 +1,43 @@ +# Generated by Django 4.2.5 on 2024-01-10 17:02 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0014_alter_providerpricing_product"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ( + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + ), + ( + "('openai-gpt-4-prod-ca-1', 'gpt-4')", + "('openai-gpt-4-prod-ca-1', 'gpt-4')", + ), + ("gpt-3.5-turbo", "gpt-3.5-turbo"), + ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/models.py b/costs/models.py index ece1500f8..18700537f 100644 --- a/costs/models.py +++ b/costs/models.py @@ -2,17 +2,64 @@ from daras_ai_v2.language_model import LLMApis -# class Provider(models.TextChoices): -# vertex_ai = "vertex_ai", "Vertex AI" -# openai = "openai", "OpenAI" -# together = "together", "Together" - - class Product(models.TextChoices): gpt_4_vision = ( "gpt-4-vision-preview", "gpt-4-vision-preview", ) + gpt_4_turbo = ( + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + ) + + gpt_4_32k = ( + "('openai-gpt-4-prod-ca-1', 'gpt-4')", + "('openai-gpt-4-prod-ca-1', 'gpt-4')", + ) + gpt_3_5_turbo = ( + "gpt-3.5-turbo", + "gpt-3.5-turbo", + ) + gpt_3_5_turbo_16k = ( + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-16k", + ) + text_davinci_003 = ( + "text-davinci-003", + "text-davinci-003", + ) + text_davinci_002 = ( + "text-davinci-002", + "text-davinci-002", + ) + code_davinci_002 = ( + "code-davinci-002", + "code-davinci-002", + ) + text_curie_001 = ( + "text-curie-001", + "text-curie-001", + ) + text_babbage_001 = ( + "text-babbage-001", + "text-babbage-001", + ) + text_ada_001 = ( + "text-ada-001", + "text-ada-001", + ) + palm2_text = ( + "text-bison", + "text-bison", + ) + palm2_chat = ( + "chat-bison", + "chat-bison", + ) + llama2_70b_chat = ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ) class UsageCost(models.Model): @@ -43,14 +90,14 @@ class UsageCost(models.Model): class ProviderPricing(models.Model): - class Type(models.TextChoices): + class Group(models.TextChoices): # change to different name than type LLM = "LLM", "LLM" class Param(models.TextChoices): input = "Input", "input" output = "Output", "output" - type = models.TextField(choices=Type.choices) + type = models.TextField(choices=Group.choices) provider = models.TextField(choices=LLMApis.choices()) product = models.TextField(choices=Product.choices) param = models.TextField(choices=Param.choices) From 6866258f3e0a6c0217e66cb922bd6028c9b10610 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 25 Jan 2024 19:10:19 +0530 Subject: [PATCH 11/85] Add different meta title and h1 title for Examples tab --- daras_ai_v2/base.py | 4 ++- daras_ai_v2/breadcrumbs.py | 70 ++++++++++++++++++++++++------------- daras_ai_v2/meta_content.py | 38 +++++++++++--------- 3 files changed, 70 insertions(+), 42 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index a3d3fd50e..ac632d4ca 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -232,7 +232,9 @@ def _render_header(self): and published_run.is_root() and published_run.saved_run == current_run ) - tbreadcrumbs = get_title_breadcrumbs(self, current_run, published_run) + tbreadcrumbs = get_title_breadcrumbs( + self, current_run, published_run, tab=self.tab + ) with st.div(className="d-flex justify-content-between mt-4"): with st.div(className="d-lg-flex d-block align-items-center"): diff --git a/daras_ai_v2/breadcrumbs.py b/daras_ai_v2/breadcrumbs.py index 61256d6d3..60a940654 100644 --- a/daras_ai_v2/breadcrumbs.py +++ b/daras_ai_v2/breadcrumbs.py @@ -6,6 +6,7 @@ PublishedRun, ) from daras_ai.image_input import truncate_text_words +from daras_ai_v2.tabs_widget import MenuTabs if typing.TYPE_CHECKING: from daras_ai_v2.base import BasePage @@ -17,6 +18,11 @@ class TitleUrl(typing.NamedTuple): class TitleBreadCrumbs(typing.NamedTuple): + """ + Breadcrumbs: root_title / published_title + Title: h1_title + """ + h1_title: str root_title: TitleUrl | None published_title: TitleUrl | None @@ -67,36 +73,50 @@ def get_title_breadcrumbs( page_cls: typing.Union["BasePage", typing.Type["BasePage"]], sr: SavedRun, pr: PublishedRun | None, + tab: str = MenuTabs.run, ) -> TitleBreadCrumbs: - if pr and sr == pr.saved_run and not pr.published_run_id: - # when published_run.published_run_id is blank, the run is the root example - return TitleBreadCrumbs(page_cls.get_recipe_title(), None, None) + is_root = pr and pr.saved_run == sr and pr.is_root() + is_example = not is_root and pr and pr.saved_run == sr + is_run = not is_root and not is_example - # the title on the saved root / the hardcoded title - recipe_title = page_cls.get_root_published_run().title or page_cls.title + recipe_title = page_cls.get_recipe_title() prompt_title = truncate_text_words( page_cls.preview_input(sr.to_dict()) or "", maxlen=60, ).replace("\n", " ") - root_title = TitleUrl(recipe_title, page_cls.app_url()) - - if pr and sr == pr.saved_run: - # published run root - return TitleBreadCrumbs( - pr.title or prompt_title or recipe_title, - root_title, - None, - ) - - if not pr or not pr.published_run_id: - # run created directly from recipe root - h1_title = prompt_title or f"Run: {recipe_title}" - return TitleBreadCrumbs(h1_title, root_title, None) - - # run created from a published run - h1_title = prompt_title or f"Run: {pr.title or recipe_title}" - published_title = TitleUrl( - pr.title or f"Fork {pr.published_run_id}", pr.get_app_url() + metadata = page_cls.workflow.get_or_create_metadata() + root_breadcrumb = TitleUrl( + metadata.short_title, + page_cls.app_url(), ) - return TitleBreadCrumbs(h1_title, root_title, published_title) + + match tab: + case MenuTabs.examples: + return TitleBreadCrumbs( + f"Examples: {metadata.short_title}", + root_title=root_breadcrumb, + published_title=None, + ) + case _ if is_root: + return TitleBreadCrumbs(page_cls.get_recipe_title(), None, None) + case _ if is_example: + assert pr is not None + return TitleBreadCrumbs( + pr.title or prompt_title or recipe_title, + root_title=root_breadcrumb, + published_title=None, + ) + case _ if is_run: + return TitleBreadCrumbs( + prompt_title or f"Run: {recipe_title}", + root_title=root_breadcrumb, + published_title=TitleUrl( + pr.title or f"Fork: {pr.published_run_id}", + pr.get_app_url(), + ) + if pr and not pr.is_root() + else None, + ) + case _: + raise AssertionError("Invalid tab or run") diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py index 12076d26f..f0c0f77e4 100644 --- a/daras_ai_v2/meta_content.py +++ b/daras_ai_v2/meta_content.py @@ -109,23 +109,29 @@ def meta_title_for_page( sr: SavedRun, pr: PublishedRun | None, ) -> str: - tbreadcrumbs = get_title_breadcrumbs(page, sr, pr) - parts = [] - if tbreadcrumbs.published_title or tbreadcrumbs.root_title: - parts.append(tbreadcrumbs.h1_title) - # use the short title for non-root examples - part = metadata.short_title - if tbreadcrumbs.published_title: - part = f"{pr.title} {part}" - # add the creator's name - user = sr.get_creator() - if user and user.display_name: - part += f" by {user.display_name}" - parts.append(part) - else: - # for root recipe, a longer, SEO-friendly title - parts.append(metadata.meta_title) + + match page.tab: + case MenuTabs.examples: + parts.append(f"Examples: {metadata.meta_title}") + case _ if pr and pr.saved_run == sr and pr.is_root(): + # for root page + parts.append(metadata.meta_title) + case _: + # non-root runs and examples + tbreadcrumbs = get_title_breadcrumbs(page, sr, pr) + + parts.append(tbreadcrumbs.h1_title) + + # use the short title for non-root examples + part = metadata.short_title + if tbreadcrumbs.published_title: + part = f"{pr.title} {part}" + # add the creator's name + user = sr.get_creator() + if user and user.display_name: + part += f" by {user.display_name}" + parts.append(part) parts.append("Gooey.AI") return sep.join(parts) From 056891b316fad9404c70e8cbf9eb41eff4c6af6b Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 25 Jan 2024 20:21:05 +0530 Subject: [PATCH 12/85] refactor: meta_title_for_page function --- daras_ai_v2/meta_content.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py index f0c0f77e4..871606639 100644 --- a/daras_ai_v2/meta_content.py +++ b/daras_ai_v2/meta_content.py @@ -28,6 +28,7 @@ def build_meta_tags( metadata=metadata, sr=sr, pr=pr, + tab=page.tab, ) description = meta_description_for_page( metadata=metadata, @@ -108,19 +109,21 @@ def meta_title_for_page( metadata: WorkflowMetadata, sr: SavedRun, pr: PublishedRun | None, + tab: str, ) -> str: - parts = [] + suffix = f" {sep} Gooey.AI" - match page.tab: + match tab: case MenuTabs.examples: - parts.append(f"Examples: {metadata.meta_title}") + return f"Examples: {metadata.meta_title}" + suffix case _ if pr and pr.saved_run == sr and pr.is_root(): # for root page - parts.append(metadata.meta_title) + return metadata.meta_title + suffix case _: # non-root runs and examples - tbreadcrumbs = get_title_breadcrumbs(page, sr, pr) + parts = [] + tbreadcrumbs = get_title_breadcrumbs(page, sr, pr) parts.append(tbreadcrumbs.h1_title) # use the short title for non-root examples @@ -133,8 +136,9 @@ def meta_title_for_page( part += f" by {user.display_name}" parts.append(part) - parts.append("Gooey.AI") - return sep.join(parts) + return sep.join(parts) + suffix + case _: + raise ValueError(f"Unknown tab: {tab}") def meta_description_for_page( From 26dd32e0232f76674fd2cc57b7f38aaa1cfbf8af Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 25 Jan 2024 20:22:20 +0530 Subject: [PATCH 13/85] Add h1 title and meta title for all tabs --- daras_ai_v2/breadcrumbs.py | 50 +++++++++++++++++++++++++++++-------- daras_ai_v2/meta_content.py | 16 ++++++++++-- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/daras_ai_v2/breadcrumbs.py b/daras_ai_v2/breadcrumbs.py index 60a940654..2623a823a 100644 --- a/daras_ai_v2/breadcrumbs.py +++ b/daras_ai_v2/breadcrumbs.py @@ -92,22 +92,16 @@ def get_title_breadcrumbs( ) match tab: - case MenuTabs.examples: - return TitleBreadCrumbs( - f"Examples: {metadata.short_title}", - root_title=root_breadcrumb, - published_title=None, - ) - case _ if is_root: + case MenuTabs.run if is_root: return TitleBreadCrumbs(page_cls.get_recipe_title(), None, None) - case _ if is_example: + case MenuTabs.run if is_example: assert pr is not None return TitleBreadCrumbs( pr.title or prompt_title or recipe_title, root_title=root_breadcrumb, published_title=None, ) - case _ if is_run: + case MenuTabs.run if is_run: return TitleBreadCrumbs( prompt_title or f"Run: {recipe_title}", root_title=root_breadcrumb, @@ -118,5 +112,41 @@ def get_title_breadcrumbs( if pr and not pr.is_root() else None, ) + case MenuTabs.examples: + return TitleBreadCrumbs( + f"Examples: {metadata.short_title}", + root_title=root_breadcrumb, + published_title=None, + ) + case MenuTabs.run_as_api: + tbreadcrumbs_on_run = get_title_breadcrumbs( + page_cls=page_cls, sr=sr, pr=pr, tab=MenuTabs.run + ) + return TitleBreadCrumbs( + f"API: {tbreadcrumbs_on_run.h1_title}", + root_title=tbreadcrumbs_on_run.root_title or root_breadcrumb, + published_title=tbreadcrumbs_on_run.published_title, + ) + case MenuTabs.integrations: + tbreadcrumbs_on_run = get_title_breadcrumbs( + page_cls=page_cls, sr=sr, pr=pr, tab=MenuTabs.run + ) + return TitleBreadCrumbs( + f"Integrations: {tbreadcrumbs_on_run.h1_title}", + root_title=tbreadcrumbs_on_run.root_title or root_breadcrumb, + published_title=tbreadcrumbs_on_run.published_title, + ) + case MenuTabs.history: + return TitleBreadCrumbs( + f"History: {metadata.short_title}", + root_title=root_breadcrumb, + published_title=None, + ) + case MenuTabs.saved: + return TitleBreadCrumbs( + f"Saved Runs: {metadata.short_title}", + root_title=root_breadcrumb, + published_title=None, + ) case _: - raise AssertionError("Invalid tab or run") + raise ValueError(f"Unknown tab: {tab}") diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py index 871606639..ff946bb75 100644 --- a/daras_ai_v2/meta_content.py +++ b/daras_ai_v2/meta_content.py @@ -116,10 +116,22 @@ def meta_title_for_page( match tab: case MenuTabs.examples: return f"Examples: {metadata.meta_title}" + suffix - case _ if pr and pr.saved_run == sr and pr.is_root(): + case MenuTabs.run_as_api: + return "API: " + meta_title_for_page( + page=page, metadata=metadata, sr=sr, pr=pr, tab=MenuTabs.run + ) + case MenuTabs.integrations: + return "Integrations: " + meta_title_for_page( + page=page, metadata=metadata, sr=sr, pr=pr, tab=MenuTabs.run + ) + case MenuTabs.history: + return f"History for {metadata.short_title}" + suffix + case MenuTabs.saved: + return f"Saved Runs for {metadata.short_title}" + suffix + case MenuTabs.run if pr and pr.saved_run == sr and pr.is_root(): # for root page return metadata.meta_title + suffix - case _: + case MenuTabs.run: # non-root runs and examples parts = [] From ff3c86f98cd70f7dd44174cfb71013d0406014b5 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 25 Jan 2024 12:02:58 -0800 Subject: [PATCH 14/85] fixes keyerror with sorting options --- bots/models.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/bots/models.py b/bots/models.py index 3ad94cd31..30c294f3e 100644 --- a/bots/models.py +++ b/bots/models.py @@ -704,7 +704,26 @@ def to_df_format( "Bot": str(convo.bot_integration), } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=[ + "Name", + "Messages", + "Correct Answers", + "Thumbs up", + "Thumbs down", + "Last Sent", + "First Sent", + "A7", + "A30", + "R1", + "R7", + "R30", + "Delta Hours", + "Created At", + "Bot", + ], + ) return df @@ -900,7 +919,17 @@ def to_df_format( "Analysis JSON": message.analysis_result, } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=[ + "Name", + "Role", + "Message (EN)", + "Sent", + "Feedback", + "Analysis JSON", + ], + ) return df def to_df_analysis_format( @@ -922,7 +951,10 @@ def to_df_analysis_format( "Analysis JSON": message.analysis_result, } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=["Name", "Question (EN)", "Answer (EN)", "Sent", "Analysis JSON"], + ) return df def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]: @@ -1107,7 +1139,20 @@ def to_df_format( "Question Answered": feedback.message.question_answered, } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=[ + "Name", + "Question (EN)", + "Question Sent", + "Answer (EN)", + "Answer Sent", + "Rating", + "Feedback (EN)", + "Feedback Sent", + "Question Answered", + ], + ) return df From 4733be17e0d0cd9620cc45a043bea471841699ac Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 25 Jan 2024 12:07:43 -0800 Subject: [PATCH 15/85] make 100% sure sortby is valid --- recipes/VideoBotsStats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index c94b847f2..b57c9f2ac 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -667,7 +667,7 @@ def get_tabular_data( df["Run URL"] = run_url df["Bot"] = bi.name - if sort_by: + if sort_by and sort_by in df.columns: df.sort_values(by=[sort_by], ascending=False, inplace=True) return df From dcc057a0e545306a1afc5bc7089e0baf125f416f Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 25 Jan 2024 12:45:19 -0800 Subject: [PATCH 16/85] added regression test --- bots/models.py | 6 ++--- bots/tests.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/bots/models.py b/bots/models.py index 30c294f3e..457f3bb6a 100644 --- a/bots/models.py +++ b/bots/models.py @@ -937,14 +937,14 @@ def to_df_analysis_format( ) -> "pd.DataFrame": import pandas as pd - qs = self.filter(role=CHATML_ROLE_USER).prefetch_related("feedbacks") + qs = self.filter(role=CHATML_ROLE_ASSISSTANT).prefetch_related("feedbacks") rows = [] for message in qs[:row_limit]: message: Message row = { "Name": message.conversation.get_display_name(), - "Question (EN)": message.content, - "Answer (EN)": message.get_next_by_created_at().content, + "Question (EN)": message.get_previous_by_created_at().content, + "Answer (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) .strftime("%b %d, %Y %I:%M %p"), diff --git a/bots/tests.py b/bots/tests.py index 21eef78f8..30105e648 100644 --- a/bots/tests.py +++ b/bots/tests.py @@ -92,3 +92,66 @@ def test_create_bot_integration_conversation_message(transactional_db): assert message_b.role == CHATML_ROLE_ASSISSTANT assert message_b.content == "Red, green, and yellow grow the best." assert message_b.display_content == "Red, green, and yellow grow the best." + + +def test_stats_get_tabular_data_invalid_sorting_options(transactional_db): + from recipes.VideoBotsStats import VideoBotsStatsPage + + page = VideoBotsStatsPage() + + # setup + run_url = "https://my_run_url" + bi = BotIntegration.objects.create( + name="My Bot Integration", + saved_run=None, + billing_account_uid="fdnacsFSBQNKVW8z6tzhBLHKpAm1", # digital green's account id + user_language="en", + show_feedback_buttons=True, + platform=Platform.WHATSAPP, + wa_phone_number="my_whatsapp_number", + wa_phone_number_id="my_whatsapp_number_id", + ) + convos = Conversation.objects.filter(bot_integration=bi) + msgs = Message.objects.filter(conversation__in=convos) + + # valid option but no data + df = page.get_tabular_data( + bi, run_url, convos, msgs, "Answered Successfully", "Name" + ) + assert df.shape[0] == 0 + assert "Name" in df.columns + + # valid option and data + convo = Conversation.objects.create( + bot_integration=bi, + state=ConvoState.INITIAL, + wa_phone_number="+919876543210", + ) + Message.objects.create( + conversation=convo, + role=CHATML_ROLE_USER, + content="What types of chilies can be grown in Mumbai?", + display_content="What types of chilies can be grown in Mumbai?", + ) + Message.objects.create( + conversation=convo, + role=CHATML_ROLE_ASSISSTANT, + content="Red, green, and yellow grow the best.", + display_content="Red, green, and yellow grow the best.", + analysis_result={"Answered": True}, + ) + convos = Conversation.objects.filter(bot_integration=bi) + msgs = Message.objects.filter(conversation__in=convos) + assert msgs.count() == 2 + df = page.get_tabular_data( + bi, run_url, convos, msgs, "Answered Successfully", "Name" + ) + assert df.shape[0] == 1 + assert "Name" in df.columns + + # invalid sort option should be ignored + df = page.get_tabular_data( + bi, run_url, convos, msgs, "Answered Successfully", "Invalid" + ) + assert df.shape[0] == 1 + assert "Name" in df.columns From 08ed3deda4ce9540c928c2bdf4a6df6d7efaba7d Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Fri, 26 Jan 2024 00:34:29 -0800 Subject: [PATCH 17/85] uploading google folders to the document uploader now expands them --- daras_ai_v2/doc_search_settings_widgets.py | 14 +++++++++- daras_ai_v2/gdrive_downloader.py | 30 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index b240b482f..39cc152db 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -75,7 +75,19 @@ def document_uploader( accept=accept, accept_multiple_files=accept_multiple_files, ) - return st.session_state.get(key, []) + documents = st.session_state.get(key, []) + for document in documents: + if not document.startswith("https://drive.google.com/drive/folders"): + continue + from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder + from furl import furl + + folder_content_urls = gdrive_list_urls_of_files_in_folder(furl(document)) + documents.remove(document) + documents.extend(folder_content_urls) + st.session_state[key] = documents + st.session_state[custom_key] = "\n".join(documents) + return documents def doc_search_settings( diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index 89bcadb5b..c8d3bb7ba 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -25,6 +25,36 @@ def url_to_gdrive_file_id(f: furl) -> str: return file_id +def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: + # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) + folder_id = f.path.segments[-1] + # get metadata + service = discovery.build("drive", "v3") + # get files in drive directly + if f.host == "drive.google.com": + request = service.files().list( + supportsAllDrives=True, + includeItemsFromAllDrives=True, + q=f"'{folder_id}' in parents", + fields="files(mimeType,webViewLink)", + ) + # export google docs to appropriate type + else: + raise ValueError(f"Can't list files non google folder url: {str(f)!r}") + # download + response = request.execute() + files = response.get("files", []) + urls = [] + for file in files: + mime_type = file.get("mimeType") + url = file.get("webViewLink") + if mime_type == "application/vnd.google-apps.folder": + continue + elif url: + urls.append(url) + return urls + + def gdrive_download(f: furl, mime_type: str) -> tuple[bytes, str]: # get drive file id file_id = url_to_gdrive_file_id(f) From 12a60c27ad7c661b5ad005980691fcd9d56396ee Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Fri, 26 Jan 2024 00:41:19 -0800 Subject: [PATCH 18/85] remove unessecary comments --- daras_ai_v2/gdrive_downloader.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index c8d3bb7ba..339a634a3 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -28,9 +28,7 @@ def url_to_gdrive_file_id(f: furl) -> str: def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) folder_id = f.path.segments[-1] - # get metadata service = discovery.build("drive", "v3") - # get files in drive directly if f.host == "drive.google.com": request = service.files().list( supportsAllDrives=True, @@ -38,10 +36,8 @@ def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: q=f"'{folder_id}' in parents", fields="files(mimeType,webViewLink)", ) - # export google docs to appropriate type else: raise ValueError(f"Can't list files non google folder url: {str(f)!r}") - # download response = request.execute() files = response.get("files", []) urls = [] From dd525156920c7a4511207fb4a9fd6270be1b95a3 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Fri, 26 Jan 2024 01:03:52 -0800 Subject: [PATCH 19/85] linting fix --- bots/models.py | 8 +++-- daras_ai_v2/bot_integration_widgets.py | 6 ++-- daras_ai_v2/bots.py | 12 ++++--- daras_ai_v2/language_model.py | 28 +++++++++------ .../language_model_settings_widgets.py | 6 ++-- daras_ai_v2/stable_diffusion.py | 12 +++---- gooeysite/urls.py | 1 + pyproject.toml | 3 ++ recipes/CompareLLM.py | 6 ++-- recipes/CompareText2Img.py | 6 ++-- recipes/CompareUpscaler.py | 6 ++-- recipes/DocExtract.py | 6 ++-- recipes/DocSearch.py | 6 ++-- recipes/DocSummary.py | 6 ++-- recipes/GoogleGPT.py | 6 ++-- recipes/ImageSegmentation.py | 6 ++-- recipes/Img2Img.py | 8 +++-- recipes/QRCodeGenerator.py | 12 +++---- recipes/SEOSummary.py | 6 ++-- recipes/SmartGPT.py | 6 ++-- recipes/SocialLookupEmail.py | 6 ++-- recipes/Text2Audio.py | 12 +++---- recipes/TextToSpeech.py | 6 ++-- recipes/VideoBots.py | 34 ++++++++++--------- recipes/VideoBotsStats.py | 10 +++--- 25 files changed, 121 insertions(+), 103 deletions(-) diff --git a/bots/models.py b/bots/models.py index 3ad94cd31..c464156e4 100644 --- a/bots/models.py +++ b/bots/models.py @@ -894,9 +894,11 @@ def to_df_format( "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) .strftime("%b %d, %Y %I:%M %p"), - "Feedback": message.feedbacks.first().get_display_text() - if message.feedbacks.first() - else None, # only show first feedback as per Sean's request + "Feedback": ( + message.feedbacks.first().get_display_text() + if message.feedbacks.first() + else None + ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, } rows.append(row) diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 3880f9fe7..2686b2298 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -19,9 +19,9 @@ def general_integration_settings(bi: BotIntegration): st.session_state[f"_bi_user_language_{bi.id}"] = BotIntegration._meta.get_field( "user_language" ).default - st.session_state[ - f"_bi_show_feedback_buttons_{bi.id}" - ] = BotIntegration._meta.get_field("show_feedback_buttons").default + st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( + BotIntegration._meta.get_field("show_feedback_buttons").default + ) st.session_state[f"_bi_analysis_url_{bi.id}"] = None bi.show_feedback_buttons = st.checkbox( diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index dfc00c794..8cc1d69e1 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -389,11 +389,13 @@ def _save_msgs( role=CHATML_ROLE_USER, content=response.raw_input_text, display_content=input_text, - saved_run=SavedRun.objects.get_or_create( - workflow=Workflow.ASR, **furl(speech_run).query.params - )[0] - if speech_run - else None, + saved_run=( + SavedRun.objects.get_or_create( + workflow=Workflow.ASR, **furl(speech_run).query.params + )[0] + if speech_run + else None + ), ) attachments = [] for f_url in (input_images or []) + (input_documents or []): diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 80136cbe8..070aa1c22 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -198,12 +198,14 @@ def calc_gpt_tokens( for entry in messages if ( content := ( - format_chatml_message(entry) + "\n" - if is_chat_model - else entry.get("content", "") + ( + format_chatml_message(entry) + "\n" + if is_chat_model + else entry.get("content", "") + ) + if isinstance(entry, dict) + else str(entry) ) - if isinstance(entry, dict) - else str(entry) ) ) return default_length_function(combined) @@ -364,9 +366,11 @@ def run_language_model( else: out_content = [ # return messages back as either chatml or json messages - format_chatml_message(entry) - if is_chatml - else (entry.get("content") or "").strip() + ( + format_chatml_message(entry) + if is_chatml + else (entry.get("content") or "").strip() + ) for entry in result ] if tools: @@ -514,9 +518,11 @@ def _run_openai_chat( frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, tools=[tool.spec for tool in tools] if tools else NOT_GIVEN, - response_format={"type": response_format_type} - if response_format_type - else NOT_GIVEN, + response_format=( + {"type": response_format_type} + if response_format_type + else NOT_GIVEN + ), ) for model_str in model ], diff --git a/daras_ai_v2/language_model_settings_widgets.py b/daras_ai_v2/language_model_settings_widgets.py index e5ab27a59..80f785bd4 100644 --- a/daras_ai_v2/language_model_settings_widgets.py +++ b/daras_ai_v2/language_model_settings_widgets.py @@ -24,9 +24,9 @@ def language_model_settings(show_selector=True, show_document_model=False): f"###### {field_title_desc(VideoBotsPage.RequestModel, 'document_model')}", key="document_model", options=[None, *doc_model_descriptions], - format_func=lambda x: f"{doc_model_descriptions[x]} ({x})" - if x - else "———", + format_func=lambda x: ( + f"{doc_model_descriptions[x]} ({x})" if x else "———" + ), ) st.checkbox("Avoid Repetition", key="avoid_repetition") diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index 22c7adc8e..ce6333068 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -247,9 +247,9 @@ def instruct_pix2pix( }, inputs={ "prompt": [prompt] * len(images), - "negative_prompt": [negative_prompt] * len(images) - if negative_prompt - else None, + "negative_prompt": ( + [negative_prompt] * len(images) if negative_prompt else None + ), "num_images_per_prompt": num_outputs, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, @@ -440,9 +440,9 @@ def controlnet( pipeline={ "model_id": text2img_model_ids[Text2ImgModels[selected_model]], "seed": seed, - "scheduler": Schedulers[scheduler].label - if scheduler - else "UniPCMultistepScheduler", + "scheduler": ( + Schedulers[scheduler].label if scheduler else "UniPCMultistepScheduler" + ), "disable_safety_checker": True, "controlnet_model_id": [ controlnet_model_ids[ControlNetModels[model]] diff --git a/gooeysite/urls.py b/gooeysite/urls.py index 767660f21..6c3809436 100644 --- a/gooeysite/urls.py +++ b/gooeysite/urls.py @@ -14,6 +14,7 @@ 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ + from django.contrib import admin from django.urls import path diff --git a/pyproject.toml b/pyproject.toml index c6a699024..6726c1fce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,3 +93,6 @@ pre-commit = "^3.5.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.black] +--force-exclude = "migrations|node_modules|\\.git|\\.venv|\\.env|\\.pytest_cache|\\.vscode|\\.github|\\.to" diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 583317ddc..f1be5a937 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -36,9 +36,9 @@ class CompareLLMPage(BasePage): class RequestModel(BaseModel): input_prompt: str | None - selected_models: list[ - typing.Literal[tuple(e.name for e in LargeLanguageModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in LargeLanguageModels)]] | None + ) avoid_repetition: bool | None num_outputs: int | None diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index e79ef5d54..5f44ca34b 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -64,9 +64,9 @@ class RequestModel(BaseModel): seed: int | None sd_2_upscaling: bool | None - selected_models: list[ - typing.Literal[tuple(e.name for e in Text2ImgModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in Text2ImgModels)]] | None + ) scheduler: typing.Literal[tuple(e.name for e in Schedulers)] | None edit_instruction: str | None diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index f6f4e988b..e4ed8865f 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -24,9 +24,9 @@ class RequestModel(BaseModel): scale: int - selected_models: list[ - typing.Literal[tuple(e.name for e in UpscalerModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in UpscalerModels)]] | None + ) class ResponseModel(BaseModel): output_images: dict[typing.Literal[tuple(e.name for e in UpscalerModels)], str] diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 6c9489fff..2672fffdb 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -77,9 +77,9 @@ class RequestModel(BaseModel): task_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 2a509d8fd..d18b5a54c 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -60,9 +60,9 @@ class RequestModel(DocSearchRequest): task_instructions: str | None query_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 4b9283cde..498f2d405 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -60,9 +60,9 @@ class RequestModel(BaseModel): task_instructions: str | None merge_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 57b1078ba..7a2d2591f 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -79,9 +79,9 @@ class RequestModel(GoogleSearchMixin, BaseModel): task_instructions: str | None query_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 543664065..751a0936b 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -49,9 +49,9 @@ class ImageSegmentationPage(BasePage): class RequestModel(BaseModel): input_image: str - selected_model: typing.Literal[ - tuple(e.name for e in ImageSegmentationModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in ImageSegmentationModels)] | None + ) mask_threshold: float | None rect_persepective_transform: bool | None diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index a97bc7c62..68e696c08 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -46,9 +46,11 @@ class RequestModel(BaseModel): text_prompt: str | None selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None - selected_controlnet_model: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)] - ] | typing.Literal[tuple(e.name for e in ControlNetModels)] | None + selected_controlnet_model: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)]] + | typing.Literal[tuple(e.name for e in ControlNetModels)] + | None + ) negative_prompt: str | None num_outputs: int | None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 6b7260368..34b615225 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -88,18 +88,18 @@ class RequestModel(BaseModel): text_prompt: str negative_prompt: str | None image_prompt: str | None - image_prompt_controlnet_models: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)], ... - ] | None + image_prompt_controlnet_models: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None + ) image_prompt_strength: float | None image_prompt_scale: float | None image_prompt_pos_x: float | None image_prompt_pos_y: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None - selected_controlnet_model: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)], ... - ] | None + selected_controlnet_model: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None + ) output_width: int | None output_height: int | None diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index a82424ae1..78328d822 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -98,9 +98,9 @@ class RequestModel(GoogleSearchMixin, BaseModel): enable_html: bool | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) sampling_temperature: float | None max_tokens: int | None num_outputs: int | None diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index 554a74940..0aabcfefb 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -34,9 +34,9 @@ class RequestModel(BaseModel): reflexion_prompt: str | None dera_prompt: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index 082fdcdaa..1094cea53 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -40,9 +40,9 @@ class RequestModel(BaseModel): domain: str | None key_words: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) sampling_temperature: float | None max_tokens: int | None diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 77776ddf5..18270cc8c 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -49,9 +49,9 @@ class RequestModel(BaseModel): seed: int | None sd_2_upscaling: bool | None - selected_models: list[ - typing.Literal[tuple(e.name for e in Text2AudioModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in Text2AudioModels)]] | None + ) class ResponseModel(BaseModel): output_audios: dict[ @@ -114,9 +114,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ), inputs=dict( prompt=[request.text_prompt], - negative_prompt=[request.negative_prompt] - if request.negative_prompt - else None, + negative_prompt=( + [request.negative_prompt] if request.negative_prompt else None + ), num_waveforms_per_prompt=request.num_outputs, num_inference_steps=request.quality, guidance_scale=request.guidance_scale, diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index c37a8eb5f..721b2cb7a 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -53,9 +53,9 @@ class TextToSpeechPage(BasePage): class RequestModel(BaseModel): text_prompt: str - tts_provider: typing.Literal[ - tuple(e.name for e in TextToSpeechProviders) - ] | None + tts_provider: ( + typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + ) uberduck_voice_name: str | None uberduck_speaking_rate: float | None diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 149931ffb..faa88d533 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -157,9 +157,9 @@ class RequestModel(BaseModel): messages: list[ConversationEntry] | None # tts settings - tts_provider: typing.Literal[ - tuple(e.name for e in TextToSpeechProviders) - ] | None + tts_provider: ( + typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + ) uberduck_voice_name: str | None uberduck_speaking_rate: float | None google_voice_name: str | None @@ -174,9 +174,9 @@ class RequestModel(BaseModel): elevenlabs_similarity_boost: float | None # llm settings - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) document_model: str | None = Field( title="🩻 Photo / Document Intelligence", description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? " @@ -1028,10 +1028,12 @@ def messenger_bot_integration(self): favicon = Platform(bi.platform).get_favicon() with st.div(className="mt-2"): st.markdown( - f'  ' - f'{bi}' - if bi.saved_run - else f"{bi}", + ( + f'  ' + f'{bi}' + if bi.saved_run + else f"{bi}" + ), unsafe_allow_html=True, ) with col2: @@ -1082,9 +1084,9 @@ def slack_specific_settings(self, bi: BotIntegration): st.session_state[f"_bi_name_{bi.id}"] = ( pr and pr.title ) or self.get_recipe_title() - st.session_state[ - f"_bi_slack_read_receipt_msg_{bi.id}" - ] = BotIntegration._meta.get_field("slack_read_receipt_msg").default + st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( + BotIntegration._meta.get_field("slack_read_receipt_msg").default + ) bi.slack_read_receipt_msg = st.text_input( """ @@ -1283,9 +1285,9 @@ def msg_container_widget(role: str): return st.div( className="px-3 py-1 pt-2", style=dict( - background="rgba(239, 239, 239, 0.6)" - if role == CHATML_ROLE_USER - else "#fff", + background=( + "rgba(239, 239, 239, 0.6)" if role == CHATML_ROLE_USER else "#fff" + ), ), ) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index c94b847f2..8d58dc272 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -310,11 +310,11 @@ def parse_run_info(self, bi): run_title = ( bi.published_run.title if bi.published_run - else saved_run.page_title - if saved_run and saved_run.page_title - else "This Copilot Run" - if saved_run - else "No Run Connected" + else ( + saved_run.page_title + if saved_run and saved_run.page_title + else "This Copilot Run" if saved_run else "No Run Connected" + ) ) run_url = furl(saved_run.get_app_url()).tostr() if saved_run else "" return run_title, run_url From ffb0a619efb118c77335d1bfad6a99fc9e6a275b Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 01:49:09 -0800 Subject: [PATCH 20/85] rename tables to remove "all" --- recipes/VideoBotsStats.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index b57c9f2ac..e625fc977 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -154,8 +154,8 @@ def render(self): details = st.horizontal_radio( "### Details", options=[ - "All Conversations", - "All Messages", + "Conversations", + "Messages", "Feedback Positive", "Feedback Negative", "Answered Successfully", @@ -164,7 +164,7 @@ def render(self): key="details", ) - if details == "All Conversations": + if details == "Conversations": options = [ "Messages", "Correct Answers", @@ -310,11 +310,11 @@ def parse_run_info(self, bi): run_title = ( bi.published_run.title if bi.published_run - else saved_run.page_title - if saved_run and saved_run.page_title - else "This Copilot Run" - if saved_run - else "No Run Connected" + else ( + saved_run.page_title + if saved_run and saved_run.page_title + else "This Copilot Run" if saved_run else "No Run Connected" + ) ) run_url = furl(saved_run.get_app_url()).tostr() if saved_run else "" return run_title, run_url @@ -626,9 +626,9 @@ def get_tabular_data( self, bi, run_url, conversations, messages, details, sort_by, rows=10000 ): df = pd.DataFrame() - if details == "All Conversations": + if details == "Conversations": df = conversations.to_df_format(row_limit=rows) - elif details == "All Messages": + elif details == "Messages": df = messages.order_by("-created_at", "conversation__id").to_df_format( row_limit=rows ) From d0b28f25283c590d5ea6d4c1c27c6ab94871b15b Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 03:42:21 -0800 Subject: [PATCH 21/85] average response time --- recipes/VideoBotsStats.py | 45 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index e625fc977..4939944cd 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -31,7 +31,8 @@ TruncYear, Concat, ) -from django.db.models import Count +from django.db.models import Count, F, Window +from django.db.models.functions import Lag ID_COLUMNS = [ "conversation__fb_page_id", @@ -388,6 +389,33 @@ def calculate_overall_stats(self, bid, bi, run_title, run_url): def calculate_stats_binned_by_time( self, bi, start_date, end_date, factor, trunc_fn ): + average_response_time = ( + Message.objects.filter( + created_at__date__gte=start_date, + created_at__date__lte=end_date, + conversation__bot_integration=bi, + ) + .values("conversation_id") + .order_by("created_at") + .annotate( + response_time=F("created_at") - Window(expression=Lag("created_at")), + ) + .annotate(date=trunc_fn("created_at")) + .values("date", "response_time", "role") + ) + average_response_time = ( + pd.DataFrame( + average_response_time, + columns=["date", "response_time", "role"], + ) + .loc[lambda df: df["role"] == CHATML_ROLE_ASSISTANT] + .groupby("date") + .agg({"response_time": "median"}) + .apply(lambda x: x.clip(lower=timedelta(0))) + .rename(columns={"response_time": "Average_response_time"}) + .reset_index() + ) + messages_received = ( Message.objects.filter( created_at__date__gte=start_date, @@ -468,6 +496,12 @@ def calculate_stats_binned_by_time( left_on="date", right_on="date", ) + df = df.merge( + average_response_time, + how="outer", + left_on="date", + right_on="date", + ) df["Messages_Sent"] = df["Messages_Sent"] * factor df["Convos"] = df["Convos"] * factor df["Senders"] = df["Senders"] * factor @@ -476,6 +510,7 @@ def calculate_stats_binned_by_time( df["Neg_feedback"] = df["Neg_feedback"] * factor df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] + df["Average_response_time"] = df["Average_response_time"] * factor df.fillna(0, inplace=True) df = df.round(0).astype("int32", errors="ignore") return df @@ -576,6 +611,14 @@ def plot_graphs(self, view, df): text=list(df["Msgs_per_user"]), hovertemplate="Messages per User: %{y:.0f}", ), + go.Scatter( + name="Average Response Time", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Average_response_time"]), + text=list(df["Average_response_time"]), + hovertemplate="Average Response Time: %{y:.0f}", + ), ], layout=dict( margin=dict(l=0, r=0, t=28, b=0), From b359f65adf93f73ac359c7e4683ccea1a36426f8 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 03:42:26 -0800 Subject: [PATCH 22/85] linting --- bots/models.py | 12 +++++-- daras_ai_v2/bot_integration_widgets.py | 6 ++-- daras_ai_v2/bots.py | 12 ++++--- daras_ai_v2/language_model.py | 28 +++++++++------ .../language_model_settings_widgets.py | 6 ++-- daras_ai_v2/stable_diffusion.py | 12 +++---- gooeysite/urls.py | 1 + pyproject.toml | 3 ++ recipes/CompareLLM.py | 6 ++-- recipes/CompareText2Img.py | 6 ++-- recipes/CompareUpscaler.py | 6 ++-- recipes/DocExtract.py | 6 ++-- recipes/DocSearch.py | 6 ++-- recipes/DocSummary.py | 6 ++-- recipes/GoogleGPT.py | 6 ++-- recipes/ImageSegmentation.py | 6 ++-- recipes/Img2Img.py | 8 +++-- recipes/QRCodeGenerator.py | 12 +++---- recipes/SEOSummary.py | 6 ++-- recipes/SmartGPT.py | 6 ++-- recipes/SocialLookupEmail.py | 6 ++-- recipes/Text2Audio.py | 12 +++---- recipes/TextToSpeech.py | 6 ++-- recipes/VideoBots.py | 34 ++++++++++--------- 24 files changed, 120 insertions(+), 98 deletions(-) diff --git a/bots/models.py b/bots/models.py index 457f3bb6a..5148a531e 100644 --- a/bots/models.py +++ b/bots/models.py @@ -913,9 +913,11 @@ def to_df_format( "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) .strftime("%b %d, %Y %I:%M %p"), - "Feedback": message.feedbacks.first().get_display_text() - if message.feedbacks.first() - else None, # only show first feedback as per Sean's request + "Feedback": ( + message.feedbacks.first().get_display_text() + if message.feedbacks.first() + else None + ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, } rows.append(row) @@ -1051,6 +1053,10 @@ def __str__(self): def local_lang(self): return Truncator(self.display_content).words(30) + @property + def response_time(self): + return self.created_at - self.get_previous_by_created_at().created_at + class MessageAttachment(models.Model): message = models.ForeignKey( diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 3880f9fe7..2686b2298 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -19,9 +19,9 @@ def general_integration_settings(bi: BotIntegration): st.session_state[f"_bi_user_language_{bi.id}"] = BotIntegration._meta.get_field( "user_language" ).default - st.session_state[ - f"_bi_show_feedback_buttons_{bi.id}" - ] = BotIntegration._meta.get_field("show_feedback_buttons").default + st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( + BotIntegration._meta.get_field("show_feedback_buttons").default + ) st.session_state[f"_bi_analysis_url_{bi.id}"] = None bi.show_feedback_buttons = st.checkbox( diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index dfc00c794..8cc1d69e1 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -389,11 +389,13 @@ def _save_msgs( role=CHATML_ROLE_USER, content=response.raw_input_text, display_content=input_text, - saved_run=SavedRun.objects.get_or_create( - workflow=Workflow.ASR, **furl(speech_run).query.params - )[0] - if speech_run - else None, + saved_run=( + SavedRun.objects.get_or_create( + workflow=Workflow.ASR, **furl(speech_run).query.params + )[0] + if speech_run + else None + ), ) attachments = [] for f_url in (input_images or []) + (input_documents or []): diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 80136cbe8..070aa1c22 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -198,12 +198,14 @@ def calc_gpt_tokens( for entry in messages if ( content := ( - format_chatml_message(entry) + "\n" - if is_chat_model - else entry.get("content", "") + ( + format_chatml_message(entry) + "\n" + if is_chat_model + else entry.get("content", "") + ) + if isinstance(entry, dict) + else str(entry) ) - if isinstance(entry, dict) - else str(entry) ) ) return default_length_function(combined) @@ -364,9 +366,11 @@ def run_language_model( else: out_content = [ # return messages back as either chatml or json messages - format_chatml_message(entry) - if is_chatml - else (entry.get("content") or "").strip() + ( + format_chatml_message(entry) + if is_chatml + else (entry.get("content") or "").strip() + ) for entry in result ] if tools: @@ -514,9 +518,11 @@ def _run_openai_chat( frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, tools=[tool.spec for tool in tools] if tools else NOT_GIVEN, - response_format={"type": response_format_type} - if response_format_type - else NOT_GIVEN, + response_format=( + {"type": response_format_type} + if response_format_type + else NOT_GIVEN + ), ) for model_str in model ], diff --git a/daras_ai_v2/language_model_settings_widgets.py b/daras_ai_v2/language_model_settings_widgets.py index e5ab27a59..80f785bd4 100644 --- a/daras_ai_v2/language_model_settings_widgets.py +++ b/daras_ai_v2/language_model_settings_widgets.py @@ -24,9 +24,9 @@ def language_model_settings(show_selector=True, show_document_model=False): f"###### {field_title_desc(VideoBotsPage.RequestModel, 'document_model')}", key="document_model", options=[None, *doc_model_descriptions], - format_func=lambda x: f"{doc_model_descriptions[x]} ({x})" - if x - else "———", + format_func=lambda x: ( + f"{doc_model_descriptions[x]} ({x})" if x else "———" + ), ) st.checkbox("Avoid Repetition", key="avoid_repetition") diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index 22c7adc8e..ce6333068 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -247,9 +247,9 @@ def instruct_pix2pix( }, inputs={ "prompt": [prompt] * len(images), - "negative_prompt": [negative_prompt] * len(images) - if negative_prompt - else None, + "negative_prompt": ( + [negative_prompt] * len(images) if negative_prompt else None + ), "num_images_per_prompt": num_outputs, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, @@ -440,9 +440,9 @@ def controlnet( pipeline={ "model_id": text2img_model_ids[Text2ImgModels[selected_model]], "seed": seed, - "scheduler": Schedulers[scheduler].label - if scheduler - else "UniPCMultistepScheduler", + "scheduler": ( + Schedulers[scheduler].label if scheduler else "UniPCMultistepScheduler" + ), "disable_safety_checker": True, "controlnet_model_id": [ controlnet_model_ids[ControlNetModels[model]] diff --git a/gooeysite/urls.py b/gooeysite/urls.py index 767660f21..6c3809436 100644 --- a/gooeysite/urls.py +++ b/gooeysite/urls.py @@ -14,6 +14,7 @@ 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ + from django.contrib import admin from django.urls import path diff --git a/pyproject.toml b/pyproject.toml index c6a699024..01ccdd04e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,3 +93,6 @@ pre-commit = "^3.5.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.black] +force-exclude = "migrations" diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 583317ddc..f1be5a937 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -36,9 +36,9 @@ class CompareLLMPage(BasePage): class RequestModel(BaseModel): input_prompt: str | None - selected_models: list[ - typing.Literal[tuple(e.name for e in LargeLanguageModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in LargeLanguageModels)]] | None + ) avoid_repetition: bool | None num_outputs: int | None diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index e79ef5d54..5f44ca34b 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -64,9 +64,9 @@ class RequestModel(BaseModel): seed: int | None sd_2_upscaling: bool | None - selected_models: list[ - typing.Literal[tuple(e.name for e in Text2ImgModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in Text2ImgModels)]] | None + ) scheduler: typing.Literal[tuple(e.name for e in Schedulers)] | None edit_instruction: str | None diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index f6f4e988b..e4ed8865f 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -24,9 +24,9 @@ class RequestModel(BaseModel): scale: int - selected_models: list[ - typing.Literal[tuple(e.name for e in UpscalerModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in UpscalerModels)]] | None + ) class ResponseModel(BaseModel): output_images: dict[typing.Literal[tuple(e.name for e in UpscalerModels)], str] diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 6c9489fff..2672fffdb 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -77,9 +77,9 @@ class RequestModel(BaseModel): task_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 2a509d8fd..d18b5a54c 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -60,9 +60,9 @@ class RequestModel(DocSearchRequest): task_instructions: str | None query_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 4b9283cde..498f2d405 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -60,9 +60,9 @@ class RequestModel(BaseModel): task_instructions: str | None merge_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 57b1078ba..7a2d2591f 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -79,9 +79,9 @@ class RequestModel(GoogleSearchMixin, BaseModel): task_instructions: str | None query_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 543664065..751a0936b 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -49,9 +49,9 @@ class ImageSegmentationPage(BasePage): class RequestModel(BaseModel): input_image: str - selected_model: typing.Literal[ - tuple(e.name for e in ImageSegmentationModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in ImageSegmentationModels)] | None + ) mask_threshold: float | None rect_persepective_transform: bool | None diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index a97bc7c62..68e696c08 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -46,9 +46,11 @@ class RequestModel(BaseModel): text_prompt: str | None selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None - selected_controlnet_model: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)] - ] | typing.Literal[tuple(e.name for e in ControlNetModels)] | None + selected_controlnet_model: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)]] + | typing.Literal[tuple(e.name for e in ControlNetModels)] + | None + ) negative_prompt: str | None num_outputs: int | None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 6b7260368..34b615225 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -88,18 +88,18 @@ class RequestModel(BaseModel): text_prompt: str negative_prompt: str | None image_prompt: str | None - image_prompt_controlnet_models: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)], ... - ] | None + image_prompt_controlnet_models: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None + ) image_prompt_strength: float | None image_prompt_scale: float | None image_prompt_pos_x: float | None image_prompt_pos_y: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None - selected_controlnet_model: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)], ... - ] | None + selected_controlnet_model: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None + ) output_width: int | None output_height: int | None diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index a82424ae1..78328d822 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -98,9 +98,9 @@ class RequestModel(GoogleSearchMixin, BaseModel): enable_html: bool | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) sampling_temperature: float | None max_tokens: int | None num_outputs: int | None diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index 554a74940..0aabcfefb 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -34,9 +34,9 @@ class RequestModel(BaseModel): reflexion_prompt: str | None dera_prompt: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index 082fdcdaa..1094cea53 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -40,9 +40,9 @@ class RequestModel(BaseModel): domain: str | None key_words: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) sampling_temperature: float | None max_tokens: int | None diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 77776ddf5..18270cc8c 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -49,9 +49,9 @@ class RequestModel(BaseModel): seed: int | None sd_2_upscaling: bool | None - selected_models: list[ - typing.Literal[tuple(e.name for e in Text2AudioModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in Text2AudioModels)]] | None + ) class ResponseModel(BaseModel): output_audios: dict[ @@ -114,9 +114,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ), inputs=dict( prompt=[request.text_prompt], - negative_prompt=[request.negative_prompt] - if request.negative_prompt - else None, + negative_prompt=( + [request.negative_prompt] if request.negative_prompt else None + ), num_waveforms_per_prompt=request.num_outputs, num_inference_steps=request.quality, guidance_scale=request.guidance_scale, diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index c37a8eb5f..721b2cb7a 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -53,9 +53,9 @@ class TextToSpeechPage(BasePage): class RequestModel(BaseModel): text_prompt: str - tts_provider: typing.Literal[ - tuple(e.name for e in TextToSpeechProviders) - ] | None + tts_provider: ( + typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + ) uberduck_voice_name: str | None uberduck_speaking_rate: float | None diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 149931ffb..faa88d533 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -157,9 +157,9 @@ class RequestModel(BaseModel): messages: list[ConversationEntry] | None # tts settings - tts_provider: typing.Literal[ - tuple(e.name for e in TextToSpeechProviders) - ] | None + tts_provider: ( + typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + ) uberduck_voice_name: str | None uberduck_speaking_rate: float | None google_voice_name: str | None @@ -174,9 +174,9 @@ class RequestModel(BaseModel): elevenlabs_similarity_boost: float | None # llm settings - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) document_model: str | None = Field( title="🩻 Photo / Document Intelligence", description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? " @@ -1028,10 +1028,12 @@ def messenger_bot_integration(self): favicon = Platform(bi.platform).get_favicon() with st.div(className="mt-2"): st.markdown( - f'  ' - f'{bi}' - if bi.saved_run - else f"{bi}", + ( + f'  ' + f'{bi}' + if bi.saved_run + else f"{bi}" + ), unsafe_allow_html=True, ) with col2: @@ -1082,9 +1084,9 @@ def slack_specific_settings(self, bi: BotIntegration): st.session_state[f"_bi_name_{bi.id}"] = ( pr and pr.title ) or self.get_recipe_title() - st.session_state[ - f"_bi_slack_read_receipt_msg_{bi.id}" - ] = BotIntegration._meta.get_field("slack_read_receipt_msg").default + st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( + BotIntegration._meta.get_field("slack_read_receipt_msg").default + ) bi.slack_read_receipt_msg = st.text_input( """ @@ -1283,9 +1285,9 @@ def msg_container_widget(role: str): return st.div( className="px-3 py-1 pt-2", style=dict( - background="rgba(239, 239, 239, 0.6)" - if role == CHATML_ROLE_USER - else "#fff", + background=( + "rgba(239, 239, 239, 0.6)" if role == CHATML_ROLE_USER else "#fff" + ), ), ) From e00411135e7aeaaf5b9897c3ef1e614a1b174dd4 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 03:51:42 -0800 Subject: [PATCH 23/85] add response time to table --- bots/models.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/bots/models.py b/bots/models.py index 5148a531e..61c56abc3 100644 --- a/bots/models.py +++ b/bots/models.py @@ -919,6 +919,7 @@ def to_df_format( else None ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, + "Response Time": message.response_time.total_seconds(), } rows.append(row) df = pd.DataFrame.from_records( @@ -930,6 +931,7 @@ def to_df_format( "Sent", "Feedback", "Analysis JSON", + "Response Time", ], ) return df @@ -1055,7 +1057,20 @@ def local_lang(self): @property def response_time(self): - return self.created_at - self.get_previous_by_created_at().created_at + import pandas as pd + + if self.role == CHATML_ROLE_USER: + return pd.NaT + return ( + self.created_at + - Message.objects.filter( + conversation=self.conversation, + role=CHATML_ROLE_USER, + created_at__lt=self.created_at, + ) + .latest() + .created_at + ) class MessageAttachment(models.Model): From c4603417fa3f79b8c771894ecc19a598c5bf3364 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 29 Jan 2024 19:27:35 +0530 Subject: [PATCH 24/85] cache bm25 embeds for copilot search --- daras_ai_v2/functional.py | 20 ++--- daras_ai_v2/vector_search.py | 139 +++++++++++++++++++++-------------- gooey_ui/pubsub.py | 6 +- 3 files changed, 98 insertions(+), 67 deletions(-) diff --git a/daras_ai_v2/functional.py b/daras_ai_v2/functional.py index 625e9c026..f8415ec87 100644 --- a/daras_ai_v2/functional.py +++ b/daras_ai_v2/functional.py @@ -8,8 +8,8 @@ def flatapply_parallel( - fn: typing.Callable[[T], list[R]], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., list[R]], + *iterables, max_workers: int = None, message: str = "", ) -> typing.Generator[str, None, list[R]]: @@ -20,8 +20,8 @@ def flatapply_parallel( def apply_parallel( - fn: typing.Callable[[T], R], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., R], + *iterables, max_workers: int = None, message: str = "", ) -> typing.Generator[str, None, list[R]]: @@ -42,8 +42,8 @@ def apply_parallel( def fetch_parallel( - fn: typing.Callable[[T], R], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., R], + *iterables, max_workers: int = None, ) -> typing.Generator[R, None, None]: assert iterables, "fetch_parallel() requires at least one iterable" @@ -57,16 +57,16 @@ def fetch_parallel( def flatmap_parallel( - fn: typing.Callable[[T], list[R]], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., list[R]], + *iterables, max_workers: int = None, ) -> list[R]: return flatten(map_parallel(fn, *iterables, max_workers=max_workers)) def map_parallel( - fn: typing.Callable[[T], R], - *iterables: typing.Sequence[T], + fn: typing.Callable[..., R], + *iterables, max_workers: int = None, ) -> list[R]: assert iterables, "map_parallel() requires at least one iterable" diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 5cf08c721..f3da60727 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -32,7 +32,7 @@ ) from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS -from daras_ai_v2.functional import flatmap_parallel +from daras_ai_v2.functional import flatmap_parallel, map_parallel from daras_ai_v2.gdrive_downloader import ( gdrive_download, is_gdrive_url, @@ -87,20 +87,23 @@ def get_top_k_references( Returns: the top k documents """ - yield "Getting embeddings..." + yield "Checking docs..." input_docs = request.documents or [] + doc_metas = map_parallel(doc_url_to_metadata, input_docs) + + yield "Getting embeddings..." embeds: list[tuple[SearchReference, np.ndarray]] = flatmap_parallel( - lambda f_url: doc_url_to_embeds( + lambda f_url, doc_meta: get_embeds_for_doc( f_url=f_url, + doc_meta=doc_meta, max_context_words=request.max_context_words, scroll_jump=request.scroll_jump, selected_asr_model=request.selected_asr_model, google_translate_target=request.google_translate_target, ), input_docs, - max_workers=10, + doc_metas, ) - dense_query_embeds = openai_embedding_create([request.search_query])[0] yield "Searching documents..." @@ -129,12 +132,21 @@ def get_top_k_references( dense_ranks = np.zeros(len(embeds)) if sparse_weight: + yield "Getting sparse scores..." # get sparse scores - tokenized_corpus = [ - bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"]) - for ref, _ in embeds - ] - bm25 = BM25Okapi(tokenized_corpus, k1=2, b=0.3) + bm25_corpus = flatmap_parallel( + lambda f_url, doc_meta: get_bm25_embeds_for_doc( + f_url=f_url, + doc_meta=doc_meta, + max_context_words=request.max_context_words, + scroll_jump=request.scroll_jump, + selected_asr_model=request.selected_asr_model, + google_translate_target=request.google_translate_target, + ), + input_docs, + doc_metas, + ) + bm25 = BM25Okapi(bm25_corpus, k1=2, b=0.3) if request.keyword_query and isinstance(request.keyword_query, list): sparse_query_tokenized = [item.lower() for item in request.keyword_query] else: @@ -211,38 +223,6 @@ def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str: ) -def doc_url_to_embeds( - *, - f_url: str, - max_context_words: int, - scroll_jump: int, - selected_asr_model: str = None, - google_translate_target: str = None, -) -> list[tuple[SearchReference, np.ndarray]]: - """ - Get document embeddings for a given document url. - - Args: - f_url: document url - max_context_words: max number of words to include in each chunk - scroll_jump: number of words to scroll by - google_translate_target: target language for google translate - selected_asr_model: selected ASR model (used for audio files) - - Returns: - list of (SearchReference, embeddings vector) tuples - """ - doc_meta = doc_url_to_metadata(f_url) - return get_embeds_for_doc( - f_url=f_url, - doc_meta=doc_meta, - max_context_words=max_context_words, - scroll_jump=scroll_jump, - selected_asr_model=selected_asr_model, - google_translate_target=google_translate_target, - ) - - class DocMetadata(typing.NamedTuple): name: str etag: str | None @@ -321,6 +301,35 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: ) +@redis_cache_decorator +def get_bm25_embeds_for_doc( + *, + f_url: str, + doc_meta: DocMetadata, + max_context_words: int, + scroll_jump: int, + google_translate_target: str = None, + selected_asr_model: str = None, +): + pages = doc_url_to_text_pages( + f_url=f_url, + doc_meta=doc_meta, + selected_asr_model=selected_asr_model, + google_translate_target=google_translate_target, + ) + refs = pages_to_split_refs( + pages=pages, + f_url=f_url, + doc_meta=doc_meta, + max_context_words=max_context_words, + scroll_jump=scroll_jump, + ) + tokenized_corpus = [ + bm25_tokenizer(ref["title"]) + bm25_tokenizer(ref["snippet"]) for ref in refs + ] + return tokenized_corpus + + @redis_cache_decorator def get_embeds_for_doc( *, @@ -345,18 +354,44 @@ def get_embeds_for_doc( Returns: list of (metadata, embeddings) tuples """ - import pandas as pd - pages = doc_url_to_text_pages( f_url=f_url, doc_meta=doc_meta, selected_asr_model=selected_asr_model, google_translate_target=google_translate_target, ) + refs = pages_to_split_refs( + pages=pages, + f_url=f_url, + doc_meta=doc_meta, + max_context_words=max_context_words, + scroll_jump=scroll_jump, + ) + texts = [m["title"] + " | " + m["snippet"] for m in refs] + # get doc embeds in batches + batch_size = 16 # azure openai limits + embeds = flatmap_parallel( + openai_embedding_create, + [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)], + max_workers=2, + ) + return list(zip(refs, embeds)) + + +def pages_to_split_refs( + *, + pages, + f_url: str, + doc_meta: DocMetadata, + max_context_words: int, + scroll_jump: int, +) -> list[SearchReference]: + import pandas as pd + chunk_size = int(max_context_words * 2) chunk_overlap = int(max_context_words * 2 / scroll_jump) if isinstance(pages, pd.DataFrame): - metas = [] + refs = [] # treat each row as a separate document for idx, row in pages.iterrows(): row = dict(row) @@ -372,7 +407,7 @@ def get_embeds_for_doc( ) else: continue - metas += [ + refs += [ { "title": doc_meta.name, "url": f_url, @@ -385,7 +420,7 @@ def get_embeds_for_doc( ] else: # split the text into chunks - metas = [ + refs = [ { "title": ( doc_meta.name + (f", page {doc.end + 1}" if len(pages) > 1 else "") @@ -403,15 +438,7 @@ def get_embeds_for_doc( pages, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) ] - # get doc embeds in batches - batch_size = 16 # azure openai limits - texts = [m["title"] + " | " + m["snippet"] for m in metas] - embeds = flatmap_parallel( - openai_embedding_create, - [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)], - max_workers=5, - ) - return list(zip(metas, embeds)) + return refs sections_re = re.compile(r"(\s*[\r\n\f\v]|^)(\w+)\=", re.MULTILINE) diff --git a/gooey_ui/pubsub.py b/gooey_ui/pubsub.py index ca3dfcb1a..5e210c956 100644 --- a/gooey_ui/pubsub.py +++ b/gooey_ui/pubsub.py @@ -42,7 +42,11 @@ def realtime_push(channel: str, value: typing.Any = "ping"): msg = json.dumps(jsonable_encoder(value)) r.set(channel, msg) r.publish(channel, json.dumps(time())) - logger.info(f"publish {channel=}") + if isinstance(value, dict): + run_status = value.get("__run_status") + logger.info(f"publish {channel=} {run_status=}") + else: + logger.info(f"publish {channel=}") # def use_state( From dab3f576e036e5a6ffe4b2df888b72d5496e453e Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 29 Jan 2024 19:53:14 +0530 Subject: [PATCH 25/85] handle case where sparse and dense ranks are of different length --- daras_ai_v2/vector_search.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index f3da60727..394d55b55 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -12,6 +12,7 @@ import requests from furl import furl from googleapiclient.errors import HttpError +from loguru import logger from pydantic import BaseModel, Field from rank_bm25 import BM25Okapi @@ -162,6 +163,13 @@ def get_top_k_references( else: sparse_ranks = np.zeros(len(embeds)) + # just in case sparse and dense ranks are different lengths, truncate to the shorter one + if len(sparse_ranks) != len(dense_ranks): + logger.warning( + f"sparse and dense ranks are different lengths, truncating... {len(sparse_ranks)=} {len(dense_ranks)=} {len(embeds)=}" + ) + sparse_ranks = sparse_ranks[: len(dense_ranks)] + dense_ranks = dense_ranks[: len(sparse_ranks)] # RRF formula: 1 / (k + rank) k = 60 rrf_scores = ( @@ -171,11 +179,7 @@ def get_top_k_references( # Final ranking max_references = min(request.max_references, len(rrf_scores)) top_k = np.argpartition(rrf_scores, -max_references)[-max_references:] - final_ranks = sorted( - top_k, - key=lambda idx: rrf_scores[idx], - reverse=True, - ) + final_ranks = sorted(top_k, key=rrf_scores.__getitem__, reverse=True) references = [embeds[idx][0] | {"score": rrf_scores[idx]} for idx in final_ranks] From 950595b7b2ce858e12f0dbe4b9a3e3feba1281c6 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 29 Jan 2024 20:38:26 +0530 Subject: [PATCH 26/85] disable click analytics because iplist.cc is down --- url_shortener/routers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/url_shortener/routers.py b/url_shortener/routers.py index fba39ffa9..1b3991cf8 100644 --- a/url_shortener/routers.py +++ b/url_shortener/routers.py @@ -21,10 +21,11 @@ def url_shortener(hashid: str, request: Request): return Response(status_code=410, content="This link has expired") # increment the click count ShortenedURL.objects.filter(id=surl.id).update(clicks=F("clicks") + 1) - if surl.enable_analytics: - save_click_info.delay( - surl.id, request.client.host, request.headers.get("user-agent", "") - ) + # disable because iplist.cc is down + # if surl.enable_analytics: + # save_click_info.delay( + # surl.id, request.client.host, request.headers.get("user-agent", "") + # ) if surl.url: return RedirectResponse( url=surl.url, status_code=303 # because youtu.be redirects are 303 From 075e35d469f4859d23c2d9d84cd7d8dd0244f6ac Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 29 Jan 2024 20:39:10 +0530 Subject: [PATCH 27/85] search published run by title and notes --- bots/admin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bots/admin.py b/bots/admin.py index 94eb48997..e659865a8 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -223,7 +223,7 @@ class PublishedRunAdmin(admin.ModelAdmin): "updated_at", ] list_filter = ["workflow", "visibility", "created_by__is_paying"] - search_fields = ["workflow", "published_run_id"] + search_fields = ["workflow", "published_run_id", "title", "notes"] autocomplete_fields = ["saved_run", "created_by", "last_edited_by"] readonly_fields = [ "open_in_gooey", From 8b63ed73155d541f6489c6f1c34ca73d3a86fd03 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 25 Jan 2024 17:38:08 -0800 Subject: [PATCH 28/85] show the created at field on the edit view of visitor infos --- url_shortener/admin.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/url_shortener/admin.py b/url_shortener/admin.py index f4d6bb4c0..e7f102a7a 100644 --- a/url_shortener/admin.py +++ b/url_shortener/admin.py @@ -137,3 +137,6 @@ class VisitorClickInfoAdmin(admin.ModelAdmin): ordering = ["-created_at"] autocomplete_fields = ["shortened_url"] actions = [export_to_csv, export_to_excel] + readonly_fields = [ + "created_at", + ] From 38badb99343156997dcff01725fd2e6fde16df38 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 10:36:33 -0800 Subject: [PATCH 29/85] recursive folder expansion --- daras_ai_v2/gdrive_downloader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index 339a634a3..ee1f7453c 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -25,7 +25,7 @@ def url_to_gdrive_file_id(f: furl) -> str: return file_id -def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: +def gdrive_list_urls_of_files_in_folder(f: furl, max_depth=10) -> list[str]: # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) folder_id = f.path.segments[-1] service = discovery.build("drive", "v3") @@ -37,7 +37,7 @@ def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: fields="files(mimeType,webViewLink)", ) else: - raise ValueError(f"Can't list files non google folder url: {str(f)!r}") + raise ValueError(f"Can't list files from non google folder url: {str(f)!r}") response = request.execute() files = response.get("files", []) urls = [] @@ -45,7 +45,9 @@ def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: mime_type = file.get("mimeType") url = file.get("webViewLink") if mime_type == "application/vnd.google-apps.folder": - continue + urls += gdrive_list_urls_of_files_in_folder( + furl(url), max_depth=max_depth - 1 + ) elif url: urls.append(url) return urls From 1c8179cc09247a3437b692127f9103766d5d1b26 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 10:37:10 -0800 Subject: [PATCH 30/85] stop recursing at max depth lol --- daras_ai_v2/gdrive_downloader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index ee1f7453c..345099aba 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -26,6 +26,8 @@ def url_to_gdrive_file_id(f: furl) -> str: def gdrive_list_urls_of_files_in_folder(f: furl, max_depth=10) -> list[str]: + if max_depth <= 0: + return [] # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) folder_id = f.path.segments[-1] service = discovery.build("drive", "v3") From d79ab396300c596efbfaca82a93a146497c54548 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 31 Jan 2024 13:20:51 +0530 Subject: [PATCH 31/85] support workflow enum value in urls for metabase --- daras_ai_v2/all_pages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daras_ai_v2/all_pages.py b/daras_ai_v2/all_pages.py index f4a317416..096150662 100644 --- a/daras_ai_v2/all_pages.py +++ b/daras_ai_v2/all_pages.py @@ -108,7 +108,7 @@ def normalize_slug(page_slug): normalize_slug(slug): page for page in (all_api_pages + all_hidden_pages) for slug in page.slug_versions -} +} | {str(page.workflow.value): page for page in (all_api_pages + all_hidden_pages)} workflow_map: dict[Workflow, typing.Type[BasePage]] = { page.workflow: page for page in all_api_pages From 691bce583a425aaa701682ea60a1427ab87b8f93 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 30 Jan 2024 22:00:17 +0530 Subject: [PATCH 32/85] compare llm support for streaming --- daras_ai_v2/language_model.py | 151 ++++++++++++++++++++++++++++------ recipes/CompareLLM.py | 13 +-- 2 files changed, 134 insertions(+), 30 deletions(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 80136cbe8..451d7b1d4 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -18,7 +18,11 @@ from django.conf import settings from jinja2.lexer import whitespace_re from loguru import logger -from openai.types.chat import ChatCompletionContentPartParam +from openai import Stream +from openai.types.chat import ( + ChatCompletionContentPartParam, + ChatCompletionChunk, +) from daras_ai_v2.asr import get_google_auth_session from daras_ai_v2.exceptions import raise_for_status @@ -27,7 +31,10 @@ from daras_ai_v2.redis_cache import ( get_redis_cache, ) -from daras_ai_v2.text_splitter import default_length_function +from daras_ai_v2.text_splitter import ( + default_length_function, + default_separators, +) DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible." @@ -325,8 +332,13 @@ def run_language_model( stop: list[str] = None, avoid_repetition: bool = False, tools: list[LLMTools] = None, + stream: bool = False, response_format_type: typing.Literal["text", "json_object"] = None, -) -> list[str] | tuple[list[str], list[list[dict]]] | list[dict]: +) -> ( + list[str] + | tuple[list[str], list[list[dict]]] + | typing.Generator[list[str], None, None] +): assert bool(prompt) != bool( messages ), "Pleave provide exactly one of { prompt, messages }" @@ -334,10 +346,9 @@ def run_language_model( model: LargeLanguageModels = LargeLanguageModels[str(model)] api = llm_api[model] model_name = llm_model_names[model] + is_chatml = False if model.is_chat_model(): - if messages: - is_chatml = False - else: + if not messages: # if input is chatml, convert it into json messages is_chatml, messages = parse_chatml(prompt) # type: ignore messages = messages or [] @@ -347,7 +358,7 @@ def run_language_model( format_chat_entry(role=entry["role"], content=get_entry_text(entry)) for entry in messages ] - result = _run_chat_model( + entries = _run_chat_model( api=api, model=model_name, messages=messages, # type: ignore @@ -358,26 +369,18 @@ def run_language_model( avoid_repetition=avoid_repetition, tools=tools, response_format_type=response_format_type, + # we can't stream with tools or json yet + stream=stream and not (tools or response_format_type), ) - if response_format_type == "json_object": - out_content = [json.loads(entry["content"]) for entry in result] + if stream: + return _stream_llm_outputs(entries, is_chatml, response_format_type, tools) else: - out_content = [ - # return messages back as either chatml or json messages - format_chatml_message(entry) - if is_chatml - else (entry.get("content") or "").strip() - for entry in result - ] - if tools: - return out_content, [(entry.get("tool_calls") or []) for entry in result] - else: - return out_content + return _parse_entries(entries, is_chatml, response_format_type, tools) else: if tools: raise ValueError("Only OpenAI chat models support Tools") logger.info(f"{model_name=}, {len(prompt)=}, {max_tokens=}, {temperature=}") - result = _run_text_model( + msgs = _run_text_model( api=api, model=model_name, prompt=prompt, @@ -388,7 +391,41 @@ def run_language_model( avoid_repetition=avoid_repetition, quality=quality, ) - return [msg.strip() for msg in result] + ret = [msg.strip() for msg in msgs] + if stream: + ret = [ret] + return ret + + +def _stream_llm_outputs(result, is_chatml, response_format_type, tools): + if isinstance(result, list): # compatibility with non-streaming apis + result = [result] + for entries in result: + yield _parse_entries(entries, is_chatml, response_format_type, tools) + + +def _parse_entries( + entries: list[dict], + is_chatml: bool, + response_format_type: typing.Literal["text", "json_object"] | None, + tools: list[dict] | None, +): + if response_format_type == "json_object": + ret = [json.loads(entry["content"]) for entry in entries] + else: + ret = [ + # return messages back as either chatml or json messages + ( + format_chatml_message(entry) + if is_chatml + else (entry.get("content") or "").strip() + ) + for entry in entries + ] + if tools: + return ret, [(entry.get("tool_calls") or []) for entry in entries] + else: + return ret def _run_text_model( @@ -439,7 +476,8 @@ def _run_chat_model( avoid_repetition: bool, tools: list[LLMTools] | None, response_format_type: typing.Literal["text", "json_object"] | None, -) -> list[ConversationEntry]: + stream: bool = False, +) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]: match api: case LLMApis.openai: return _run_openai_chat( @@ -452,6 +490,7 @@ def _run_chat_model( temperature=temperature, tools=tools, response_format_type=response_format_type, + stream=stream, ) case LLMApis.vertex_ai: if tools: @@ -490,7 +529,8 @@ def _run_openai_chat( avoid_repetition: bool, tools: list[LLMTools] | None, response_format_type: typing.Literal["text", "json_object"] | None, -) -> list[ConversationEntry]: + stream: bool = False, +) -> list[ConversationEntry] | typing.Generator[list[ConversationEntry], None, None]: from openai._types import NOT_GIVEN if avoid_repetition: @@ -517,11 +557,72 @@ def _run_openai_chat( response_format={"type": response_format_type} if response_format_type else NOT_GIVEN, + stream=stream, ) for model_str in model ], ) - return [choice.message.dict() for choice in r.choices] + if stream: + return _stream_openai_chunked(r) + else: + return [choice.message.dict() for choice in r.choices] + + +def _stream_openai_chunked( + r: Stream[ChatCompletionChunk], + start_chunk_size: int = 50, + stop_chunk_size: int = 400, + step_chunk_size: int = 150, +): + ret = [] + chunk_size = start_chunk_size + + for completion_chunk in r: + changed = False + for choice in completion_chunk.choices: + try: + entry = ret[choice.index] + except IndexError: + # initialize the entry + entry = choice.delta.dict() | {"content": "", "chunk": ""} + ret.append(entry) + + # append the delta to the current chunk + if not choice.delta.content: + continue + entry["chunk"] += choice.delta.content + # if the chunk is too small, we need to wait for more data + chunk = entry["chunk"] + if len(chunk) < chunk_size: + continue + + # iterate through the separators and find the best one that matches + for sep in default_separators[:-1]: + # find the last occurrence of the separator + match = None + for match in re.finditer(sep, chunk): + pass + if not match: + continue # no match, try the next separator or wait for more data + # append text before the separator to the content + part = chunk[: match.end()] + if len(part) < chunk_size: + continue # not enough text, try the next separator or wait for more data + entry["content"] += part + # set text after the separator as the next chunk + entry["chunk"] = chunk[match.end() :] + # increase the chunk size, but don't go over the max + chunk_size = min(chunk_size + step_chunk_size, stop_chunk_size) + # we found a separator, so we can stop looking and yield the partial result + changed = True + break + if changed: + yield ret + + # add the leftover chunks + for entry in ret: + entry["content"] += entry["chunk"] + yield ret @retry_if(openai_should_retry) diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 583317ddc..7b309866f 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -1,10 +1,9 @@ import random import typing - -import gooey_ui as st from pydantic import BaseModel +import gooey_ui as st from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_multiselect @@ -94,9 +93,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["output_text"] = output_text = {} for selected_model in request.selected_models: - yield f"Running {LargeLanguageModels[selected_model].value}..." - - output_text[selected_model] = run_language_model( + model = LargeLanguageModels[selected_model] + yield f"Running {model.value}..." + ret = run_language_model( model=selected_model, quality=request.quality, num_outputs=request.num_outputs, @@ -104,7 +103,11 @@ def run(self, state: dict) -> typing.Iterator[str | None]: prompt=prompt, max_tokens=request.max_tokens, avoid_repetition=request.avoid_repetition, + stream=True, ) + for i, item in enumerate(ret): + output_text[selected_model] = item + yield f"Streaming {model.value}... {i + 1}" def render_output(self): self._render_outputs(st.session_state, 450) From 2cd686a0f78d2d1416d5eaff9e8571d93d3c0d99 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 31 Jan 2024 19:30:22 +0530 Subject: [PATCH 33/85] streaming progres in superscript --- daras_ai_v2/language_model.py | 3 +++ recipes/CompareLLM.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 451d7b1d4..99452a193 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -45,6 +45,9 @@ CHATML_ROLE_ASSISTANT = "assistant" CHATML_ROLE_USER = "user" +# nice for showing streaming progress +SUPERSCRIPT = str.maketrans("0123456789", "⁰¹²³⁴⁵⁶⁷⁸⁹") + class LLMApis(Enum): vertex_ai = "Vertex AI" diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 7b309866f..c07bcffa6 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -11,6 +11,7 @@ run_language_model, LargeLanguageModels, llm_price, + SUPERSCRIPT, ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.loom_video_widget import youtube_video @@ -107,7 +108,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) for i, item in enumerate(ret): output_text[selected_model] = item - yield f"Streaming {model.value}... {i + 1}" + yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." def render_output(self): self._render_outputs(st.session_state, 450) From 975dae303516cac0600a7aef225e8ce63bc1518c Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 31 Jan 2024 20:16:54 +0530 Subject: [PATCH 34/85] streaming support for videobots --- daras_ai_v2/search_ref.py | 272 +++++++++++++++++++++----------------- recipes/DocSearch.py | 8 +- recipes/VideoBots.py | 83 +++++++----- 3 files changed, 205 insertions(+), 158 deletions(-) diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py index 1baf64b47..7cc810b1a 100644 --- a/daras_ai_v2/search_ref.py +++ b/daras_ai_v2/search_ref.py @@ -57,150 +57,178 @@ def render_text_with_refs(text: str, references: list[SearchReference]): return html -def apply_response_template( +def apply_response_formattings_prefix( output_text: list[str], references: list[SearchReference], citation_style: CitationStyles | None = CitationStyles.number, -): +) -> list[dict[int, SearchReference]]: + all_refs_list = [{}] * len(output_text) for i, text in enumerate(output_text): - formatted = "" - all_refs = {} - - for snippet, ref_map in parse_refs(text, references): - match citation_style: - case CitationStyles.number | CitationStyles.number_plaintext: - cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) - case CitationStyles.title: - cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) - case CitationStyles.url: - cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) - case CitationStyles.symbol | CitationStyles.symbol_plaintext: - cites = " ".join( - f"[{generate_footnote_symbol(ref_num - 1)}]" - for ref_num in ref_map.keys() - ) + all_refs_list[i], output_text[i] = format_citations( + text, references, citation_style + ) + return all_refs_list - case CitationStyles.markdown: - cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) - case CitationStyles.html: - cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) - case CitationStyles.slack_mrkdwn: - cites = " ".join( - ref_to_slack_mrkdwn(ref) for ref in ref_map.values() - ) - case CitationStyles.plaintext: - cites = " ".join( - f'[{ref["title"]} {ref["url"]}]' - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_markdown: - cites = " ".join( - markdown_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_html: - cites = " ".join( - html_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_slack_mrkdwn: - cites = " ".join( - slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) +def apply_response_formattings_suffix( + all_refs_list: list[dict[int, SearchReference]], + output_text: list[str], + citation_style: CitationStyles | None = CitationStyles.number, +): + for i, text in enumerate(output_text): + output_text[i] = format_jinja_response_template( + all_refs_list[i], + format_footnotes(all_refs_list[i], text, citation_style), + ) - case CitationStyles.symbol_markdown: - cites = " ".join( - markdown_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.symbol_html: - cites = " ".join( - html_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.symbol_slack_mrkdwn: - cites = " ".join( - slack_mrkdwn_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case None: - cites = "" - case _: - raise ValueError(f"Unknown citation style: {citation_style}") - formatted += snippet + " " + cites + " " - all_refs.update(ref_map) +def format_citations( + text: str, + references: list[SearchReference], + citation_style: CitationStyles | None = CitationStyles.number, +) -> tuple[dict[int, SearchReference], str]: + all_refs = {} + formatted = "" + for snippet, ref_map in parse_refs(text, references): match citation_style: + case CitationStyles.number | CitationStyles.number_plaintext: + cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) + case CitationStyles.title: + cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) + case CitationStyles.url: + cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) + case CitationStyles.symbol | CitationStyles.symbol_plaintext: + cites = " ".join( + f"[{generate_footnote_symbol(ref_num - 1)}]" + for ref_num in ref_map.keys() + ) + + case CitationStyles.markdown: + cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) + case CitationStyles.html: + cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) + case CitationStyles.slack_mrkdwn: + cites = " ".join(ref_to_slack_mrkdwn(ref) for ref in ref_map.values()) + case CitationStyles.plaintext: + cites = " ".join( + f'[{ref["title"]} {ref["url"]}]' for ref_num, ref in ref_map.items() + ) + case CitationStyles.number_markdown: - formatted += "\n\n" - formatted += "\n".join( - f"[{ref_num}] {ref_to_markdown(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + markdown_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.number_html: - formatted += "

" - formatted += "
".join( - f"[{ref_num}] {ref_to_html(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + html_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.number_slack_mrkdwn: - formatted += "\n\n" - formatted += "\n".join( - f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}" - for ref_num, ref in sorted(all_refs.items()) - ) - case CitationStyles.number_plaintext: - formatted += "\n\n" - formatted += "\n".join( - f'{ref_num}. {ref["title"]} {ref["url"]}' - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_markdown: - formatted += "\n\n" - formatted += "\n".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + markdown_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_html: - formatted += "

" - formatted += "
".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + html_link(f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_slack_mrkdwn: - formatted += "\n\n" - formatted += "\n".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" - for ref_num, ref in sorted(all_refs.items()) - ) - case CitationStyles.symbol_plaintext: - formatted += "\n\n" - formatted += "\n".join( - f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' - for ref_num, ref in sorted(all_refs.items()) - ) - - for ref_num, ref in all_refs.items(): - try: - template = ref["response_template"] - except KeyError: - pass - else: - formatted = jinja2.Template(template).render( - **ref, - output_text=formatted, - ref_num=ref_num, + cites = " ".join( + slack_mrkdwn_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) - output_text[i] = formatted + case None: + cites = "" + case _: + raise ValueError(f"Unknown citation style: {citation_style}") + formatted += " ".join(filter(None, [snippet, cites])) + all_refs.update(ref_map) + return all_refs, formatted + + +def format_footnotes( + all_refs: dict[int, SearchReference], formatted: str, citation_style: CitationStyles +) -> str: + match citation_style: + case CitationStyles.number_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"[{ref_num}] {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_html: + formatted += "

" + formatted += "
".join( + f"[{ref_num}] {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{ref_num}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + + case CitationStyles.symbol_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_html: + formatted += "

" + formatted += "
".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + return formatted + + +def format_jinja_response_template( + all_refs: dict[int, SearchReference], formatted: str +) -> str: + for ref_num, ref in all_refs.items(): + try: + template = ref["response_template"] + except KeyError: + pass + else: + formatted = jinja2.Template(template).render( + **ref, + output_text=formatted, + ref_num=ref_num, + ) + return formatted search_ref_pat = re.compile(r"\[" r"[\d\s\.\,\[\]\$\{\}]+" r"\]") diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 2a509d8fd..15cb6b063 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -23,8 +23,9 @@ from daras_ai_v2.search_ref import ( SearchReference, render_output_with_refs, - apply_response_template, CitationStyles, + apply_response_formattings_prefix, + apply_response_formattings_suffix, ) from daras_ai_v2.vector_search import ( DocSearchRequest, @@ -194,9 +195,12 @@ def run_v2( citation_style = ( request.citation_style and CitationStyles[request.citation_style] ) or None - apply_response_template( + all_refs_list = apply_response_formattings_prefix( response.output_text, response.references, citation_style ) + apply_response_formattings_suffix( + all_refs_list, response.output_text, citation_style + ) def get_raw_price(self, state: dict) -> float: name = state.get("selected_model") diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 149931ffb..89ea87973 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -50,6 +50,7 @@ get_entry_images, get_entry_text, format_chat_entry, + SUPERSCRIPT, ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.lipsync_settings_widgets import lipsync_settings @@ -58,7 +59,12 @@ from daras_ai_v2.query_generator import generate_final_search_query from daras_ai_v2.query_params import gooey_get_query_params from daras_ai_v2.query_params_util import extract_query_params -from daras_ai_v2.search_ref import apply_response_template, parse_refs, CitationStyles +from daras_ai_v2.search_ref import ( + parse_refs, + CitationStyles, + apply_response_formattings_prefix, + apply_response_formattings_suffix, +) from daras_ai_v2.text_output_widget import text_output from daras_ai_v2.text_to_speech_settings_widgets import ( TextToSpeechProviders, @@ -805,7 +811,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield f"Running {model.value}..." if is_chat_model: - output_text = run_language_model( + chunks = run_language_model( model=request.selected_model, messages=[ {"role": s["role"], "content": s["content"]} @@ -816,12 +822,13 @@ def run(self, state: dict) -> typing.Iterator[str | None]: temperature=request.sampling_temperature, avoid_repetition=request.avoid_repetition, tools=request.tools, + stream=True, ) else: prompt = "\n".join( format_chatml_message(entry) for entry in prompt_messages ) - output_text = run_language_model( + chunks = run_language_model( model=request.selected_model, prompt=prompt, max_tokens=max_allowed_tokens, @@ -830,43 +837,51 @@ def run(self, state: dict) -> typing.Iterator[str | None]: temperature=request.sampling_temperature, avoid_repetition=request.avoid_repetition, stop=[CHATML_END_TOKEN, CHATML_START_TOKEN], + stream=True, ) - if request.tools: - output_text, tool_call_choices = output_text - state["output_documents"] = output_documents = [] - for tool_calls in tool_call_choices: - for call in tool_calls: - result = yield from exec_tool_call(call) - output_documents.append(result) - - # save model response - state["raw_output_text"] = [ - "".join(snippet for snippet, _ in parse_refs(text, references)) - for text in output_text - ] - - # translate response text - if request.user_language and request.user_language != "en": - yield f"Translating response to {request.user_language}..." - output_text = run_google_translate( - texts=output_text, - source_language="en", - target_language=request.user_language, - glossary_url=request.output_glossary_document, - ) - state["raw_tts_text"] = [ + citation_style = ( + request.citation_style and CitationStyles[request.citation_style] + ) or None + all_refs_list = [] + for i, output_text in enumerate(chunks): + if request.tools: + output_text, tool_call_choices = output_text + state["output_documents"] = output_documents = [] + for tool_calls in tool_call_choices: + for call in tool_calls: + result = yield from exec_tool_call(call) + output_documents.append(result) + + # save model response + state["raw_output_text"] = [ "".join(snippet for snippet, _ in parse_refs(text, references)) for text in output_text ] - if references: - citation_style = ( - request.citation_style and CitationStyles[request.citation_style] - ) or None - apply_response_template(output_text, references, citation_style) - - state["output_text"] = output_text + # translate response text + if request.user_language and request.user_language != "en": + yield f"Translating response to {request.user_language}..." + output_text = run_google_translate( + texts=output_text, + source_language="en", + target_language=request.user_language, + glossary_url=request.output_glossary_document, + ) + state["raw_tts_text"] = [ + "".join(snippet for snippet, _ in parse_refs(text, references)) + for text in output_text + ] + + if references: + all_refs_list = apply_response_formattings_prefix( + output_text, references, citation_style + ) + state["output_text"] = output_text + yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." + apply_response_formattings_suffix( + all_refs_list, state["output_text"], citation_style + ) state["output_audio"] = [] state["output_video"] = [] From 06c9d1daa591fca21f37ec787472a6da7fd5d6e4 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 31 Jan 2024 21:57:25 +0530 Subject: [PATCH 35/85] admin niceties for saved runs --- bots/admin.py | 31 ++++++++++++++++++++++++------- bots/models.py | 10 +++++++++- daras_ai_v2/base.py | 6 ++---- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index e659865a8..49321ba35 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -219,6 +219,7 @@ class PublishedRunAdmin(admin.ModelAdmin): "view_user", "open_in_gooey", "linked_saved_run", + "view_runs", "created_at", "updated_at", ] @@ -227,6 +228,7 @@ class PublishedRunAdmin(admin.ModelAdmin): autocomplete_fields = ["saved_run", "created_by", "last_edited_by"] readonly_fields = [ "open_in_gooey", + "view_runs", "created_at", "updated_at", ] @@ -243,19 +245,28 @@ def linked_saved_run(self, published_run: PublishedRun): linked_saved_run.short_description = "Linked Run" + @admin.display(description="View Runs") + def view_runs(self, published_run: PublishedRun): + return list_related_html_url( + SavedRun.objects.filter(parent_version__published_run=published_run), + query_param="parent_version__published_run__id__exact", + instance_id=published_run.id, + show_add=False, + ) + @admin.register(SavedRun) class SavedRunAdmin(admin.ModelAdmin): list_display = [ "__str__", - "example_id", "run_id", "view_user", - "created_at", + "open_in_gooey", + "view_parent_published_run", "run_time", - "updated_at", "price", - "preview_input", + "created_at", + "updated_at", ] list_filter = ["workflow"] search_fields = ["workflow", "example_id", "run_id", "uid"] @@ -278,6 +289,11 @@ class SavedRunAdmin(admin.ModelAdmin): django.db.models.JSONField: {"widget": JSONEditorWidget}, } + def lookup_allowed(self, key, value): + if key in ["parent_version__published_run__id__exact"]: + return True + return super().lookup_allowed(key, value) + def view_user(self, saved_run: SavedRun): return change_obj_url( AppUser.objects.get(uid=saved_run.uid), @@ -291,9 +307,10 @@ def view_bots(self, saved_run: SavedRun): view_bots.short_description = "View Bots" - @admin.display(description="Input") - def preview_input(self, saved_run: SavedRun): - return truncate_text_words(BasePage.preview_input(saved_run.state) or "", 100) + @admin.display(description="View Published Run") + def view_parent_published_run(self, saved_run: SavedRun): + pr = saved_run.parent_published_run() + return pr and change_obj_url(pr) @admin.register(PublishedRunVersion) diff --git a/bots/models.py b/bots/models.py index 3ad94cd31..fd6071f58 100644 --- a/bots/models.py +++ b/bots/models.py @@ -280,7 +280,15 @@ class Meta: ] def __str__(self): - return self.get_app_url() + from daras_ai_v2.breadcrumbs import get_title_breadcrumbs + + title = get_title_breadcrumbs( + Workflow(self.workflow).page_cls, self, self.parent_published_run() + ).h1_title + return title or self.get_app_url() + + def parent_published_run(self) -> "PublishedRun": + return self.parent_version and self.parent_version.published_run def get_app_url(self): workflow = Workflow(self.workflow) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 519b70a28..7187178b7 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -923,7 +923,7 @@ def get_runs_from_query_params( ) -> tuple[SavedRun, PublishedRun | None]: if run_id and uid: sr = cls.run_doc_sr(run_id, uid) - pr = (sr and sr.parent_version and sr.parent_version.published_run) or None + pr = sr.parent_published_run() else: pr = cls.get_published_run(published_run_id=example_id or "") sr = pr.saved_run @@ -940,9 +940,7 @@ def get_pr_from_query_params( ) -> PublishedRun | None: if run_id and uid: sr = cls.get_sr_from_query_params(example_id, run_id, uid) - return ( - sr and sr.parent_version and sr.parent_version.published_run - ) or None + return sr.parent_published_run() elif example_id: return cls.get_published_run(published_run_id=example_id) else: From 24cf0324308049b4a01e5048c44cec84646e8943 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 31 Jan 2024 22:07:01 +0530 Subject: [PATCH 36/85] fix IndexError list index out of range --- recipes/VideoBots.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 89ea87973..44994b7ff 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -879,9 +879,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["output_text"] = output_text yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." - apply_response_formattings_suffix( - all_refs_list, state["output_text"], citation_style - ) + if all_refs_list: + apply_response_formattings_suffix( + all_refs_list, state["output_text"], citation_style + ) state["output_audio"] = [] state["output_video"] = [] From 91e7810a8803c500a152d694c5687ff4d37f8e34 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 12:07:40 -0800 Subject: [PATCH 37/85] SavedRun run_time attempt --- bots/models.py | 29 +++++++++++++++-------------- recipes/VideoBotsStats.py | 39 ++++----------------------------------- 2 files changed, 19 insertions(+), 49 deletions(-) diff --git a/bots/models.py b/bots/models.py index 0714f187b..204a625bd 100644 --- a/bots/models.py +++ b/bots/models.py @@ -1065,20 +1065,21 @@ def local_lang(self): @property def response_time(self): - import pandas as pd - - if self.role == CHATML_ROLE_USER: - return pd.NaT - return ( - self.created_at - - Message.objects.filter( - conversation=self.conversation, - role=CHATML_ROLE_USER, - created_at__lt=self.created_at, - ) - .latest() - .created_at - ) + return self.saved_run.run_time + # import pandas as pd + + # if self.role == CHATML_ROLE_USER: + # return pd.NaT + # return ( + # self.created_at + # - Message.objects.filter( + # conversation=self.conversation, + # role=CHATML_ROLE_USER, + # created_at__lt=self.created_at, + # ) + # .latest() + # .created_at + # ) class MessageAttachment(models.Model): diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 4939944cd..c82989c02 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -31,8 +31,7 @@ TruncYear, Concat, ) -from django.db.models import Count, F, Window -from django.db.models.functions import Lag +from django.db.models import Count, Avg ID_COLUMNS = [ "conversation__fb_page_id", @@ -389,33 +388,6 @@ def calculate_overall_stats(self, bid, bi, run_title, run_url): def calculate_stats_binned_by_time( self, bi, start_date, end_date, factor, trunc_fn ): - average_response_time = ( - Message.objects.filter( - created_at__date__gte=start_date, - created_at__date__lte=end_date, - conversation__bot_integration=bi, - ) - .values("conversation_id") - .order_by("created_at") - .annotate( - response_time=F("created_at") - Window(expression=Lag("created_at")), - ) - .annotate(date=trunc_fn("created_at")) - .values("date", "response_time", "role") - ) - average_response_time = ( - pd.DataFrame( - average_response_time, - columns=["date", "response_time", "role"], - ) - .loc[lambda df: df["role"] == CHATML_ROLE_ASSISTANT] - .groupby("date") - .agg({"response_time": "median"}) - .apply(lambda x: x.clip(lower=timedelta(0))) - .rename(columns={"response_time": "Average_response_time"}) - .reset_index() - ) - messages_received = ( Message.objects.filter( created_at__date__gte=start_date, @@ -436,6 +408,7 @@ def calculate_stats_binned_by_time( distinct=True, ) ) + .annotate(Average_response_time=Avg("saved_run__run_time")) .annotate(Unique_feedback_givers=Count("feedbacks", distinct=True)) .values( "date", @@ -443,6 +416,7 @@ def calculate_stats_binned_by_time( "Convos", "Senders", "Unique_feedback_givers", + "Average_response_time", ) ) @@ -482,6 +456,7 @@ def calculate_stats_binned_by_time( "Convos", "Senders", "Unique_feedback_givers", + "Average_response_time", ], ) df = df.merge( @@ -496,12 +471,6 @@ def calculate_stats_binned_by_time( left_on="date", right_on="date", ) - df = df.merge( - average_response_time, - how="outer", - left_on="date", - right_on="date", - ) df["Messages_Sent"] = df["Messages_Sent"] * factor df["Convos"] = df["Convos"] * factor df["Senders"] = df["Senders"] * factor From 344f6ba493e79a5ee999572186e9f87a12d4c0fc Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 13:01:09 -0800 Subject: [PATCH 38/85] response time collection --- bots/migrations/0056_message_response_time.py | 19 +++++++++++++++ bots/models.py | 23 ++++--------------- daras_ai_v2/bots.py | 12 ++++++++++ recipes/VideoBotsStats.py | 2 +- 4 files changed, 37 insertions(+), 19 deletions(-) create mode 100644 bots/migrations/0056_message_response_time.py diff --git a/bots/migrations/0056_message_response_time.py b/bots/migrations/0056_message_response_time.py new file mode 100644 index 000000000..06730e931 --- /dev/null +++ b/bots/migrations/0056_message_response_time.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.7 on 2024-02-01 20:15 + +import datetime +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0055_workflowmetadata'), + ] + + operations = [ + migrations.AddField( + model_name='message', + name='response_time', + field=models.DurationField(default=datetime.timedelta(days=-1, seconds=86399), help_text='The time it took for the bot to respond to the corresponding user message'), + ), + ] diff --git a/bots/models.py b/bots/models.py index 204a625bd..22008019f 100644 --- a/bots/models.py +++ b/bots/models.py @@ -1048,6 +1048,11 @@ class Message(models.Model): help_text="Subject of given question (DEPRECATED)", ) + response_time = models.DurationField( + default=datetime.timedelta(seconds=-1), + help_text="The time it took for the bot to respond to the corresponding user message", + ) + _analysis_started = False objects = MessageQuerySet.as_manager() @@ -1063,24 +1068,6 @@ def __str__(self): def local_lang(self): return Truncator(self.display_content).words(30) - @property - def response_time(self): - return self.saved_run.run_time - # import pandas as pd - - # if self.role == CHATML_ROLE_USER: - # return pd.NaT - # return ( - # self.created_at - # - Message.objects.filter( - # conversation=self.conversation, - # role=CHATML_ROLE_USER, - # created_at__lt=self.created_at, - # ) - # .latest() - # .created_at - # ) - class MessageAttachment(models.Model): message = models.ForeignKey( diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 8cc1d69e1..324affbb7 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -3,11 +3,14 @@ import typing from urllib.parse import parse_qs +import pytz +from datetime import datetime from django.db import transaction from fastapi import HTTPException, Request from furl import furl from sentry_sdk import capture_exception +from daras_ai_v2 import settings from app_users.models import AppUser from bots.models import ( Platform, @@ -202,6 +205,7 @@ def _on_msg(bot: BotInterface): speech_run = None input_images = None input_documents = None + recieved_time: datetime = datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) if not bot.page_cls: bot.send_msg(text=PAGE_NOT_CONNECTED_ERROR) return @@ -286,6 +290,7 @@ def _on_msg(bot: BotInterface): input_documents=input_documents, input_text=input_text, speech_run=speech_run, + recieved_time=recieved_time, ) @@ -324,6 +329,7 @@ def _process_and_send_msg( input_images: list[str] | None, input_documents: list[str] | None, input_text: str, + recieved_time: datetime, speech_run: str | None, ): try: @@ -369,6 +375,7 @@ def _process_and_send_msg( platform_msg_id=msg_id, response=response, url=url, + received_time=recieved_time, ) @@ -381,6 +388,7 @@ def _save_msgs( platform_msg_id: str | None, response: VideoBotsPage.ResponseModel, url: str, + received_time: datetime, ): # create messages for future context user_msg = Message( @@ -396,6 +404,8 @@ def _save_msgs( if speech_run else None ), + response_time=datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) + - received_time, ) attachments = [] for f_url in (input_images or []) + (input_documents or []): @@ -412,6 +422,8 @@ def _save_msgs( saved_run=SavedRun.objects.get_or_create( workflow=Workflow.VIDEO_BOTS, **furl(url).query.params )[0], + response_time=datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) + - received_time, ) # save the messages & attachments with transaction.atomic(): diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index c82989c02..a265f5a92 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -408,7 +408,7 @@ def calculate_stats_binned_by_time( distinct=True, ) ) - .annotate(Average_response_time=Avg("saved_run__run_time")) + .annotate(Average_response_time=Avg("response_time")) .annotate(Unique_feedback_givers=Count("feedbacks", distinct=True)) .values( "date", From fb9d32cebe0bf3d441acc966fdcd99088e108320 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 13:26:34 -0800 Subject: [PATCH 39/85] enable show all --- recipes/VideoBotsStats.py | 51 ++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index a265f5a92..12eabf51f 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -136,7 +136,7 @@ def render(self): view, factor, trunc_fn, - ) = self.render_date_view_inputs() + ) = self.render_date_view_inputs(bi) df = self.calculate_stats_binned_by_time( bi, start_date, end_date, factor, trunc_fn @@ -260,27 +260,34 @@ def render(self): ).tostr() st.change_url(new_url, self.request) - def render_date_view_inputs(self): - start_of_year_date = datetime.now().replace(month=1, day=1) - st.session_state.setdefault( - "start_date", - self.request.query_params.get( - "start_date", start_of_year_date.strftime("%Y-%m-%d") - ), - ) - start_date: datetime = ( - st.date_input("Start date", key="start_date") or start_of_year_date - ) - st.session_state.setdefault( - "end_date", - self.request.query_params.get( - "end_date", datetime.now().strftime("%Y-%m-%d") - ), - ) - end_date: datetime = st.date_input("End date", key="end_date") or datetime.now() - st.session_state.setdefault( - "view", self.request.query_params.get("view", "Weekly") - ) + def render_date_view_inputs(self, bi): + if st.checkbox("Show All"): + start_date = bi.created_at + end_date = datetime.now() + else: + start_of_year_date = datetime.now().replace(month=1, day=1) + st.session_state.setdefault( + "start_date", + self.request.query_params.get( + "start_date", start_of_year_date.strftime("%Y-%m-%d") + ), + ) + start_date: datetime = ( + st.date_input("Start date", key="start_date") or start_of_year_date + ) + st.session_state.setdefault( + "end_date", + self.request.query_params.get( + "end_date", datetime.now().strftime("%Y-%m-%d") + ), + ) + end_date: datetime = ( + st.date_input("End date", key="end_date") or datetime.now() + ) + st.session_state.setdefault( + "view", self.request.query_params.get("view", "Weekly") + ) + st.write("---") view = st.horizontal_radio( "### View", options=["Daily", "Weekly", "Monthly"], From 570963c1e99dfedbb6a90192be4c1d601450bb9a Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 13:53:49 -0800 Subject: [PATCH 40/85] tables adhere to data filters --- recipes/VideoBotsStats.py | 60 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 12eabf51f..f41add98f 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -210,7 +210,15 @@ def render(self): sort_by = st.session_state["sort_by"] df = self.get_tabular_data( - bi, run_url, conversations, messages, details, sort_by, rows=500 + bi, + run_url, + conversations, + messages, + details, + sort_by, + rows=500, + start_date=start_date, + end_date=end_date, ) if not df.empty: @@ -233,7 +241,14 @@ def render(self): st.html("
") if st.checkbox("Export"): df = self.get_tabular_data( - bi, run_url, conversations, messages, details, sort_by + bi, + run_url, + conversations, + messages, + details, + sort_by, + start_date=start_date, + end_date=end_date, ) csv = df.to_csv() b64 = base64.b64encode(csv.encode()).decode() @@ -642,12 +657,29 @@ def plot_graphs(self, view, df): st.plotly_chart(fig) def get_tabular_data( - self, bi, run_url, conversations, messages, details, sort_by, rows=10000 + self, + bi, + run_url, + conversations, + messages, + details, + sort_by, + rows=10000, + start_date=None, + end_date=None, ): df = pd.DataFrame() if details == "Conversations": + if start_date and end_date: + conversations = conversations.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = conversations.to_df_format(row_limit=rows) elif details == "Messages": + if start_date and end_date: + messages = messages.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = messages.order_by("-created_at", "conversation__id").to_df_format( row_limit=rows ) @@ -658,6 +690,10 @@ def get_tabular_data( message__conversation__bot_integration=bi, rating=Feedback.Rating.RATING_THUMBS_UP, ) # type: ignore + if start_date and end_date: + pos_feedbacks = pos_feedbacks.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = pos_feedbacks.to_df_format(row_limit=rows) df["Run URL"] = run_url df["Bot"] = bi.name @@ -665,7 +701,13 @@ def get_tabular_data( neg_feedbacks: FeedbackQuerySet = Feedback.objects.filter( message__conversation__bot_integration=bi, rating=Feedback.Rating.RATING_THUMBS_DOWN, + created_at__date__gte=start_date, + created_at__date__lte=end_date, ) # type: ignore + if start_date and end_date: + neg_feedbacks = neg_feedbacks.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = neg_feedbacks.to_df_format(row_limit=rows) df["Run URL"] = run_url df["Bot"] = bi.name @@ -673,7 +715,13 @@ def get_tabular_data( successful_messages: MessageQuerySet = Message.objects.filter( conversation__bot_integration=bi, analysis_result__contains={"Answered": True}, + created_at__date__gte=start_date, + created_at__date__lte=end_date, ) # type: ignore + if start_date and end_date: + successful_messages = successful_messages.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = successful_messages.to_df_analysis_format(row_limit=rows) df["Run URL"] = run_url df["Bot"] = bi.name @@ -681,7 +729,13 @@ def get_tabular_data( unsuccessful_messages: MessageQuerySet = Message.objects.filter( conversation__bot_integration=bi, analysis_result__contains={"Answered": False}, + created_at__date__gte=start_date, + created_at__date__lte=end_date, ) # type: ignore + if start_date and end_date: + unsuccessful_messages = unsuccessful_messages.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = unsuccessful_messages.to_df_analysis_format(row_limit=rows) df["Run URL"] = run_url df["Bot"] = bi.name From 747cead27e41e1f683c4b9c658d5e304208c0bf9 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 14:00:15 -0800 Subject: [PATCH 41/85] round response time --- bots/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bots/models.py b/bots/models.py index 22008019f..edb7e183a 100644 --- a/bots/models.py +++ b/bots/models.py @@ -927,7 +927,7 @@ def to_df_format( else None ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, - "Response Time": message.response_time.total_seconds(), + "Response Time": round(message.response_time.total_seconds(), 1), } rows.append(row) df = pd.DataFrame.from_records( From 819cf1f80fc7fe8fd5ddd8bdaa91051a4224b39b Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 22:02:03 -0800 Subject: [PATCH 42/85] split of graphs --- recipes/VideoBotsStats.py | 80 +++++++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 8 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index f41add98f..08787f1b3 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -499,6 +499,12 @@ def calculate_stats_binned_by_time( df["Unique_feedback_givers"] = df["Unique_feedback_givers"] * factor df["Pos_feedback"] = df["Pos_feedback"] * factor df["Neg_feedback"] = df["Neg_feedback"] * factor + df["Percentage_positive_feedback"] = ( + df["Pos_feedback"] / (df["Pos_feedback"] + df["Neg_feedback"]) + ) * 100 + df["Percentage_negative_feedback"] = ( + df["Neg_feedback"] / (df["Pos_feedback"] + df["Neg_feedback"]) + ) * 100 df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] df["Average_response_time"] = df["Average_response_time"] * factor @@ -602,14 +608,6 @@ def plot_graphs(self, view, df): text=list(df["Msgs_per_user"]), hovertemplate="Messages per User: %{y:.0f}", ), - go.Scatter( - name="Average Response Time", - mode="lines+markers", - x=list(df["date"]), - y=list(df["Average_response_time"]), - text=list(df["Average_response_time"]), - hovertemplate="Average Response Time: %{y:.0f}", - ), ], layout=dict( margin=dict(l=0, r=0, t=28, b=0), @@ -655,6 +653,72 @@ def plot_graphs(self, view, df): ], ) st.plotly_chart(fig) + st.markdown("
", unsafe_allow_html=True) + fig = go.Figure( + data=[ + go.Scatter( + name="Positive Feedback", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Percentage_positive_feedback"]), + text=list(df["Percentage_positive_feedback"]), + hovertemplate="Positive Feedback: %{y:.0f}\\%", + ), + go.Scatter( + name="Negative Feedback", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Percentage_negative_feedback"]), + text=list(df["Percentage_negative_feedback"]), + hovertemplate="Negative Feedback: %{y:.0f}\\%", + ), + ], + layout=dict( + margin=dict(l=0, r=0, t=28, b=0), + yaxis=dict( + title="Percentage", + range=[0, 100], + tickvals=[ + *range( + 0, + 101, + 10, + ) + ], + ), + title=dict( + text=f"{view} Feedback Distribution", + ), + height=300, + template="plotly_white", + ), + ) + st.plotly_chart(fig) + st.write("---") + fig = go.Figure( + data=[ + go.Scatter( + name="Average Response Time", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Average_response_time"]), + text=list(df["Average_response_time"]), + hovertemplate="Average Response Time: %{y:.0f}", + ), + ], + layout=dict( + margin=dict(l=0, r=0, t=28, b=0), + yaxis=dict( + title="Seconds", + ), + title=dict( + text=f"{view} Performance Metrics", + ), + height=300, + template="plotly_white", + ), + ) + st.plotly_chart(fig) def get_tabular_data( self, From 2a5d31969c86a07e0a6180c5c75829c268253e8d Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 22:50:30 -0800 Subject: [PATCH 43/85] sean tweaks and reordering of graphs --- recipes/VideoBotsStats.py | 134 ++++++++++++++++++++++++++++---------- 1 file changed, 99 insertions(+), 35 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 08787f1b3..718b580f1 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -119,6 +119,7 @@ def render(self): if int(bid) not in allowed_bids: bid = allowed_bids[0] bi = BotIntegration.objects.get(id=bid) + has_analysis_run = bi.analysis_run is not None run_title, run_url = self.parse_run_info(bi) self.show_title_breadcrumb_share(run_title, run_url, bi) @@ -153,14 +154,22 @@ def render(self): st.session_state.setdefault("details", self.request.query_params.get("details")) details = st.horizontal_radio( "### Details", - options=[ - "Conversations", - "Messages", - "Feedback Positive", - "Feedback Negative", - "Answered Successfully", - "Answered Unsuccessfully", - ], + options=( + [ + "Conversations", + "Messages", + "Feedback Positive", + "Feedback Negative", + ] + + ( + [ + "Answered Successfully", + "Answered Unsuccessfully", + ] + if has_analysis_run + else [] + ) + ), key="details", ) @@ -430,7 +439,9 @@ def calculate_stats_binned_by_time( distinct=True, ) ) + .annotate(Average_runtime=Avg("saved_run__run_time")) .annotate(Average_response_time=Avg("response_time")) + .annotate(Average_analysis_time=Avg("analysis_run__run_time")) .annotate(Unique_feedback_givers=Count("feedbacks", distinct=True)) .values( "date", @@ -439,6 +450,8 @@ def calculate_stats_binned_by_time( "Senders", "Unique_feedback_givers", "Average_response_time", + "Average_runtime", + "Average_analysis_time", ) ) @@ -470,6 +483,20 @@ def calculate_stats_binned_by_time( .values("date", "Neg_feedback") ) + successfully_answered = ( + Message.objects.filter( + conversation__bot_integration=bi, + analysis_result__contains={"Answered": True}, + created_at__date__gte=start_date, + created_at__date__lte=end_date, + ) + .order_by() + .annotate(date=trunc_fn("created_at")) + .values("date") + .annotate(Successfully_Answered=Count("id")) + .values("date", "Successfully_Answered") + ) + df = pd.DataFrame( messages_received, columns=[ @@ -479,6 +506,8 @@ def calculate_stats_binned_by_time( "Senders", "Unique_feedback_givers", "Average_response_time", + "Average_runtime", + "Average_analysis_time", ], ) df = df.merge( @@ -493,6 +522,14 @@ def calculate_stats_binned_by_time( left_on="date", right_on="date", ) + df = df.merge( + pd.DataFrame( + successfully_answered, columns=["date", "Successfully_Answered"] + ), + how="outer", + left_on="date", + right_on="date", + ) df["Messages_Sent"] = df["Messages_Sent"] * factor df["Convos"] = df["Convos"] * factor df["Senders"] = df["Senders"] * factor @@ -500,10 +537,13 @@ def calculate_stats_binned_by_time( df["Pos_feedback"] = df["Pos_feedback"] * factor df["Neg_feedback"] = df["Neg_feedback"] * factor df["Percentage_positive_feedback"] = ( - df["Pos_feedback"] / (df["Pos_feedback"] + df["Neg_feedback"]) + df["Pos_feedback"] / df["Messages_Sent"] ) * 100 df["Percentage_negative_feedback"] = ( - df["Neg_feedback"] / (df["Pos_feedback"] + df["Neg_feedback"]) + df["Neg_feedback"] / df["Messages_Sent"] + ) * 100 + df["Percentage_successfully_answered"] = ( + df["Successfully_Answered"] / df["Messages_Sent"] ) * 100 df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] @@ -653,41 +693,41 @@ def plot_graphs(self, view, df): ], ) st.plotly_chart(fig) - st.markdown("
", unsafe_allow_html=True) + st.write("---") fig = go.Figure( data=[ go.Scatter( - name="Positive Feedback", + name="Average Response Time", mode="lines+markers", x=list(df["date"]), - y=list(df["Percentage_positive_feedback"]), - text=list(df["Percentage_positive_feedback"]), - hovertemplate="Positive Feedback: %{y:.0f}\\%", + y=list(df["Average_response_time"]), + text=list(df["Average_response_time"]), + hovertemplate="Average Response Time: %{y:.0f}", ), go.Scatter( - name="Negative Feedback", + name="Average Run Time", mode="lines+markers", x=list(df["date"]), - y=list(df["Percentage_negative_feedback"]), - text=list(df["Percentage_negative_feedback"]), - hovertemplate="Negative Feedback: %{y:.0f}\\%", + y=list(df["Average_runtime"]), + text=list(df["Average_runtime"]), + hovertemplate="Average Runtime: %{y:.0f}", + ), + go.Scatter( + name="Average Analysis Time", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Average_analysis_time"]), + text=list(df["Average_analysis_time"]), + hovertemplate="Average Analysis Time: %{y:.0f}", ), ], layout=dict( margin=dict(l=0, r=0, t=28, b=0), yaxis=dict( - title="Percentage", - range=[0, 100], - tickvals=[ - *range( - 0, - 101, - 10, - ) - ], + title="Seconds", ), title=dict( - text=f"{view} Feedback Distribution", + text=f"{view} Performance Metrics", ), height=300, template="plotly_white", @@ -698,21 +738,45 @@ def plot_graphs(self, view, df): fig = go.Figure( data=[ go.Scatter( - name="Average Response Time", + name="Positive Feedback", mode="lines+markers", x=list(df["date"]), - y=list(df["Average_response_time"]), - text=list(df["Average_response_time"]), - hovertemplate="Average Response Time: %{y:.0f}", + y=list(df["Percentage_positive_feedback"]), + text=list(df["Percentage_positive_feedback"]), + hovertemplate="Positive Feedback: %{y:.0f}%", + ), + go.Scatter( + name="Negative Feedback", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Percentage_negative_feedback"]), + text=list(df["Percentage_negative_feedback"]), + hovertemplate="Negative Feedback: %{y:.0f}%", + ), + go.Scatter( + name="Successfully Answered", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Percentage_successfully_answered"]), + text=list(df["Percentage_successfully_answered"]), + hovertemplate="Successfully Answered: %{y:.0f}%", ), ], layout=dict( margin=dict(l=0, r=0, t=28, b=0), yaxis=dict( - title="Seconds", + title="Percentage", + range=[0, 100], + tickvals=[ + *range( + 0, + 101, + 10, + ) + ], ), title=dict( - text=f"{view} Performance Metrics", + text=f"{view} Feedback Distribution", ), height=300, template="plotly_white", From 896e2b808fd9d6d5d5d5514e1b731d7f99b6c867 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 1 Feb 2024 00:53:03 +0530 Subject: [PATCH 44/85] optional streaming support for slack & whatsapp add finish_reason to streaming --- README.md | 7 + bots/admin.py | 1 + bots/models.py | 5 + bots/tasks.py | 2 +- daras_ai_v2/base.py | 20 ++- daras_ai_v2/bot_integration_widgets.py | 5 + daras_ai_v2/bots.py | 212 ++++++++++++++----------- daras_ai_v2/facebook_bots.py | 106 +++++-------- daras_ai_v2/language_model.py | 35 ++-- daras_ai_v2/search_ref.py | 2 + daras_ai_v2/slack_bot.py | 94 +++++++---- daras_ai_v2/vector_search.py | 8 +- gooey_ui/pubsub.py | 29 ++++ poetry.lock | 8 +- pyproject.toml | 2 +- recipes/CompareLLM.py | 4 +- recipes/VideoBots.py | 67 +++++--- routers/api.py | 2 +- scripts/test_wa_msg_send.py | 124 +++++++++++++++ 19 files changed, 487 insertions(+), 246 deletions(-) create mode 100644 scripts/test_wa_msg_send.py diff --git a/README.md b/README.md index 89b0c5c34..a2fbe6a7a 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,12 @@ ngrok http 8080 5. Copy the temporary access token there and set env var `WHATSAPP_ACCESS_TOKEN = XXXX` +**(Optional) Use the test script to send yourself messages** + +```bash +python manage.py runscript test_wa_msg_send --script-args 104696745926402 +918764022384 +``` +Replace `+918764022384` with your number and `104696745926402` with the test number ID ## Dangerous postgres commands @@ -178,3 +184,4 @@ rsync -P -a @captain.us-1.gooey.ai:/home//fixture.json . createdb -T template0 $PGDATABASE pg_dump $SOURCE_DATABASE | psql -q $PGDATABASE ``` + diff --git a/bots/admin.py b/bots/admin.py index 49321ba35..4f69a1081 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -168,6 +168,7 @@ class BotIntegrationAdmin(admin.ModelAdmin): "Settings", { "fields": [ + "streaming_enabled", "show_feedback_buttons", "analysis_run", "view_analysis_results", diff --git a/bots/models.py b/bots/models.py index fd6071f58..467131fa9 100644 --- a/bots/models.py +++ b/bots/models.py @@ -571,6 +571,11 @@ class BotIntegration(models.Model): help_text="If provided, the message content will be analyzed for this bot using this saved run", ) + streaming_enabled = models.BooleanField( + default=False, + help_text="If set, the bot will stream messages to the frontend (Slack only)", + ) + created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) diff --git a/bots/tasks.py b/bots/tasks.py index 9afe2c343..5dda17e90 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -128,7 +128,7 @@ def send_broadcast_msg( channel_is_personal=convo.slack_channel_is_personal, username=bi.name, token=bi.slack_access_token, - ) + )[0] case _: raise NotImplementedError( f"Platform {bi.platform} doesn't support broadcasts yet" diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 7187178b7..f05a0c71c 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -189,7 +189,7 @@ def setup_render(self): def refresh_state(self): _, run_id, uid = extract_query_params(gooey_get_query_params()) - channel = f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}" + channel = self.realtime_channel_name(run_id, uid) output = realtime_pull([channel])[0] if output: st.session_state.update(output) @@ -197,7 +197,7 @@ def refresh_state(self): def render(self): self.setup_render() - if self.get_run_state() == RecipeRunState.running: + if self.get_run_state(st.session_state) == RecipeRunState.running: self.refresh_state() else: realtime_clear_subs() @@ -1307,12 +1307,13 @@ def _render_input_col(self): ) return submitted - def get_run_state(self) -> RecipeRunState: - if st.session_state.get(StateKeys.run_status): + @classmethod + def get_run_state(cls, state: dict[str, typing.Any]) -> RecipeRunState: + if state.get(StateKeys.run_status): return RecipeRunState.running - elif st.session_state.get(StateKeys.error_msg): + elif state.get(StateKeys.error_msg): return RecipeRunState.failed - elif st.session_state.get(StateKeys.run_time): + elif state.get(StateKeys.run_time): return RecipeRunState.completed else: # when user is at a recipe root, and not running anything @@ -1331,7 +1332,7 @@ def _render_output_col(self, submitted: bool): self._render_before_output() - run_state = self.get_run_state() + run_state = self.get_run_state(st.session_state) match run_state: case RecipeRunState.completed: self._render_completed_output() @@ -1458,13 +1459,16 @@ def call_runner_task(self, example_id, run_id, uid, is_api_call=False): run_id=run_id, uid=uid, state=st.session_state, - channel=f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}", + channel=self.realtime_channel_name(run_id, uid), query_params=self.clean_query_params( example_id=example_id, run_id=run_id, uid=uid ), is_api_call=is_api_call, ) + def realtime_channel_name(self, run_id, uid): + return f"gooey-outputs/{self.slug_versions[0]}/{uid}/{run_id}" + def generate_credit_error_message(self, example_id, run_id, uid) -> str: account_url = furl(settings.APP_BASE_URL) / "account/" if self.request.user.is_anonymous: diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 3880f9fe7..309d32b44 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -24,6 +24,11 @@ def general_integration_settings(bi: BotIntegration): ] = BotIntegration._meta.get_field("show_feedback_buttons").default st.session_state[f"_bi_analysis_url_{bi.id}"] = None + bi.streaming_enabled = st.checkbox( + "**📡 Streaming Enabled**", + value=bi.streaming_enabled, + key=f"_bi_streaming_enabled_{bi.id}", + ) bi.show_feedback_buttons = st.checkbox( "**👍🏾 👎🏽 Show Feedback Buttons**", value=bi.show_feedback_buttons, diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index dfc00c794..3d106e9ae 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -20,11 +20,13 @@ MessageAttachment, ) from daras_ai_v2.asr import AsrModels, run_google_translate -from daras_ai_v2.base import BasePage +from daras_ai_v2.base import BasePage, RecipeRunState, StateKeys from daras_ai_v2.language_model import CHATML_ROLE_USER, CHATML_ROLE_ASSISTANT from daras_ai_v2.vector_search import doc_url_to_file_metadata +from gooey_ui.pubsub import realtime_subscribe from gooeysite.bg_db_conn import db_middleware from recipes.VideoBots import VideoBotsPage, ReplyButton +from routers.api import submit_api_call PAGE_NOT_CONNECTED_ERROR = ( "💔 Looks like you haven't connected this page to a gooey.ai workflow. " @@ -47,7 +49,7 @@ """.strip() ERROR_MSG = """ -`{0!r}` +`{}` ⚠️ Sorry, I ran into an error while processing your request. Please try again, or type "Reset" to start over. """.strip() @@ -60,6 +62,8 @@ TAPPED_SKIP_MSG = "🌱 Alright. What else can I help you with?" +SLACK_MAX_SIZE = 3000 + async def request_json(request: Request): return await request.json() @@ -80,33 +84,13 @@ class BotInterface: input_type: str language: str show_feedback_buttons: bool = False + streaming_enabled: bool = False + can_update_message: bool = False convo: Conversation recieved_msg_id: str = None input_glossary: str | None = None output_glossary: str | None = None - def send_msg_or_default( - self, - *, - text: str | None = None, - audio: str = None, - video: str = None, - buttons: list[ReplyButton] = None, - documents: list[str] = None, - should_translate: bool = False, - default: str = DEFAULT_RESPONSE, - ): - if not (text or audio or video or documents): - text = default - return self.send_msg( - text=text, - audio=audio, - video=video, - buttons=buttons, - documents=documents, - should_translate=should_translate, - ) - def send_msg( self, *, @@ -116,6 +100,7 @@ def send_msg( buttons: list[ReplyButton] = None, documents: list[str] = None, should_translate: bool = False, + update_msg_id: str = None, ) -> str | None: raise NotImplementedError @@ -166,6 +151,7 @@ def _unpack_bot_integration(self): self.billing_account_uid = bi.billing_account_uid self.language = bi.user_language self.show_feedback_buttons = bi.show_feedback_buttons + self.streaming_enabled = bi.streaming_enabled def get_interactive_msg_info(self) -> tuple[str, str]: raise NotImplementedError("This bot does not support interactive messages.") @@ -326,39 +312,112 @@ def _process_and_send_msg( input_text: str, speech_run: str | None, ): - try: - # # mock testing - # msgs_to_save, response_audio, response_text, response_video = _echo( - # bot, input_text - # ) - # make API call to gooey bots to get the response - response, url = _process_msg( - page_cls=bot.page_cls, - api_user=billing_account_user, - query_params=bot.query_params, - convo=bot.convo, - input_text=input_text, - user_language=bot.language, - speech_run=speech_run, - input_images=input_images, - input_documents=input_documents, - ) - except HTTPException as e: - traceback.print_exc() - capture_exception(e) - # send error msg as repsonse - bot.send_msg(text=ERROR_MSG.format(e)) - return + # get latest messages for context (upto 100) + saved_msgs = bot.convo.messages.all().as_llm_context() - # send the response to the user - msg_id = bot.send_msg_or_default( - text=response.output_text and response.output_text[0], - audio=response.output_audio and response.output_audio[0], - video=response.output_video and response.output_video[0], - documents=response.output_documents or [], - buttons=_feedback_start_buttons() if bot.show_feedback_buttons else None, + # # mock testing + # result = _mock_api_output(input_text) + page, result, run_id, uid = submit_api_call( + page_cls=bot.page_cls, + user=billing_account_user, + request_body={ + "input_prompt": input_text, + "input_images": input_images, + "input_documents": input_documents, + "messages": saved_msgs, + "user_language": bot.language, + }, + query_params=bot.query_params, ) + if bot.show_feedback_buttons: + buttons = _feedback_start_buttons() + else: + buttons = None + + update_msg_id = None # this is the message id to update during streaming + sent_msg_id = None # this is the message id to record in the db + last_idx = 0 # this is the last index of the text sent to the user + if bot.streaming_enabled: + # subscribe to the realtime channel for updates + channel = page.realtime_channel_name(run_id, uid) + with realtime_subscribe(channel) as realtime_gen: + for state in realtime_gen: + run_state = page.get_run_state(state) + run_status = state.get(StateKeys.run_status) or "" + # check for errors + if run_state == RecipeRunState.failed: + err_msg = state.get(StateKeys.error_msg) + bot.send_msg(text=ERROR_MSG.format(err_msg)) + return # abort + if run_state != RecipeRunState.running: + break # we're done running, abort + text = state.get("output_text") and state.get("output_text")[0] + if not text: + # if no text, send the run status + if bot.can_update_message: + update_msg_id = bot.send_msg( + text=run_status, update_msg_id=update_msg_id + ) + continue # no text, wait for the next update + streaming_done = not run_status.lower().startswith("streaming") + # send the response to the user + if bot.can_update_message: + update_msg_id = bot.send_msg( + text=text.strip() + "...", + update_msg_id=update_msg_id, + buttons=buttons if streaming_done else None, + ) + last_idx = len(text) + else: + next_chunk = text[last_idx:] + last_idx = len(text) + if not next_chunk: + continue # no chunk, wait for the next update + update_msg_id = bot.send_msg( + text=next_chunk, + buttons=buttons if streaming_done else None, + ) + if streaming_done and not bot.can_update_message: + # if we send the buttons, this is the ID we need to record in the db for lookups later when the button is pressed + sent_msg_id = update_msg_id + # don't show buttons again + buttons = None + if streaming_done: + break # we're done streaming, abort + + # wait for the celery task to finish + result.get(disable_sync_subtasks=False) + # get the final state from db + state = page.run_doc_sr(run_id, uid).to_dict() + # check for errors + err_msg = state.get(StateKeys.error_msg) + if err_msg: + bot.send_msg(text=ERROR_MSG.format(err_msg)) + return + + text = (state.get("output_text") and state.get("output_text")[0]) or "" + audio = state.get("output_audio") and state.get("output_audio")[0] + video = state.get("output_video") and state.get("output_video")[0] + documents = state.get("output_documents") or [] + # check for empty response + if not (text or audio or video or documents or buttons): + bot.send_msg(text=DEFAULT_RESPONSE) + return + # if in-place updates are enabled, update the message, otherwise send the remaining text + if not bot.can_update_message: + text = text[last_idx:] + # send the response to the user if there is any remaining + if text or audio or video or documents or buttons: + update_msg_id = bot.send_msg( + text=text, + audio=audio, + video=video, + documents=documents, + buttons=buttons, + update_msg_id=update_msg_id, + ) + # save msgs to db _save_msgs( bot=bot, @@ -366,9 +425,9 @@ def _process_and_send_msg( input_documents=input_documents, input_text=input_text, speech_run=speech_run, - platform_msg_id=msg_id, - response=response, - url=url, + platform_msg_id=sent_msg_id or update_msg_id, + response=VideoBotsPage.ResponseModel.parse_obj(state), + url=page.app_url(run_id=run_id, uid=uid), ) @@ -420,45 +479,6 @@ def _save_msgs( assistant_msg.save() -def _process_msg( - *, - page_cls, - api_user: AppUser, - query_params: dict, - convo: Conversation, - input_images: list[str] | None, - input_documents: list[str] | None, - input_text: str, - user_language: str, - speech_run: str | None, -) -> tuple[VideoBotsPage.ResponseModel, str]: - from routers.api import call_api - - # get latest messages for context (upto 100) - saved_msgs = convo.messages.all().as_llm_context() - - # # mock testing - # result = _mock_api_output(input_text) - - # call the api with provided input - result = call_api( - page_cls=page_cls, - user=api_user, - request_body={ - "input_prompt": input_text, - "input_images": input_images, - "input_documents": input_documents, - "messages": saved_msgs, - "user_language": user_language, - }, - query_params=query_params, - ) - # parse result - response = page_cls.ResponseModel.parse_obj(result["output"]) - url = result.get("url", "") - return response, url - - def _handle_interactive_msg(bot: BotInterface): try: button_id, context_msg_id = bot.get_interactive_msg_info() diff --git a/daras_ai_v2/facebook_bots.py b/daras_ai_v2/facebook_bots.py index 2a22bbddd..abf33b8a0 100644 --- a/daras_ai_v2/facebook_bots.py +++ b/daras_ai_v2/facebook_bots.py @@ -102,12 +102,12 @@ def send_msg( buttons: list[ReplyButton] = None, documents: list[str] = None, should_translate: bool = False, + update_msg_id: str = None, ) -> str | None: if text and should_translate and self.language and self.language != "en": text = run_google_translate( [text], self.language, glossary_url=self.output_glossary )[0] - text = text or "\u200b" # handle empty text with zero-width space return self.send_msg_to( bot_number=self.bot_id, user_number=self.user_id, @@ -123,9 +123,9 @@ def mark_read(self): @classmethod def send_msg_to( - self, + cls, *, - text: str, + text: str = None, audio: str = None, video: str = None, documents: list[str] = None, @@ -137,7 +137,7 @@ def send_msg_to( # see https://developers.facebook.com/docs/whatsapp/api/messages/media/ # split text into chunks if too long - if len(text) > WA_MSG_MAX_SIZE: + if text and len(text) > WA_MSG_MAX_SIZE: splits = text_splitter( text, chunk_size=WA_MSG_MAX_SIZE, length_function=len ) @@ -160,6 +160,7 @@ def send_msg_to( ], ) + messages = [] if video: if buttons: messages = [ @@ -168,7 +169,7 @@ def send_msg_to( buttons, { "body": { - "text": text, + "text": text or "\u200b", }, "header": { "type": "video", @@ -188,74 +189,38 @@ def send_msg_to( }, }, ] - elif audio: - if buttons: - # audio can't be sent as an interaction, so send text and audio separately - messages = [ - # simple audio msg + elif buttons: + # interactive text msg + messages = [ + _build_msg_buttons( + buttons, { - "type": "audio", - "audio": {"link": audio}, + "body": { + "text": text or "\u200b", + } }, - ] - send_wa_msgs_raw( - bot_number=bot_number, - user_number=user_number, - messages=messages, - ) - messages = [ - # interactive text msg - _build_msg_buttons( - buttons, - { - "body": { - "text": text, - }, - }, - ) - ] - else: - # audio doesn't support captions, so send text and audio separately - messages = [ - # simple text msg - { - "type": "text", - "text": { - "body": text, - "preview_url": True, - }, - }, - # simple audio msg - { - "type": "audio", - "audio": {"link": audio}, - }, - ] - else: - # text message - if buttons: - messages = [ - # interactive text msg - _build_msg_buttons( - buttons, - { - "body": { - "text": text, - } - }, - ), - ] - else: - messages = [ - # simple text msg - { - "type": "text", - "text": { - "body": text, - "preview_url": True, - }, + ), + ] + elif text: + # simple text msg + messages = [ + { + "type": "text", + "text": { + "body": text, + "preview_url": True, }, - ] + }, + ] + + if audio and not video: # video already has audio + # simple audio msg + messages.append( + { + "type": "audio", + "audio": {"link": audio}, + } + ) if documents: messages += [ @@ -398,6 +363,7 @@ def send_msg( buttons: list[ReplyButton] = None, documents: list[str] = None, should_translate: bool = False, + update_msg_id: str = None, ) -> str | None: if text and should_translate and self.language and self.language != "en": text = run_google_translate( diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 99452a193..faf364e63 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -340,7 +340,7 @@ def run_language_model( ) -> ( list[str] | tuple[list[str], list[list[dict]]] - | typing.Generator[list[str], None, None] + | typing.Generator[list[dict], None, None] ): assert bool(prompt) != bool( messages @@ -376,7 +376,7 @@ def run_language_model( stream=stream and not (tools or response_format_type), ) if stream: - return _stream_llm_outputs(entries, is_chatml, response_format_type, tools) + return _stream_llm_outputs(entries, response_format_type) else: return _parse_entries(entries, is_chatml, response_format_type, tools) else: @@ -396,15 +396,28 @@ def run_language_model( ) ret = [msg.strip() for msg in msgs] if stream: - ret = [ret] + ret = [ + [ + format_chat_entry(role=CHATML_ROLE_ASSISTANT, content=msg) + for msg in ret + ] + ] return ret -def _stream_llm_outputs(result, is_chatml, response_format_type, tools): +def _stream_llm_outputs( + result: list | typing.Generator[list[ConversationEntry], None, None], + response_format_type: typing.Literal["text", "json_object"] | None, +): if isinstance(result, list): # compatibility with non-streaming apis result = [result] for entries in result: - yield _parse_entries(entries, is_chatml, response_format_type, tools) + if response_format_type == "json_object": + for i, entry in enumerate(entries): + entries[i] = json.loads(entry["content"]) + for i, entry in enumerate(entries): + entries[i]["content"] = entry.get("content") or "" + yield entries def _parse_entries( @@ -576,24 +589,28 @@ def _stream_openai_chunked( start_chunk_size: int = 50, stop_chunk_size: int = 400, step_chunk_size: int = 150, -): +) -> typing.Generator[list[ConversationEntry], None, None]: ret = [] chunk_size = start_chunk_size for completion_chunk in r: changed = False for choice in completion_chunk.choices: + delta = choice.delta try: + # get the entry for this choice entry = ret[choice.index] except IndexError: # initialize the entry - entry = choice.delta.dict() | {"content": "", "chunk": ""} + entry = delta.dict() | {"content": "", "chunk": ""} ret.append(entry) + # this is to mark the end of streaming + entry["finish_reason"] = choice.finish_reason # append the delta to the current chunk - if not choice.delta.content: + if not delta.content: continue - entry["chunk"] += choice.delta.content + entry["chunk"] += delta.content # if the chunk is too small, we need to wait for more data chunk = entry["chunk"] if len(chunk) < chunk_size: diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py index 7cc810b1a..e95fb2ab2 100644 --- a/daras_ai_v2/search_ref.py +++ b/daras_ai_v2/search_ref.py @@ -161,6 +161,8 @@ def format_citations( def format_footnotes( all_refs: dict[int, SearchReference], formatted: str, citation_style: CitationStyles ) -> str: + if not all_refs: + return formatted match citation_style: case CitationStyles.number_markdown: formatted += "\n\n" diff --git a/daras_ai_v2/slack_bot.py b/daras_ai_v2/slack_bot.py index f75863133..10c88f7d7 100644 --- a/daras_ai_v2/slack_bot.py +++ b/daras_ai_v2/slack_bot.py @@ -11,7 +11,7 @@ from bots.models import BotIntegration, Platform, Conversation from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2.asr import run_google_translate, audio_bytes_to_wav -from daras_ai_v2.bots import BotInterface +from daras_ai_v2.bots import BotInterface, SLACK_MAX_SIZE from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.functional import fetch_parallel from daras_ai_v2.text_splitter import text_splitter @@ -27,11 +27,10 @@ I have been configured for $user_language and will respond to you in that language. """.strip() -SLACK_MAX_SIZE = 3000 - class SlackBot(BotInterface): platform = Platform.SLACK + can_update_message = True _read_rcpt_ts: str | None = None @@ -135,6 +134,7 @@ def send_msg( buttons: list[ReplyButton] = None, documents: list[str] = None, should_translate: bool = False, + update_msg_id: str | None = None, ) -> str | None: if text and should_translate and self.language and self.language != "en": text = run_google_translate( @@ -150,7 +150,9 @@ def send_msg( ) self._read_rcpt_ts = None - self._msg_ts = self.send_msg_to( + if not self.can_update_message: + update_msg_id = None + self._msg_ts, num_splits = self.send_msg_to( text=text, audio=audio, video=video, @@ -160,7 +162,10 @@ def send_msg( username=self.convo.bot_integration.name, token=self._access_token, thread_ts=self._msg_ts, + update_msg_ts=update_msg_id, ) + if num_splits > 1: + self.can_update_message = False return self._msg_ts @classmethod @@ -178,7 +183,8 @@ def send_msg_to( username: str, token: str, thread_ts: str = None, - ) -> str | None: + update_msg_ts: str = None, + ) -> tuple[str | None, int]: splits = text_splitter(text, chunk_size=SLACK_MAX_SIZE, length_function=len) for doc in splits[:-1]: thread_ts = chat_post_message( @@ -186,9 +192,11 @@ def send_msg_to( channel=channel, channel_is_personal=channel_is_personal, thread_ts=thread_ts, + update_msg_ts=update_msg_ts, username=username, token=token, ) + update_msg_ts = None thread_ts = chat_post_message( text=splits[-1].text, audio=audio, @@ -197,10 +205,11 @@ def send_msg_to( channel=channel, channel_is_personal=channel_is_personal, thread_ts=thread_ts, + update_msg_ts=update_msg_ts, username=username, token=token, ) - return thread_ts + return thread_ts, len(splits) def mark_read(self): text = self.convo.bot_integration.slack_read_receipt_msg.strip() @@ -524,9 +533,10 @@ def chat_post_message( channel: str, thread_ts: str, token: str, + update_msg_ts: str = None, channel_is_personal: bool = False, - audio: str | None = None, - video: str | None = None, + audio: str = None, + video: str = None, username: str = "Video Bot", buttons: list[ReplyButton] = None, ) -> str | None: @@ -535,28 +545,52 @@ def chat_post_message( if channel_is_personal: # don't thread in personal channels thread_ts = None - res = requests.post( - "https://slack.com/api/chat.postMessage", - json={ - "channel": channel, - "thread_ts": thread_ts, - "text": text, - "username": username, - "icon_emoji": ":robot_face:", - "blocks": [ - { - "type": "section", - "text": {"type": "mrkdwn", "text": text}, - }, - ] - + create_file_block("Audio", token, audio) - + create_file_block("Video", token, video) - + create_button_block(buttons), - }, - headers={ - "Authorization": f"Bearer {token}", - }, - ) + if update_msg_ts: + res = requests.post( + "https://slack.com/api/chat.update", + json={ + "channel": channel, + "ts": update_msg_ts, + "text": text, + "username": username, + "icon_emoji": ":robot_face:", + "blocks": [ + { + "type": "section", + "text": {"type": "mrkdwn", "text": text}, + }, + ] + + create_file_block("Audio", token, audio) + + create_file_block("Video", token, video) + + create_button_block(buttons), + }, + headers={ + "Authorization": f"Bearer {token}", + }, + ) + else: + res = requests.post( + "https://slack.com/api/chat.postMessage", + json={ + "channel": channel, + "thread_ts": thread_ts, + "text": text, + "username": username, + "icon_emoji": ":robot_face:", + "blocks": [ + { + "type": "section", + "text": {"type": "mrkdwn", "text": text}, + }, + ] + + create_file_block("Audio", token, audio) + + create_file_block("Video", token, video) + + create_button_block(buttons), + }, + headers={ + "Authorization": f"Bearer {token}", + }, + ) data = parse_slack_response(res) return data.get("ts") diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 394d55b55..b2cc187d0 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -88,11 +88,11 @@ def get_top_k_references( Returns: the top k documents """ - yield "Checking docs..." + yield "Fetching latest knowledge docs..." input_docs = request.documents or [] doc_metas = map_parallel(doc_url_to_metadata, input_docs) - yield "Getting embeddings..." + yield "Creating knowledge embeddings..." embeds: list[tuple[SearchReference, np.ndarray]] = flatmap_parallel( lambda f_url, doc_meta: get_embeds_for_doc( f_url=f_url, @@ -107,7 +107,7 @@ def get_top_k_references( ) dense_query_embeds = openai_embedding_create([request.search_query])[0] - yield "Searching documents..." + yield "Searching knowledge base..." dense_weight = request.dense_weight if dense_weight is None: # for backwards compatibility @@ -133,7 +133,7 @@ def get_top_k_references( dense_ranks = np.zeros(len(embeds)) if sparse_weight: - yield "Getting sparse scores..." + yield "Considering results..." # get sparse scores bm25_corpus = flatmap_parallel( lambda f_url, doc_meta: get_bm25_embeds_for_doc( diff --git a/gooey_ui/pubsub.py b/gooey_ui/pubsub.py index 5e210c956..ef2ba1539 100644 --- a/gooey_ui/pubsub.py +++ b/gooey_ui/pubsub.py @@ -2,6 +2,7 @@ import json import threading import typing +from contextlib import contextmanager from time import time import redis @@ -49,6 +50,34 @@ def realtime_push(channel: str, value: typing.Any = "ping"): logger.info(f"publish {channel=}") +@contextmanager +def realtime_subscribe(channel: str) -> typing.Generator: + channel = f"gooey-gui/state/{channel}" + pubsub = r.pubsub() + pubsub.subscribe(channel) + logger.info(f"subscribe {channel=}") + try: + yield _realtime_sub_gen(channel, pubsub) + finally: + logger.info(f"unsubscribe {channel=}") + pubsub.unsubscribe(channel) + pubsub.close() + + +def _realtime_sub_gen(channel: str, pubsub: redis.client.PubSub) -> typing.Generator: + while True: + message = pubsub.get_message(timeout=10) + if not (message and message["type"] == "message"): + continue + value = json.loads(r.get(channel)) + if isinstance(value, dict): + run_status = value.get("__run_status") + logger.info(f"realtime_subscribe: {channel=} {run_status=}") + else: + logger.info(f"realtime_subscribe: {channel=}") + yield value + + # def use_state( # value: T = None, *, key: str = None # ) -> tuple[T, typing.Callable[[T], None]]: diff --git a/poetry.lock b/poetry.lock index edbc668d1..a5c39450e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -13,13 +13,13 @@ files = [ [[package]] name = "aifail" -version = "0.1.0" +version = "0.2.0" description = "" optional = false python-versions = ">=3.10,<4.0" files = [ - {file = "aifail-0.1.0-py3-none-any.whl", hash = "sha256:d51fa56b1b8531a298a38cf91a3979f7d427afc88f5af635f86865b8f721bd23"}, - {file = "aifail-0.1.0.tar.gz", hash = "sha256:40f95fd45c07f9a7f0478d9702e3ea0d811f8582da7f6b841618de2a52803177"}, + {file = "aifail-0.2.0-py3-none-any.whl", hash = "sha256:83f3a842dbe523ee10a4d53ee00f06e794176122b93b57752566fb98d60db603"}, + {file = "aifail-0.2.0.tar.gz", hash = "sha256:d3e19e16740577181922055883a00e894d11413d12803e8d443094846040d6df"}, ] [package.dependencies] @@ -6510,4 +6510,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "a5ef77e11ff5b9bb9a5ab5ec5a07c2d3ffa0ef3204d332653d3039e8aba04e14" +content-hash = "a0d934b2a9b3b5c54d3b14ad5e524ff4b163eaa5eaeb794fe38bc3a8d8d60e04" diff --git a/pyproject.toml b/pyproject.toml index c6a699024..555778a69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ ua-parser = "^0.18.0" user-agents = "^2.2.0" openpyxl = "^3.1.2" loguru = "^0.7.2" -aifail = "^0.1.0" +aifail = "0.2.0" pytest-playwright = "^0.4.3" [tool.poetry.group.dev.dependencies] diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index c07bcffa6..bb246d6c3 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -106,8 +106,8 @@ def run(self, state: dict) -> typing.Iterator[str | None]: avoid_repetition=request.avoid_repetition, stream=True, ) - for i, item in enumerate(ret): - output_text[selected_model] = item + for i, entries in enumerate(ret): + output_text[selected_model] = [e["content"] for e in entries] yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." def render_output(self): diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 44994b7ff..452db9f17 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -713,7 +713,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: query_instructions = (request.query_instructions or "").strip() if query_instructions: - yield "Generating search query..." + yield "Creating search query..." state["final_search_query"] = generate_final_search_query( request=request, instructions=query_instructions, @@ -727,7 +727,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: keyword_instructions = (request.keyword_instructions or "").strip() if keyword_instructions: - yield "Extracting keywords..." + yield "Finding keywords..." k_request = request.copy() # other models dont support JSON mode k_request.selected_model = LargeLanguageModels.gpt_4_turbo.name @@ -809,7 +809,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: if max_allowed_tokens < 0: raise ValueError("Input Script is too long! Please reduce the script size.") - yield f"Running {model.value}..." + yield f"Summarizing with {model.value}..." if is_chat_model: chunks = run_language_model( model=request.selected_model, @@ -842,15 +842,16 @@ def run(self, state: dict) -> typing.Iterator[str | None]: citation_style = ( request.citation_style and CitationStyles[request.citation_style] ) or None - all_refs_list = [] - for i, output_text in enumerate(chunks): + for i, entries in enumerate(chunks): + if not entries: + continue + output_text = [entry["content"] for entry in entries] if request.tools: - output_text, tool_call_choices = output_text + # output_text, tool_call_choices = output_text state["output_documents"] = output_documents = [] - for tool_calls in tool_call_choices: - for call in tool_calls: - result = yield from exec_tool_call(call) - output_documents.append(result) + for call in entries[0].get("tool_calls") or []: + result = yield from exec_tool_call(call) + output_documents.append(result) # save model response state["raw_output_text"] = [ @@ -876,13 +877,20 @@ def run(self, state: dict) -> typing.Iterator[str | None]: all_refs_list = apply_response_formattings_prefix( output_text, references, citation_style ) + else: + all_refs_list = None + state["output_text"] = output_text - yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." + if all(entry.get("finish_reason") for entry in entries): + if all_refs_list: + apply_response_formattings_suffix( + all_refs_list, state["output_text"], citation_style + ) + finish_reason = entries[0]["finish_reason"] + yield f"Completed with {finish_reason=}" # avoid changing this message since it's used to detect end of stream + else: + yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." - if all_refs_list: - apply_response_formattings_suffix( - all_refs_list, state["output_text"], citation_style - ) state["output_audio"] = [] state["output_video"] = [] @@ -1042,12 +1050,27 @@ def messenger_bot_integration(self): col1, col2, col3, *_ = st.columns([1, 1, 2]) with col1: favicon = Platform(bi.platform).get_favicon() + if bi.published_run: + url = self.app_url( + example_id=bi.published_run.published_run_id, + tab_name=MenuTabs.paths[MenuTabs.integrations], + ) + elif bi.saved_run: + url = self.app_url( + run_id=bi.saved_run.run_id, + uid=bi.saved_run.uid, + example_id=bi.saved_run.example_id, + tab_name=MenuTabs.paths[MenuTabs.integrations], + ) + else: + url = None + if url: + href = f'{bi}' + else: + f"{bi}" with st.div(className="mt-2"): st.markdown( - f'  ' - f'{bi}' - if bi.saved_run - else f"{bi}", + f'  {href}', unsafe_allow_html=True, ) with col2: @@ -1082,7 +1105,10 @@ def messenger_bot_integration(self): st.session_state.get("user_language") or bi.user_language ) bi.saved_run = current_run - bi.published_run = published_run + if published_run and published_run.saved_run_id == current_run.id: + bi.published_run = published_run + else: + bi.published_run = None if bi.platform == Platform.SLACK: from daras_ai_v2.slack_bot import send_confirmation_msg @@ -1121,6 +1147,7 @@ def slack_specific_settings(self, bi: BotIntegration): value=bi.name, key=f"_bi_name_{bi.id}", ) + st.caption("Enable streaming messages to Slack in real-time.") def show_landbot_widget(): diff --git a/routers/api.py b/routers/api.py index c57d948a6..e58ae3b2b 100644 --- a/routers/api.py +++ b/routers/api.py @@ -367,7 +367,7 @@ def build_api_response( run_async: bool, created_at: str, ): - web_url = str(furl(self.app_url(run_id=run_id, uid=uid))) + web_url = self.app_url(run_id=run_id, uid=uid) if run_async: status_url = str( furl(settings.API_BASE_URL, query_params=dict(run_id=run_id)) diff --git a/scripts/test_wa_msg_send.py b/scripts/test_wa_msg_send.py new file mode 100644 index 000000000..c31fca0af --- /dev/null +++ b/scripts/test_wa_msg_send.py @@ -0,0 +1,124 @@ +from time import sleep + +from daras_ai_v2.bots import _feedback_start_buttons +from daras_ai_v2.facebook_bots import WhatsappBot + + +def run(bot_number: str, user_number: str): + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="", + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Text With Buttons", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Audio with text", + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Audio + Video", + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + video="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6f019f2a-b714-11ee-82f3-02420a000172/gooey.ai%20lipsync.mp4#t=0.001", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + video="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6f019f2a-b714-11ee-82f3-02420a000172/gooey.ai%20lipsync.mp4#t=0.001", + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + video="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6f019f2a-b714-11ee-82f3-02420a000172/gooey.ai%20lipsync.mp4#t=0.001", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Video with text", + video="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/6f019f2a-b714-11ee-82f3-02420a000172/gooey.ai%20lipsync.mp4#t=0.001", + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Some Docs", + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Some Docs", + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + buttons=_feedback_start_buttons(), + ) + sleep(1) + WhatsappBot.send_msg_to( + bot_number=bot_number, + user_number=user_number, + text="Audio + Docs", + audio="https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d949d330-95cb-11ee-9a21-02420a00012e/google_tts_gen.mp3", + documents=[ + "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/d30155b8-b438-11ed-85e8-02420a0001f7/Chetan%20Bhagat%20-three%20mistakes%20of%20my%20life.pdf" + ], + buttons=_feedback_start_buttons(), + ) From 2b54d88f5a1e25d6dac6d2fc67acde95f51e9ad1 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 2 Feb 2024 18:22:36 +0530 Subject: [PATCH 45/85] migrations --- .../0056_botintegration_streaming_enabled.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 bots/migrations/0056_botintegration_streaming_enabled.py diff --git a/bots/migrations/0056_botintegration_streaming_enabled.py b/bots/migrations/0056_botintegration_streaming_enabled.py new file mode 100644 index 000000000..dcf1366fa --- /dev/null +++ b/bots/migrations/0056_botintegration_streaming_enabled.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.7 on 2024-01-31 19:14 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0055_workflowmetadata"), + ] + + operations = [ + migrations.AddField( + model_name="botintegration", + name="streaming_enabled", + field=models.BooleanField( + default=False, + help_text="If set, the bot will stream messages to the frontend", + ), + ), + ] From 4783f7db107d49bc02692b7283bd1fa1bdf7cec9 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 2 Feb 2024 18:40:49 +0530 Subject: [PATCH 46/85] fix local variable 'href' referenced before assignment --- recipes/VideoBots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 452db9f17..17d05e770 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1067,7 +1067,7 @@ def messenger_bot_integration(self): if url: href = f'{bi}' else: - f"{bi}" + href = f"{bi}" with st.div(className="mt-2"): st.markdown( f'  {href}', From d2da2d10d84c6ca7e07e964286333cb7847f39dc Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 20:33:57 +0530 Subject: [PATCH 47/85] use timezone.now() --- daras_ai_v2/bots.py | 13 +++++-------- recipes/VideoBots.py | 6 +++--- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index edaf0ed32..9f2ebd555 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -1,16 +1,15 @@ import mimetypes import traceback import typing +from datetime import datetime from urllib.parse import parse_qs -import pytz -from datetime import datetime from django.db import transaction +from django.utils import timezone from fastapi import HTTPException, Request from furl import furl from sentry_sdk import capture_exception -from daras_ai_v2 import settings from app_users.models import AppUser from bots.models import ( Platform, @@ -191,7 +190,7 @@ def _on_msg(bot: BotInterface): speech_run = None input_images = None input_documents = None - recieved_time: datetime = datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) + recieved_time: datetime = timezone.now() if not bot.page_cls: bot.send_msg(text=PAGE_NOT_CONNECTED_ERROR) return @@ -463,8 +462,7 @@ def _save_msgs( if speech_run else None ), - response_time=datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) - - received_time, + response_time=timezone.now() - received_time, ) attachments = [] for f_url in (input_images or []) + (input_documents or []): @@ -481,8 +479,7 @@ def _save_msgs( saved_run=SavedRun.objects.get_or_create( workflow=Workflow.VIDEO_BOTS, **furl(url).query.params )[0], - response_time=datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) - - received_time, + response_time=timezone.now() - received_time, ) # save the messages & attachments with transaction.atomic(): diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 5d23502ff..75615e6a7 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1124,9 +1124,9 @@ def slack_specific_settings(self, bi: BotIntegration): st.session_state[f"_bi_name_{bi.id}"] = ( pr and pr.title ) or self.get_recipe_title() - st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( - BotIntegration._meta.get_field("slack_read_receipt_msg").default - ) + st.session_state[ + f"_bi_slack_read_receipt_msg_{bi.id}" + ] = BotIntegration._meta.get_field("slack_read_receipt_msg").default bi.slack_read_receipt_msg = st.text_input( """ From b063439037904c706e4101db916f1b68e9fbc366 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 20:42:30 +0530 Subject: [PATCH 48/85] make response_time null by default --- bots/migrations/0056_message_response_time.py | 19 --------------- .../0057_message_response_time_and_more.py | 23 +++++++++++++++++++ bots/models.py | 8 +++++-- 3 files changed, 29 insertions(+), 21 deletions(-) delete mode 100644 bots/migrations/0056_message_response_time.py create mode 100644 bots/migrations/0057_message_response_time_and_more.py diff --git a/bots/migrations/0056_message_response_time.py b/bots/migrations/0056_message_response_time.py deleted file mode 100644 index 06730e931..000000000 --- a/bots/migrations/0056_message_response_time.py +++ /dev/null @@ -1,19 +0,0 @@ -# Generated by Django 4.2.7 on 2024-02-01 20:15 - -import datetime -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('bots', '0055_workflowmetadata'), - ] - - operations = [ - migrations.AddField( - model_name='message', - name='response_time', - field=models.DurationField(default=datetime.timedelta(days=-1, seconds=86399), help_text='The time it took for the bot to respond to the corresponding user message'), - ), - ] diff --git a/bots/migrations/0057_message_response_time_and_more.py b/bots/migrations/0057_message_response_time_and_more.py new file mode 100644 index 000000000..5bdecd39c --- /dev/null +++ b/bots/migrations/0057_message_response_time_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.7 on 2024-02-05 15:11 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0056_botintegration_streaming_enabled'), + ] + + operations = [ + migrations.AddField( + model_name='message', + name='response_time', + field=models.DurationField(default=None, help_text='The time it took for the bot to respond to the corresponding user message', null=True), + ), + migrations.AlterField( + model_name='botintegration', + name='streaming_enabled', + field=models.BooleanField(default=False, help_text='If set, the bot will stream messages to the frontend (Slack only)'), + ), + ] diff --git a/bots/models.py b/bots/models.py index 4aa2c3d4a..8e2a41dc8 100644 --- a/bots/models.py +++ b/bots/models.py @@ -932,7 +932,10 @@ def to_df_format( else None ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, - "Response Time": round(message.response_time.total_seconds(), 1), + "Response Time": ( + message.response_time + and round(message.response_time.total_seconds(), 1) + ), } rows.append(row) df = pd.DataFrame.from_records( @@ -1054,7 +1057,8 @@ class Message(models.Model): ) response_time = models.DurationField( - default=datetime.timedelta(seconds=-1), + default=None, + null=True, help_text="The time it took for the bot to respond to the corresponding user message", ) From f5466022b9eb943cf5d89bbb72f503b9a22b8df3 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 20:56:38 +0530 Subject: [PATCH 49/85] fix run_time display in copilot stats add resposne_time to admin --- bots/admin.py | 2 ++ recipes/VideoBotsStats.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index 4f69a1081..0b6b475cc 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -492,6 +492,7 @@ class MessageAdmin(admin.ModelAdmin): "prev_msg_content", "prev_msg_display_content", "prev_msg_saved_run", + "response_time", ] ordering = ["created_at"] actions = [export_to_csv, export_to_excel] @@ -551,6 +552,7 @@ def get_fieldsets(self, request, msg: Message = None): "Analysis", { "fields": [ + "response_time", "analysis_result", "analysis_run", "question_answered", diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 718b580f1..db88f4dde 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -338,15 +338,12 @@ def render_date_view_inputs(self, bi): def parse_run_info(self, bi): saved_run = bi.get_active_saved_run() - run_title = ( - bi.published_run.title - if bi.published_run - else ( - saved_run.page_title - if saved_run and saved_run.page_title - else "This Copilot Run" if saved_run else "No Run Connected" - ) - ) + if bi.published_run: + run_title = bi.published_run.title + elif saved_run: + run_title = "This Copilot Run" + else: + run_title = "No Run Connected" run_url = furl(saved_run.get_app_url()).tostr() if saved_run else "" return run_title, run_url @@ -424,7 +421,7 @@ def calculate_stats_binned_by_time( created_at__date__gte=start_date, created_at__date__lte=end_date, conversation__bot_integration=bi, - role=CHATML_ROLE_USER, + role=CHATML_ROLE_ASSISTANT, ) .order_by() .annotate(date=trunc_fn("created_at")) From 2be2378a5e1142c4b8b8f503850d76154d954882 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:02:33 +0530 Subject: [PATCH 50/85] use timezone.now() instead of datetime.now() --- recipes/VideoBotsStats.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index db88f4dde..c18683b96 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -1,3 +1,5 @@ +from django.utils import timezone + from daras_ai_v2.base import BasePage, MenuTabs import gooey_ui as st from furl import furl @@ -287,9 +289,9 @@ def render(self): def render_date_view_inputs(self, bi): if st.checkbox("Show All"): start_date = bi.created_at - end_date = datetime.now() + end_date = timezone.now() else: - start_of_year_date = datetime.now().replace(month=1, day=1) + start_of_year_date = timezone.now().replace(month=1, day=1) st.session_state.setdefault( "start_date", self.request.query_params.get( @@ -302,11 +304,11 @@ def render_date_view_inputs(self, bi): st.session_state.setdefault( "end_date", self.request.query_params.get( - "end_date", datetime.now().strftime("%Y-%m-%d") + "end_date", timezone.now().strftime("%Y-%m-%d") ), ) end_date: datetime = ( - st.date_input("End date", key="end_date") or datetime.now() + st.date_input("End date", key="end_date") or timezone.now() ) st.session_state.setdefault( "view", self.request.query_params.get("view", "Weekly") @@ -359,7 +361,7 @@ def calculate_overall_stats(self, bid, bi, run_title, run_url): num_active_users_last_7_days = ( user_messages.filter( conversation__in=users, - created_at__gte=datetime.now() - timedelta(days=7), + created_at__gte=timezone.now() - timedelta(days=7), ) .distinct( *ID_COLUMNS, @@ -369,7 +371,7 @@ def calculate_overall_stats(self, bid, bi, run_title, run_url): num_active_users_last_30_days = ( user_messages.filter( conversation__in=users, - created_at__gte=datetime.now() - timedelta(days=30), + created_at__gte=timezone.now() - timedelta(days=30), ) .distinct( *ID_COLUMNS, @@ -544,7 +546,9 @@ def calculate_stats_binned_by_time( ) * 100 df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] - df["Average_response_time"] = df["Average_response_time"] * factor + df["Average_response_time"] = ( + df["Average_response_time"].dt.total_seconds() * factor + ) df.fillna(0, inplace=True) df = df.round(0).astype("int32", errors="ignore") return df From 3ac7299c9ed18ac9948ac88768d9d8f74daae0ec Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:06:32 +0530 Subject: [PATCH 51/85] silence black --- recipes/VideoBots.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 75615e6a7..5d23502ff 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1124,9 +1124,9 @@ def slack_specific_settings(self, bi: BotIntegration): st.session_state[f"_bi_name_{bi.id}"] = ( pr and pr.title ) or self.get_recipe_title() - st.session_state[ - f"_bi_slack_read_receipt_msg_{bi.id}" - ] = BotIntegration._meta.get_field("slack_read_receipt_msg").default + st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( + BotIntegration._meta.get_field("slack_read_receipt_msg").default + ) bi.slack_read_receipt_msg = st.text_input( """ From 4aa53e70bfae3f0b1ee650f4cf945566135ebee9 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:27:43 +0530 Subject: [PATCH 52/85] fix AttributeError: Can only use .dt accessor with datetimelike values --- recipes/VideoBotsStats.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index c18683b96..00c7f07dd 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -546,9 +546,11 @@ def calculate_stats_binned_by_time( ) * 100 df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] - df["Average_response_time"] = ( - df["Average_response_time"].dt.total_seconds() * factor - ) + try: + df["Average_response_time"] = df["Average_response_time"].dt.total_seconds() + except AttributeError: + pass + df["Average_response_time"] = df["Average_response_time"] * factor df.fillna(0, inplace=True) df = df.round(0).astype("int32", errors="ignore") return df From 84a77a4f1efb945a4aa7e30caec9bf3d73b45b35 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:13:16 +0530 Subject: [PATCH 53/85] update black --- poetry.lock | 44 ++++++++++++++++++++++++-------------------- pyproject.toml | 2 +- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/poetry.lock b/poetry.lock index a5c39450e..0f225fd4e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -323,29 +323,33 @@ files = [ [[package]] name = "black" -version = "23.10.1" +version = "23.12.1" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.10.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:ec3f8e6234c4e46ff9e16d9ae96f4ef69fa328bb4ad08198c8cee45bb1f08c69"}, - {file = "black-23.10.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:1b917a2aa020ca600483a7b340c165970b26e9029067f019e3755b56e8dd5916"}, - {file = "black-23.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c74de4c77b849e6359c6f01987e94873c707098322b91490d24296f66d067dc"}, - {file = "black-23.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:7b4d10b0f016616a0d93d24a448100adf1699712fb7a4efd0e2c32bbb219b173"}, - {file = "black-23.10.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b15b75fc53a2fbcac8a87d3e20f69874d161beef13954747e053bca7a1ce53a0"}, - {file = "black-23.10.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:e293e4c2f4a992b980032bbd62df07c1bcff82d6964d6c9496f2cd726e246ace"}, - {file = "black-23.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d56124b7a61d092cb52cce34182a5280e160e6aff3137172a68c2c2c4b76bcb"}, - {file = "black-23.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f157a8945a7b2d424da3335f7ace89c14a3b0625e6593d21139c2d8214d55ce"}, - {file = "black-23.10.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:cfcce6f0a384d0da692119f2d72d79ed07c7159879d0bb1bb32d2e443382bf3a"}, - {file = "black-23.10.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:33d40f5b06be80c1bbce17b173cda17994fbad096ce60eb22054da021bf933d1"}, - {file = "black-23.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:840015166dbdfbc47992871325799fd2dc0dcf9395e401ada6d88fe11498abad"}, - {file = "black-23.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:037e9b4664cafda5f025a1728c50a9e9aedb99a759c89f760bd83730e76ba884"}, - {file = "black-23.10.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:7cb5936e686e782fddb1c73f8aa6f459e1ad38a6a7b0e54b403f1f05a1507ee9"}, - {file = "black-23.10.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:7670242e90dc129c539e9ca17665e39a146a761e681805c54fbd86015c7c84f7"}, - {file = "black-23.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed45ac9a613fb52dad3b61c8dea2ec9510bf3108d4db88422bacc7d1ba1243d"}, - {file = "black-23.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6d23d7822140e3fef190734216cefb262521789367fbdc0b3f22af6744058982"}, - {file = "black-23.10.1-py3-none-any.whl", hash = "sha256:d431e6739f727bb2e0495df64a6c7a5310758e87505f5f8cde9ff6c0f2d7e4fe"}, - {file = "black-23.10.1.tar.gz", hash = "sha256:1f8ce316753428ff68749c65a5f7844631aa18c8679dfd3ca9dc1a289979c258"}, + {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, + {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, + {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, + {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, + {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, + {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, + {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, + {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, + {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, + {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, + {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, + {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, + {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, + {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, + {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, + {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, + {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, + {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, + {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, + {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, + {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, + {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, ] [package.dependencies] @@ -359,7 +363,7 @@ typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} [package.extras] colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] diff --git a/pyproject.toml b/pyproject.toml index aac74b09c..44dbff5da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ django-phonenumber-field = { extras = ["phonenumberslite"], version = "^7.0.2" } gunicorn = "^20.1.0" psycopg2-binary = "^2.9.6" whitenoise = "^6.4.0" -black = "^23.3.0" +black = "^24.1.1" django-extensions = "^3.2.1" pytest-django = "^4.5.2" celery = "^5.3.1" From 3b321f5e1e974538121cb148be9ace1bede53190 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:26:43 +0530 Subject: [PATCH 54/85] fix test urls --- conftest.py | 5 +++-- tests/test_apis.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/conftest.py b/conftest.py index e8534d5f2..dba333885 100644 --- a/conftest.py +++ b/conftest.py @@ -47,8 +47,9 @@ def _mock_gui_runner( def threadpool_subtest(subtests, max_workers: int = 8): ts = [] - def submit(fn, *args, **kwargs): - msg = "--".join(map(str, [*args, *kwargs.values()])) + def submit(fn, *args, msg=None, **kwargs): + if not msg: + msg = "--".join(map(str, [*args, *kwargs.values()])) @wraps(fn) def runner(*args, **kwargs): diff --git a/tests/test_apis.py b/tests/test_apis.py index 6412e9eda..96c27c7ef 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -73,7 +73,7 @@ def test_apis_examples(mock_gui_runner, force_authentication, threadpool_subtest hidden=False, example_id__isnull=False, ): - threadpool_subtest(_test_apis_examples, sr) + threadpool_subtest(_test_apis_examples, sr, msg=sr.get_app_url()) def _test_apis_examples(sr: SavedRun): From 1c0c990bb450934b767bc972e66df9f12f41293c Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 6 Feb 2024 19:35:11 +0530 Subject: [PATCH 55/85] Copilot: Turn off collecting detailed feedback off thumbs up or down --- daras_ai_v2/bots.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 9f2ebd555..30acf07a3 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -509,13 +509,25 @@ def _handle_interactive_msg(bot: BotInterface): return if button_id == ButtonIds.feedback_thumbs_up: rating = Feedback.Rating.RATING_THUMBS_UP - bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_UP - response_text = FEEDBACK_THUMBS_UP_MSG + # bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_UP + # response_text = FEEDBACK_THUMBS_UP_MSG else: rating = Feedback.Rating.RATING_THUMBS_DOWN - bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_DOWN - response_text = FEEDBACK_THUMBS_DOWN_MSG + # bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_DOWN + # response_text = FEEDBACK_THUMBS_DOWN_MSG + response_text = FEEDBACK_CONFIRMED_MSG.format( + bot_name=str(bot.convo.bot_integration.name) + ) bot.convo.save() + # save the feedback + Feedback.objects.create(message=context_msg, rating=rating) + # send a confirmation msg + post click buttons + bot.send_msg( + text=response_text, + # buttons=_feedback_post_click_buttons(), + should_translate=True, + ) + # handle skip case ButtonIds.action_skip: bot.send_msg(text=TAPPED_SKIP_MSG, should_translate=True) @@ -523,6 +535,7 @@ def _handle_interactive_msg(bot: BotInterface): bot.convo.state = ConvoState.INITIAL bot.convo.save() return + # not sure what button was pressed, ignore case _: bot_name = str(bot.convo.bot_integration.name) @@ -534,14 +547,6 @@ def _handle_interactive_msg(bot: BotInterface): bot.convo.state = ConvoState.INITIAL bot.convo.save() return - # save the feedback - Feedback.objects.create(message=context_msg, rating=rating) - # send a confirmation msg + post click buttons - bot.send_msg( - text=response_text, - buttons=_feedback_post_click_buttons(), - should_translate=True, - ) def _handle_audio_msg(billing_account_user, bot: BotInterface): From 5d82e5db8ca5c702d16f1118705e80b0625e4e66 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 6 Feb 2024 22:46:08 +0530 Subject: [PATCH 56/85] lockfile --- poetry.lock | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0f225fd4e..587185ab8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -323,33 +323,33 @@ files = [ [[package]] name = "black" -version = "23.12.1" +version = "24.1.1" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, - {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, - {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, - {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, - {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, - {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, - {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, - {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, - {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, - {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, - {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, - {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, - {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, - {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, - {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, - {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, - {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, - {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, - {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, - {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, - {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, - {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, + {file = "black-24.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2588021038bd5ada078de606f2a804cadd0a3cc6a79cb3e9bb3a8bf581325a4c"}, + {file = "black-24.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a95915c98d6e32ca43809d46d932e2abc5f1f7d582ffbe65a5b4d1588af7445"}, + {file = "black-24.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fa6a0e965779c8f2afb286f9ef798df770ba2b6cee063c650b96adec22c056a"}, + {file = "black-24.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:5242ecd9e990aeb995b6d03dc3b2d112d4a78f2083e5a8e86d566340ae80fec4"}, + {file = "black-24.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fc1ec9aa6f4d98d022101e015261c056ddebe3da6a8ccfc2c792cbe0349d48b7"}, + {file = "black-24.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0269dfdea12442022e88043d2910429bed717b2d04523867a85dacce535916b8"}, + {file = "black-24.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3d64db762eae4a5ce04b6e3dd745dcca0fb9560eb931a5be97472e38652a161"}, + {file = "black-24.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:5d7b06ea8816cbd4becfe5f70accae953c53c0e53aa98730ceccb0395520ee5d"}, + {file = "black-24.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e2c8dfa14677f90d976f68e0c923947ae68fa3961d61ee30976c388adc0b02c8"}, + {file = "black-24.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a21725862d0e855ae05da1dd25e3825ed712eaaccef6b03017fe0853a01aa45e"}, + {file = "black-24.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07204d078e25327aad9ed2c64790d681238686bce254c910de640c7cc4fc3aa6"}, + {file = "black-24.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:a83fe522d9698d8f9a101b860b1ee154c1d25f8a82ceb807d319f085b2627c5b"}, + {file = "black-24.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08b34e85170d368c37ca7bf81cf67ac863c9d1963b2c1780c39102187ec8dd62"}, + {file = "black-24.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7258c27115c1e3b5de9ac6c4f9957e3ee2c02c0b39222a24dc7aa03ba0e986f5"}, + {file = "black-24.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40657e1b78212d582a0edecafef133cf1dd02e6677f539b669db4746150d38f6"}, + {file = "black-24.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e298d588744efda02379521a19639ebcd314fba7a49be22136204d7ed1782717"}, + {file = "black-24.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:34afe9da5056aa123b8bfda1664bfe6fb4e9c6f311d8e4a6eb089da9a9173bf9"}, + {file = "black-24.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:854c06fb86fd854140f37fb24dbf10621f5dab9e3b0c29a690ba595e3d543024"}, + {file = "black-24.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3897ae5a21ca132efa219c029cce5e6bfc9c3d34ed7e892113d199c0b1b444a2"}, + {file = "black-24.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:ecba2a15dfb2d97105be74bbfe5128bc5e9fa8477d8c46766505c1dda5883aac"}, + {file = "black-24.1.1-py3-none-any.whl", hash = "sha256:5cdc2e2195212208fbcae579b931407c1fa9997584f0a415421748aeafff1168"}, + {file = "black-24.1.1.tar.gz", hash = "sha256:48b5760dcbfe5cf97fd4fba23946681f3a81514c6ab8a45b50da67ac8fbc6c7b"}, ] [package.dependencies] @@ -6514,4 +6514,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "a0d934b2a9b3b5c54d3b14ad5e524ff4b163eaa5eaeb794fe38bc3a8d8d60e04" +content-hash = "798d4cfff5a447e83bd2aa675525c89d6258c927839260a6f20aa71a633a7e81" From 570da24520935bc18457a740ac715740ceeaa7d1 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 6 Feb 2024 22:24:43 +0530 Subject: [PATCH 57/85] support openai streaming squash migrations better lookup for model when saving cost rename costs -> usage_costs, provider pricing -> model pricing add init script for llm pricing --- ...alter_savedrun_unique_together_and_more.py | 21 ++ bots/models.py | 3 +- celeryapp/tasks.py | 1 - costs/admin.py | 33 -- costs/cost_utils.py | 36 -- costs/migrations/0001_initial.py | 59 --- costs/migrations/0002_providerpricing.py | 79 ---- ...t_model_remove_usagecost_param_and_more.py | 64 ---- .../0004_remove_usagecost_saved_run.py | 16 - .../0005_remove_usagecost_provider_pricing.py | 16 - ...st_provider_pricing_usagecost_saved_run.py | 38 -- .../0007_alter_providerpricing_provider.py | 25 -- ...08_alter_providerpricing_param_and_more.py | 33 -- .../0009_alter_providerpricing_product.py | 19 - ...0010_remove_usagecost_calculation_notes.py | 16 - .../0011_alter_providerpricing_product.py | 35 -- .../0012_alter_providerpricing_product.py | 37 -- .../0013_alter_providerpricing_product.py | 37 -- .../0014_alter_providerpricing_product.py | 40 --- .../0015_alter_providerpricing_product.py | 43 --- costs/models.py | 113 ------ daras_ai_v2/language_model.py | 159 +++++--- daras_ai_v2/settings.py | 2 +- scripts/init_llm_pricing.py | 338 ++++++++++++++++++ {costs => usage_costs}/__init__.py | 0 usage_costs/admin.py | 75 ++++ {costs => usage_costs}/apps.py | 4 +- usage_costs/cost_utils.py | 34 ++ usage_costs/migrations/0001_initial.py | 56 +++ {costs => usage_costs}/migrations/__init__.py | 0 usage_costs/models.py | 114 ++++++ {costs => usage_costs}/tests.py | 0 {costs => usage_costs}/views.py | 0 33 files changed, 752 insertions(+), 794 deletions(-) create mode 100644 bots/migrations/0058_alter_savedrun_unique_together_and_more.py delete mode 100644 costs/admin.py delete mode 100644 costs/cost_utils.py delete mode 100644 costs/migrations/0001_initial.py delete mode 100644 costs/migrations/0002_providerpricing.py delete mode 100644 costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py delete mode 100644 costs/migrations/0004_remove_usagecost_saved_run.py delete mode 100644 costs/migrations/0005_remove_usagecost_provider_pricing.py delete mode 100644 costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py delete mode 100644 costs/migrations/0007_alter_providerpricing_provider.py delete mode 100644 costs/migrations/0008_alter_providerpricing_param_and_more.py delete mode 100644 costs/migrations/0009_alter_providerpricing_product.py delete mode 100644 costs/migrations/0010_remove_usagecost_calculation_notes.py delete mode 100644 costs/migrations/0011_alter_providerpricing_product.py delete mode 100644 costs/migrations/0012_alter_providerpricing_product.py delete mode 100644 costs/migrations/0013_alter_providerpricing_product.py delete mode 100644 costs/migrations/0014_alter_providerpricing_product.py delete mode 100644 costs/migrations/0015_alter_providerpricing_product.py delete mode 100644 costs/models.py create mode 100644 scripts/init_llm_pricing.py rename {costs => usage_costs}/__init__.py (100%) create mode 100644 usage_costs/admin.py rename {costs => usage_costs}/apps.py (67%) create mode 100644 usage_costs/cost_utils.py create mode 100644 usage_costs/migrations/0001_initial.py rename {costs => usage_costs}/migrations/__init__.py (100%) create mode 100644 usage_costs/models.py rename {costs => usage_costs}/tests.py (100%) rename {costs => usage_costs}/views.py (100%) diff --git a/bots/migrations/0058_alter_savedrun_unique_together_and_more.py b/bots/migrations/0058_alter_savedrun_unique_together_and_more.py new file mode 100644 index 000000000..a4e137127 --- /dev/null +++ b/bots/migrations/0058_alter_savedrun_unique_together_and_more.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.7 on 2024-02-06 18:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0057_message_response_time_and_more'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='savedrun', + unique_together={('run_id', 'uid'), ('workflow', 'example_id')}, + ), + migrations.AddIndex( + model_name='savedrun', + index=models.Index(fields=['run_id', 'uid'], name='bots_savedr_run_id_7b0b34_idx'), + ), + ] diff --git a/bots/models.py b/bots/models.py index 8e2a41dc8..205dfb5ff 100644 --- a/bots/models.py +++ b/bots/models.py @@ -260,7 +260,7 @@ class Meta: ordering = ["-updated_at"] unique_together = [ ["workflow", "example_id"], - ["workflow", "run_id", "uid"], + ["run_id", "uid"], ] constraints = [ models.CheckConstraint( @@ -273,6 +273,7 @@ class Meta: models.Index(fields=["-created_at"]), models.Index(fields=["-updated_at"]), models.Index(fields=["workflow"]), + models.Index(fields=["run_id", "uid"]), models.Index(fields=["workflow", "run_id", "uid"]), models.Index(fields=["workflow", "example_id", "run_id", "uid"]), models.Index(fields=["workflow", "example_id", "hidden"]), diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 7bd144091..1868c7a18 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -8,7 +8,6 @@ import gooey_ui as st from app_users.models import AppUser from bots.models import SavedRun -from costs.models import UsageCost from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings diff --git a/costs/admin.py b/costs/admin.py deleted file mode 100644 index f3d99d472..000000000 --- a/costs/admin.py +++ /dev/null @@ -1,33 +0,0 @@ -from django.contrib import admin -from costs import models - -# Register your models here. - - -@admin.register(models.UsageCost) -class CostsAdmin(admin.ModelAdmin): - list_display = [ - "saved_run", - "provider_pricing", - "quantity", - "notes", - "dollar_amt", - "created_at", - ] - - -@admin.register(models.ProviderPricing) -class ProviderAdmin(admin.ModelAdmin): - list_display = [ - "type", - "provider", - "product", - "param", - "cost", - "unit", - "notes", - "created_at", - "last_updated", - "updated_by", - "pricing_url", - ] diff --git a/costs/cost_utils.py b/costs/cost_utils.py deleted file mode 100644 index 25f6ec28b..000000000 --- a/costs/cost_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -from costs.models import UsageCost, ProviderPricing -from bots.models import SavedRun - - -def get_provider_pricing( - type: str, - provider: str, - product: str, - param: str, -) -> ProviderPricing: - print("get_provider_pricing", type, provider, product, param) - return ProviderPricing.objects.get( - type=type, - provider=provider, - product=product, - param=param, - ) - - -def record_cost( - run_id: str | None, - uid: str | None, - provider_pricing: ProviderPricing, - quantity: int, -) -> UsageCost: - saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) - cost = UsageCost( - saved_run=saved_run, - provider_pricing=provider_pricing, - quantity=quantity, - notes="", - dollar_amt=provider_pricing.cost * quantity, - created_at=saved_run.created_at, - ) - cost.save() - return cost diff --git a/costs/migrations/0001_initial.py b/costs/migrations/0001_initial.py deleted file mode 100644 index d3f6d1948..000000000 --- a/costs/migrations/0001_initial.py +++ /dev/null @@ -1,59 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-04 03:59 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - ("bots", "0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more"), - ] - - operations = [ - migrations.CreateModel( - name="UsageCost", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ( - "provider", - models.IntegerField( - choices=[ - (1, "OpenAI"), - (2, "GPT-4"), - (3, "dalle-e"), - (4, "whisper"), - (5, "GPT-3.5"), - ] - ), - ), - ("model", models.TextField(blank=True, verbose_name="model")), - ("param", models.TextField(blank=True, verbose_name="param")), - ("notes", models.TextField(blank=True, default="")), - ("calculation_notes", models.TextField(blank=True, default="")), - ("dollar_amt", models.DecimalField(decimal_places=8, max_digits=13)), - ("created_at", models.DateTimeField(auto_now_add=True)), - ( - "saved_run", - models.ForeignKey( - blank=True, - default=None, - help_text="The run that was last saved by the user.", - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="usage_costs", - to="bots.savedrun", - ), - ), - ], - ), - ] diff --git a/costs/migrations/0002_providerpricing.py b/costs/migrations/0002_providerpricing.py deleted file mode 100644 index 6b0bb17a1..000000000 --- a/costs/migrations/0002_providerpricing.py +++ /dev/null @@ -1,79 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-04 04:00 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="ProviderPricing", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("type", models.TextField(choices=[("LLM", "Llm")])), - ( - "provider", - models.IntegerField( - choices=[ - (1, "OpenAI"), - (2, "GPT-4"), - (3, "dalle-e"), - (4, "whisper"), - (5, "GPT-3.5"), - ] - ), - ), - ( - "product", - models.TextField( - choices=[ - ("GPT-4 Vision (openai)", "gpt_4_vision"), - ("GPT-4 Turbo (openai)", "gpt_4_turbo"), - ("GPT-4 (openai)", "gpt_4"), - ("GPT-4 32K (openai)", "gpt_4_32k"), - ("ChatGPT (openai)", "gpt_3_5_turbo"), - ("ChatGPT 16k (openai)", "gpt_3_5_turbo_16k"), - ("Llama 2 (Meta AI)", "llama2_70b_chat"), - ("PaLM 2 Text (Google)", "palm2_chat"), - ("PaLM 2 Chat (Google)", "palm2_text"), - ("GPT-3.5 Davinci-3 (openai)", "text_davinci_003"), - ("GPT-3.5 Davinci-2 (openai)", "text_davinci_002"), - ("Curie (openai)", "text_curie_001"), - ("Babbage (openai)", "text_babbage_001"), - ("Ada (openai)", "text_ada_001"), - ("Codex [Deprecated] (openai)", "code_davinci_002"), - ] - ), - ), - ( - "param", - models.TextField( - choices=[ - ("input", "Input"), - ("output", "Output"), - ("input image", "Input Image"), - ("output image", "Output Image"), - ] - ), - ), - ("cost", models.DecimalField(decimal_places=8, max_digits=13)), - ("unit", models.TextField(default="")), - ("notes", models.TextField(default="")), - ("created_at", models.DateTimeField(auto_now_add=True)), - ("last_updated", models.DateTimeField(auto_now=True)), - ("updated_by", models.TextField(default="")), - ("pricing_url", models.TextField(default="")), - ], - ), - ] diff --git a/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py b/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py deleted file mode 100644 index 4aa89cadd..000000000 --- a/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py +++ /dev/null @@ -1,64 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-05 06:02 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0002_providerpricing"), - ] - - operations = [ - migrations.RemoveField( - model_name="usagecost", - name="model", - ), - migrations.RemoveField( - model_name="usagecost", - name="param", - ), - migrations.RemoveField( - model_name="usagecost", - name="provider", - ), - migrations.AddField( - model_name="usagecost", - name="provider_pricing", - field=models.ForeignKey( - default=None, - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="usage_costs", - to="costs.providerpricing", - ), - ), - migrations.AddField( - model_name="usagecost", - name="quantity", - field=models.IntegerField(default=1), - ), - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("GPT-4 Vision (openai)", "gpt_4_vision"), - ("GPT-4 Turbo (openai)", "gpt_4_turbo"), - ("GPT-4 (openai)", "gpt_4"), - ("GPT-4 32K (openai)", "gpt_4_32k"), - ("ChatGPT (openai)", "gpt_3_5_turbo"), - ("ChatGPT 16k (openai)", "gpt_3_5_turbo_16k"), - ("Llama 2 (Meta AI)", "llama2_70b_chat"), - ("PaLM 2 Chat (Google)", "palm2_chat"), - ("PaLM 2 Text (Google)", "palm2_text"), - ("GPT-3.5 Davinci-3 (openai)", "text_davinci_003"), - ("GPT-3.5 Davinci-2 (openai)", "text_davinci_002"), - ("Curie (openai)", "text_curie_001"), - ("Babbage (openai)", "text_babbage_001"), - ("Ada (openai)", "text_ada_001"), - ("Codex [Deprecated] (openai)", "code_davinci_002"), - ] - ), - ), - ] diff --git a/costs/migrations/0004_remove_usagecost_saved_run.py b/costs/migrations/0004_remove_usagecost_saved_run.py deleted file mode 100644 index 2e68a519e..000000000 --- a/costs/migrations/0004_remove_usagecost_saved_run.py +++ /dev/null @@ -1,16 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-05 07:33 - -from django.db import migrations - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0003_remove_usagecost_model_remove_usagecost_param_and_more"), - ] - - operations = [ - migrations.RemoveField( - model_name="usagecost", - name="saved_run", - ), - ] diff --git a/costs/migrations/0005_remove_usagecost_provider_pricing.py b/costs/migrations/0005_remove_usagecost_provider_pricing.py deleted file mode 100644 index 2ff682ab0..000000000 --- a/costs/migrations/0005_remove_usagecost_provider_pricing.py +++ /dev/null @@ -1,16 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-05 08:06 - -from django.db import migrations - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0004_remove_usagecost_saved_run"), - ] - - operations = [ - migrations.RemoveField( - model_name="usagecost", - name="provider_pricing", - ), - ] diff --git a/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py b/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py deleted file mode 100644 index ba207c963..000000000 --- a/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py +++ /dev/null @@ -1,38 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-05 08:06 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - dependencies = [ - ("bots", "0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more"), - ("costs", "0005_remove_usagecost_provider_pricing"), - ] - - operations = [ - migrations.AddField( - model_name="usagecost", - name="provider_pricing", - field=models.ForeignKey( - default=None, - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="usage_costs", - to="costs.providerpricing", - ), - ), - migrations.AddField( - model_name="usagecost", - name="saved_run", - field=models.ForeignKey( - blank=True, - default=None, - help_text="The run that was last saved by the user.", - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="usage_costs", - to="bots.savedrun", - ), - ), - ] diff --git a/costs/migrations/0007_alter_providerpricing_provider.py b/costs/migrations/0007_alter_providerpricing_provider.py deleted file mode 100644 index e6220e391..000000000 --- a/costs/migrations/0007_alter_providerpricing_provider.py +++ /dev/null @@ -1,25 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-06 00:51 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0006_usagecost_provider_pricing_usagecost_saved_run"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="provider", - field=models.TextField( - choices=[ - ("OpenAI", "Openai"), - ("GPT-4", "Gpt 4"), - ("dalle-e", "Dalle E"), - ("whisper", "Whisper"), - ("GPT-3.5", "Gpt 3 5"), - ] - ), - ), - ] diff --git a/costs/migrations/0008_alter_providerpricing_param_and_more.py b/costs/migrations/0008_alter_providerpricing_param_and_more.py deleted file mode 100644 index bcc9a0030..000000000 --- a/costs/migrations/0008_alter_providerpricing_param_and_more.py +++ /dev/null @@ -1,33 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-06 01:10 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0007_alter_providerpricing_provider"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="param", - field=models.TextField(choices=[("Input", "input"), ("Output", "output")]), - ), - migrations.AlterField( - model_name="providerpricing", - name="provider", - field=models.TextField( - choices=[ - ("vertex_ai", "Vertex AI"), - ("openai", "OpenAI"), - ("together", "Together"), - ] - ), - ), - migrations.AlterField( - model_name="providerpricing", - name="type", - field=models.TextField(choices=[("LLM", "LLM")]), - ), - ] diff --git a/costs/migrations/0009_alter_providerpricing_product.py b/costs/migrations/0009_alter_providerpricing_product.py deleted file mode 100644 index 5e5a51019..000000000 --- a/costs/migrations/0009_alter_providerpricing_product.py +++ /dev/null @@ -1,19 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-06 01:28 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0008_alter_providerpricing_param_and_more"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[("gpt-4-vision-preview", "gpt-4-vision-preview")] - ), - ), - ] diff --git a/costs/migrations/0010_remove_usagecost_calculation_notes.py b/costs/migrations/0010_remove_usagecost_calculation_notes.py deleted file mode 100644 index 3eba7ace1..000000000 --- a/costs/migrations/0010_remove_usagecost_calculation_notes.py +++ /dev/null @@ -1,16 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-06 02:02 - -from django.db import migrations - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0009_alter_providerpricing_product"), - ] - - operations = [ - migrations.RemoveField( - model_name="usagecost", - name="calculation_notes", - ), - ] diff --git a/costs/migrations/0011_alter_providerpricing_product.py b/costs/migrations/0011_alter_providerpricing_product.py deleted file mode 100644 index b49d3f308..000000000 --- a/costs/migrations/0011_alter_providerpricing_product.py +++ /dev/null @@ -1,35 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 08:50 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0010_remove_usagecost_calculation_notes"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ("gpt-4-1106-preview", "gpt-4-1106-preview"), - ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/migrations/0012_alter_providerpricing_product.py b/costs/migrations/0012_alter_providerpricing_product.py deleted file mode 100644 index 3bc92c01c..000000000 --- a/costs/migrations/0012_alter_providerpricing_product.py +++ /dev/null @@ -1,37 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 16:45 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0011_alter_providerpricing_product"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ("openai-gpt-4-turbo-prod-ca-1", "openai-gpt-4-turbo-prod-ca-1"), - ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), - ("gpt-3.5-turbo", "gpt-3.5-turbo"), - ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/migrations/0013_alter_providerpricing_product.py b/costs/migrations/0013_alter_providerpricing_product.py deleted file mode 100644 index 9449fbf91..000000000 --- a/costs/migrations/0013_alter_providerpricing_product.py +++ /dev/null @@ -1,37 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 16:50 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0012_alter_providerpricing_product"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ("openai-gpt-4-turbo-prod-ca-1", "gpt-4-1106-preview"), - ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), - ("gpt-3.5-turbo", "gpt-3.5-turbo"), - ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/migrations/0014_alter_providerpricing_product.py b/costs/migrations/0014_alter_providerpricing_product.py deleted file mode 100644 index 22b5439ed..000000000 --- a/costs/migrations/0014_alter_providerpricing_product.py +++ /dev/null @@ -1,40 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 16:58 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0013_alter_providerpricing_product"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ( - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - ), - ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), - ("gpt-3.5-turbo", "gpt-3.5-turbo"), - ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/migrations/0015_alter_providerpricing_product.py b/costs/migrations/0015_alter_providerpricing_product.py deleted file mode 100644 index 2658de04f..000000000 --- a/costs/migrations/0015_alter_providerpricing_product.py +++ /dev/null @@ -1,43 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 17:02 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0014_alter_providerpricing_product"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ( - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - ), - ( - "('openai-gpt-4-prod-ca-1', 'gpt-4')", - "('openai-gpt-4-prod-ca-1', 'gpt-4')", - ), - ("gpt-3.5-turbo", "gpt-3.5-turbo"), - ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/models.py b/costs/models.py deleted file mode 100644 index 18700537f..000000000 --- a/costs/models.py +++ /dev/null @@ -1,113 +0,0 @@ -from django.db import models -from daras_ai_v2.language_model import LLMApis - - -class Product(models.TextChoices): - gpt_4_vision = ( - "gpt-4-vision-preview", - "gpt-4-vision-preview", - ) - gpt_4_turbo = ( - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - ) - - gpt_4_32k = ( - "('openai-gpt-4-prod-ca-1', 'gpt-4')", - "('openai-gpt-4-prod-ca-1', 'gpt-4')", - ) - gpt_3_5_turbo = ( - "gpt-3.5-turbo", - "gpt-3.5-turbo", - ) - gpt_3_5_turbo_16k = ( - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-16k", - ) - text_davinci_003 = ( - "text-davinci-003", - "text-davinci-003", - ) - text_davinci_002 = ( - "text-davinci-002", - "text-davinci-002", - ) - code_davinci_002 = ( - "code-davinci-002", - "code-davinci-002", - ) - text_curie_001 = ( - "text-curie-001", - "text-curie-001", - ) - text_babbage_001 = ( - "text-babbage-001", - "text-babbage-001", - ) - text_ada_001 = ( - "text-ada-001", - "text-ada-001", - ) - palm2_text = ( - "text-bison", - "text-bison", - ) - palm2_chat = ( - "chat-bison", - "chat-bison", - ) - llama2_70b_chat = ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ) - - -class UsageCost(models.Model): - saved_run = models.ForeignKey( - "bots.SavedRun", - on_delete=models.CASCADE, - related_name="usage_costs", - null=True, - default=None, - blank=True, - help_text="The run that was last saved by the user.", - ) - - provider_pricing = models.ForeignKey( - "costs.ProviderPricing", - on_delete=models.CASCADE, - related_name="usage_costs", - null=True, - default=None, - ) - quantity = models.IntegerField(default=1) - notes = models.TextField(default="", blank=True) - dollar_amt = models.DecimalField( - max_digits=13, - decimal_places=8, - ) - created_at = models.DateTimeField(auto_now_add=True) - - -class ProviderPricing(models.Model): - class Group(models.TextChoices): # change to different name than type - LLM = "LLM", "LLM" - - class Param(models.TextChoices): - input = "Input", "input" - output = "Output", "output" - - type = models.TextField(choices=Group.choices) - provider = models.TextField(choices=LLMApis.choices()) - product = models.TextField(choices=Product.choices) - param = models.TextField(choices=Param.choices) - cost = models.DecimalField(max_digits=13, decimal_places=8) - unit = models.TextField(default="") - notes = models.TextField(default="") - created_at = models.DateTimeField(auto_now_add=True) - last_updated = models.DateTimeField(auto_now=True) - updated_by = models.TextField(default="") - pricing_url = models.TextField(default="") - - def __str__(self): - return self.type + " " + self.provider + " " + self.product + " " + self.param diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 58a26f9b6..fa8b48d86 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -4,7 +4,7 @@ import re import typing from enum import Enum -from functools import partial +from functools import partial, wraps import numpy as np import requests @@ -112,11 +112,11 @@ def is_chat_model(self) -> bool: LargeLanguageModels.gpt_4_32k: "openai-gpt-4-32k-prod-ca-1", LargeLanguageModels.gpt_3_5_turbo: ( "openai-gpt-35-turbo-prod-ca-1", - "gpt-3.5-turbo", + "gpt-3.5-turbo-0613", ), LargeLanguageModels.gpt_3_5_turbo_16k: ( "openai-gpt-35-turbo-16k-prod-ca-1", - "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-16k-0613", ), LargeLanguageModels.text_davinci_003: "text-davinci-003", LargeLanguageModels.text_davinci_002: "text-davinci-002", @@ -378,34 +378,6 @@ def run_language_model( stream=stream and not (tools or response_format_type), ) - ## incorportate down the call chain - # provider_pricing_in = get_provider_pricing( - # type="LLM", - # provider=api.name, - # product=model_name, - # param="Input", - # ) - # - # provider_pricing_out = get_provider_pricing( - # type="LLM", - # provider=api.name, - # product=model_name, - # param="Output", - # ) - # example_id, run_id, uid = extract_query_params(gooey_get_query_params()) - # record_cost( - # run_id=run_id, - # uid=uid, - # provider_pricing=provider_pricing_in, - # quantity=input_token, - # ) - # record_cost( - # run_id=run_id, - # uid=uid, - # provider_pricing=provider_pricing_out, - # quantity=output_token, - # ) - if stream: return _stream_llm_outputs(entries, response_format_type) else: @@ -588,10 +560,9 @@ def _run_openai_chat( presence_penalty = 0 if isinstance(model, str): model = [model] - r = try_all( + r, used_model = try_all( *[ - partial( - _get_openai_client(model_str).chat.completions.create, + _get_chat_completions_create( model=model_str, messages=messages, max_tokens=max_tokens, @@ -612,15 +583,39 @@ def _run_openai_chat( ], ) if stream: - return _stream_openai_chunked(r) + return _stream_openai_chunked(r, used_model, messages) else: - r.usage.completion_tokens - r.usage.prompt_tokens + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=used_model, + sku=ModelSku.llm_prompt, + quantity=r.usage.prompt_tokens, + ) + record_cost_auto( + model=used_model, + sku=ModelSku.llm_completion, + quantity=r.usage.completion_tokens, + ) return [choice.message.dict() for choice in r.choices] +def _get_chat_completions_create(model: str, **kwargs): + client = _get_openai_client(model) + + @wraps(client.chat.completions.create) + def wrapper(): + return client.chat.completions.create(model=model, **kwargs), model + + return wrapper + + def _stream_openai_chunked( r: Stream[ChatCompletionChunk], + used_model: str, + messages: list[ConversationEntry], + *, start_chunk_size: int = 50, stop_chunk_size: int = 400, step_chunk_size: int = 150, @@ -679,6 +674,20 @@ def _stream_openai_chunked( entry["content"] += entry["chunk"] yield ret + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=used_model, + sku=ModelSku.llm_prompt, + quantity=sum(default_length_function(entry["content"]) for entry in messages), + ) + record_cost_auto( + model=used_model, + sku=ModelSku.llm_completion, + quantity=sum(default_length_function(entry["content"]) for entry in ret), + ) + @retry_if(openai_should_retry) def _run_openai_text( @@ -702,8 +711,21 @@ def _run_openai_text( frequency_penalty=0.1 if avoid_repetition else 0, presence_penalty=0.25 if avoid_repetition else 0, ) - r.usage.completion_tokens, - r.usage.prompt_tokens, + + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=model, + sku=ModelSku.llm_prompt, + quantity=r.usage.prompt_tokens, + ) + record_cost_auto( + model=model, + sku=ModelSku.llm_completion, + quantity=r.usage.completion_tokens, + ) + return [choice.text for choice in r.choices] @@ -761,15 +783,13 @@ def _run_together_chat( range(num_outputs), ) ret = [] - total_out_tokens = 0 - total_in_tokens = 0 + prompt_tokens = 0 + completion_tokens = 0 for r in results: raise_for_status(r) data = r.json() output = data["output"] error = output.get("error") - total_out_tokens += output.get("usage", {}).get("completion_tokens", 0) - total_in_tokens += output.get("usage", {}).get("prompt_tokens", 0) if error: raise ValueError(error) ret.append( @@ -778,6 +798,21 @@ def _run_together_chat( "content": output["choices"][0]["text"], } ) + prompt_tokens += output.get("usage", {}).get("prompt_tokens", 0) + completion_tokens += output.get("usage", {}).get("completion_tokens", 0) + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=model, + sku=ModelSku.llm_prompt, + quantity=prompt_tokens, + ) + record_cost_auto( + model=model, + sku=ModelSku.llm_completion, + quantity=completion_tokens, + ) return ret @@ -828,16 +863,28 @@ def _run_palm_chat( }, ) raise_for_status(r) - - r.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"] - r.json()["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"] + out = r.json() + + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=model_id, + sku=ModelSku.llm_prompt, + quantity=out["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"], + ) + record_cost_auto( + model=model_id, + sku=ModelSku.llm_completion, + quantity=out["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + ) return [ { "role": msg["author"], "content": msg["content"], } - for pred in r.json()["predictions"] + for pred in out["predictions"] for msg in pred["candidates"] ] @@ -876,11 +923,23 @@ def _run_palm_text( }, ) raise_for_status(res) + out = res.json() + + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku - res.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"] - res.json()["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"] + record_cost_auto( + model=model_id, + sku=ModelSku.llm_prompt, + quantity=out["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"], + ) + record_cost_auto( + model=model_id, + sku=ModelSku.llm_completion, + quantity=out["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + ) - return [prediction["content"] for prediction in res.json()["predictions"]] + return [prediction["content"] for prediction in out["predictions"]] def format_chatml_message(entry: ConversationEntry) -> str: diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 9df4d7673..deda81500 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -55,7 +55,7 @@ "files", "url_shortener", "glossary_resources", - "costs", + "usage_costs", ] MIDDLEWARE = [ diff --git a/scripts/init_llm_pricing.py b/scripts/init_llm_pricing.py new file mode 100644 index 000000000..a1f55e964 --- /dev/null +++ b/scripts/init_llm_pricing.py @@ -0,0 +1,338 @@ +from daras_ai_v2.language_model import LargeLanguageModels +from usage_costs.models import ModelSku, ModelCategory, ModelProvider, ModelPricing + + +def run(): + category = ModelCategory.LLM + + # GPT-4-Turbo + + for model in ["gpt-4-0125-preview", "gpt-4-1106-preview"]: + ModelPricing.objects.create( + model_id=model, + model_name=LargeLanguageModels.gpt_4_turbo.name, + sku=ModelSku.llm_prompt, + unit_cost=0.01, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id=model, + model_name=LargeLanguageModels.gpt_4_turbo.name, + sku=ModelSku.llm_completion, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-4-turbo-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_turbo.name, + sku=ModelSku.llm_prompt, + unit_cost=0.01, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-4-turbo-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_turbo.name, + sku=ModelSku.llm_completion, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-4-Turbo-Vision + + ModelPricing.objects.create( + model_id="gpt-4-vision-preview", + model_name=LargeLanguageModels.gpt_4_vision.name, + sku=ModelSku.llm_prompt, + unit_cost=0.01, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-4-vision-preview", + model_name=LargeLanguageModels.gpt_4_vision.name, + sku=ModelSku.llm_completion, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-4-turbo-vision-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_vision.name, + sku=ModelSku.llm_prompt, + unit_cost=0.01, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-4-turbo-vision-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_vision.name, + sku=ModelSku.llm_completion, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-4 + + ModelPricing.objects.create( + model_id="gpt-4", + model_name=LargeLanguageModels.gpt_4.name, + sku=ModelSku.llm_prompt, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-4", + model_name=LargeLanguageModels.gpt_4.name, + sku=ModelSku.llm_completion, + unit_cost=0.06, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-4-prod-ca-1", + model_name=LargeLanguageModels.gpt_4.name, + sku=ModelSku.llm_prompt, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-4-prod-ca-1", + model_name=LargeLanguageModels.gpt_4.name, + sku=ModelSku.llm_completion, + unit_cost=0.06, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-4-32k + + ModelPricing.objects.create( + model_id="gpt-4-32k", + model_name=LargeLanguageModels.gpt_4_32k.name, + sku=ModelSku.llm_prompt, + unit_cost=0.06, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-4-32k", + model_name=LargeLanguageModels.gpt_4_32k.name, + sku=ModelSku.llm_completion, + unit_cost=0.12, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-4-32k-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_32k.name, + sku=ModelSku.llm_prompt, + unit_cost=0.06, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-4-32k-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_32k.name, + sku=ModelSku.llm_completion, + unit_cost=0.12, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-3.5-Turbo + + ModelPricing.objects.create( + model_id="gpt-3.5-turbo-0613", + model_name=LargeLanguageModels.gpt_3_5_turbo.name, + sku=ModelSku.llm_prompt, + unit_cost=0.0015, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-3.5-turbo-0613", + model_name=LargeLanguageModels.gpt_3_5_turbo.name, + sku=ModelSku.llm_completion, + unit_cost=0.002, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-35-turbo-prod-ca-1", + model_name=LargeLanguageModels.gpt_3_5_turbo.name, + sku=ModelSku.llm_prompt, + unit_cost=0.0015, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-35-turbo-prod-ca-1", + model_name=LargeLanguageModels.gpt_3_5_turbo.name, + sku=ModelSku.llm_completion, + unit_cost=0.002, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-3.5-Turbo-16k + + ModelPricing.objects.create( + model_id="gpt-3.5-turbo-16k-0613", + model_name=LargeLanguageModels.gpt_3_5_turbo_16k.name, + sku=ModelSku.llm_prompt, + unit_cost=0.003, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-3.5-turbo-16k-0613", + model_name=LargeLanguageModels.gpt_3_5_turbo_16k.name, + sku=ModelSku.llm_completion, + unit_cost=0.004, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-35-turbo-16k-prod-ca-1", + model_name=LargeLanguageModels.gpt_3_5_turbo_16k.name, + sku=ModelSku.llm_prompt, + unit_cost=0.003, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-35-turbo-16k-prod-ca-1", + model_name=LargeLanguageModels.gpt_3_5_turbo_16k.name, + sku=ModelSku.llm_completion, + unit_cost=0.004, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # Palm2 + + ModelPricing.objects.create( + model_id="text-bison", + model_name=LargeLanguageModels.palm2_text.name, + sku=ModelSku.llm_prompt, + unit_cost=0.00025, + unit_quantity=1000, + category=category, + provider=ModelProvider.google, + pricing_url="https://cloud.google.com/vertex-ai/docs/generative-ai/pricing#text_generation", + ) + ModelPricing.objects.create( + model_id="text-bison", + model_name=LargeLanguageModels.palm2_text.name, + sku=ModelSku.llm_completion, + unit_cost=0.0005, + unit_quantity=1000, + category=category, + provider=ModelProvider.google, + pricing_url="https://cloud.google.com/vertex-ai/docs/generative-ai/pricing#text_generation", + ) + + ModelPricing.objects.create( + model_id="chat-bison", + model_name=LargeLanguageModels.palm2_chat.name, + sku=ModelSku.llm_prompt, + unit_cost=0.00025, + unit_quantity=1000, + category=category, + provider=ModelProvider.google, + pricing_url="https://cloud.google.com/vertex-ai/docs/generative-ai/pricing#text_generation", + ) + ModelPricing.objects.create( + model_id="chat-bison", + model_name=LargeLanguageModels.palm2_chat.name, + sku=ModelSku.llm_completion, + unit_cost=0.0005, + unit_quantity=1000, + category=category, + provider=ModelProvider.google, + pricing_url="https://cloud.google.com/vertex-ai/docs/generative-ai/pricing#text_generation", + ) + + # Llama2 + + ModelPricing.objects.create( + model_id="togethercomputer/llama-2-70b-chat", + model_name=LargeLanguageModels.llama2_70b_chat.name, + sku=ModelSku.llm_prompt, + unit_cost=0.9, + unit_quantity=10**6, + category=category, + provider=ModelProvider.together_ai, + pricing_url="https://www.together.ai/pricing", + ) + ModelPricing.objects.create( + model_id="togethercomputer/llama-2-70b-chat", + model_name=LargeLanguageModels.llama2_70b_chat.name, + sku=ModelSku.llm_completion, + unit_cost=0.9, + unit_quantity=10**6, + category=category, + provider=ModelProvider.together_ai, + pricing_url="https://www.together.ai/pricing", + ) diff --git a/costs/__init__.py b/usage_costs/__init__.py similarity index 100% rename from costs/__init__.py rename to usage_costs/__init__.py diff --git a/usage_costs/admin.py b/usage_costs/admin.py new file mode 100644 index 000000000..611624ab1 --- /dev/null +++ b/usage_costs/admin.py @@ -0,0 +1,75 @@ +from decimal import Decimal + +from django.contrib import admin + +from bots.admin_links import open_in_new_tab, change_obj_url +from usage_costs import models + + +class CostQtyMixin: + @admin.display(description="Cost / Qty", ordering="unit_cost") + def cost_qty(self, obj): + return f"${obj.unit_cost.normalize()} / {obj.unit_quantity}" + + +@admin.register(models.UsageCost) +class UsageCostAdmin(admin.ModelAdmin, CostQtyMixin): + list_display = [ + "__str__", + "cost_qty", + "quantity", + "display_dollar_amount", + "view_pricing", + "view_saved_run", + "notes", + "created_at", + ] + autocomplete_fields = ["saved_run", "pricing"] + list_filter = [ + "saved_run__workflow", + "pricing__category", + "pricing__provider", + "pricing__model_name", + "pricing__sku", + "created_at", + ] + search_fields = ["saved_run", "pricing"] + readonly_fields = ["created_at"] + ordering = ["-created_at"] + + @admin.display(description="Amount", ordering="dollar_amount") + def display_dollar_amount(self, obj): + return f"${obj.dollar_amount.normalize()}" + + @admin.display(description="Saved Run", ordering="saved_run") + def view_saved_run(self, obj): + return change_obj_url( + obj.saved_run, + label=f"{obj.saved_run.get_workflow_display()} - {obj.saved_run} ({obj.saved_run.run_id})", + ) + + @admin.display(description="Pricing", ordering="pricing") + def view_pricing(self, obj): + return change_obj_url(obj.pricing) + + +@admin.register(models.ModelPricing) +class ModelPricingAdmin(admin.ModelAdmin, CostQtyMixin): + list_display = [ + "__str__", + "cost_qty", + "model_id", + "category", + "provider", + "model_name", + "sku", + "view_pricing_url", + "notes", + ] + list_filter = ["category", "provider", "model_name", "sku"] + search_fields = ["category", "provider", "model_name", "sku", "model_id"] + readonly_fields = ["created_at", "updated_at"] + + @admin.display(description="Pricing URL", ordering="pricing_url") + def view_pricing_url(self, obj): + return open_in_new_tab(obj.pricing_url, label=obj.pricing_url) diff --git a/costs/apps.py b/usage_costs/apps.py similarity index 67% rename from costs/apps.py rename to usage_costs/apps.py index 55bed0ca1..b807e363b 100644 --- a/costs/apps.py +++ b/usage_costs/apps.py @@ -3,5 +3,5 @@ class CostsConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" - name = "costs" - verbose_name = "Costs" + name = "usage_costs" + verbose_name = "Usage Costs" diff --git a/usage_costs/cost_utils.py b/usage_costs/cost_utils.py new file mode 100644 index 000000000..ea622e06a --- /dev/null +++ b/usage_costs/cost_utils.py @@ -0,0 +1,34 @@ +from loguru import logger + +from daras_ai_v2.query_params import gooey_get_query_params +from daras_ai_v2.query_params_util import extract_query_params +from usage_costs.models import ( + UsageCost, + ModelSku, + ModelPricing, +) + + +def record_cost_auto(model: str, sku: ModelSku, quantity: int): + from bots.models import SavedRun + + _, run_id, uid = extract_query_params(gooey_get_query_params()) + if not run_id or not uid: + return + + try: + pricing = ModelPricing.objects.get(model_id=model, sku=sku) + except ModelPricing.DoesNotExist as e: + logger.warning(f"Cant find pricing for {model=} {sku=}: {e=}") + return + + saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) + + UsageCost.objects.create( + saved_run=saved_run, + pricing=pricing, + quantity=quantity, + unit_cost=pricing.unit_cost, + unit_quantity=pricing.unit_quantity, + dollar_amount=pricing.unit_cost * quantity / pricing.unit_quantity, + ) diff --git a/usage_costs/migrations/0001_initial.py b/usage_costs/migrations/0001_initial.py new file mode 100644 index 000000000..bb02723f7 --- /dev/null +++ b/usage_costs/migrations/0001_initial.py @@ -0,0 +1,56 @@ +# Generated by Django 4.2.7 on 2024-02-06 18:22 + +import bots.custom_fields +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('bots', '0057_message_response_time_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='ModelPricing', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('model_id', models.TextField(help_text='The model ID. Model ID + SKU should be unique together.')), + ('sku', models.IntegerField(choices=[(1, 'LLM Prompt'), (2, 'LLM Completion')], help_text="The model's SKU. Model ID + SKU should be unique together.")), + ('unit_cost', models.DecimalField(decimal_places=10, help_text='The cost per unit.', max_digits=15)), + ('unit_quantity', models.PositiveIntegerField(help_text='The quantity of the unit. (e.g. 1000 tokens)')), + ('category', models.IntegerField(choices=[(1, 'LLM')], help_text='The category of the model. Only used for Display purposes.')), + ('provider', models.IntegerField(choices=[(1, 'OpenAI'), (2, 'Google'), (3, 'TogetherAI'), (4, 'Azure OpenAI')], help_text='The provider of the model. Only used for Display purposes.')), + ('model_name', models.CharField(choices=[('gpt_4_vision', 'GPT-4 Vision (openai)'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai)'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('llama2_70b_chat', 'Llama 2 (Meta AI)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 (openai)'), ('text_curie_001', 'Curie (openai)'), ('text_babbage_001', 'Babbage (openai)'), ('text_ada_001', 'Ada (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)')], help_text='The name of the model. Only used for Display purposes.', max_length=255)), + ('notes', models.TextField(blank=True, default='', help_text='Any notes about the pricing. (e.g. how the pricing was calculated)')), + ('pricing_url', bots.custom_fields.CustomURLField(blank=True, default='', help_text='The URL of the pricing.', max_length=2048)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ], + ), + migrations.CreateModel( + name='UsageCost', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('quantity', models.PositiveIntegerField()), + ('unit_cost', models.DecimalField(decimal_places=10, help_text='The cost per unit, recorded at the time of the usage.', max_digits=15)), + ('unit_quantity', models.PositiveIntegerField(help_text='The quantity of the unit (e.g. 1000 tokens), recorded at the time of the usage.')), + ('dollar_amount', models.DecimalField(decimal_places=10, help_text='The dollar amount, calculated as unit_cost x quantity / unit_quantity.', max_digits=15)), + ('notes', models.TextField(blank=True, default='')), + ('created_at', models.DateTimeField(auto_now_add=True, db_index=True)), + ('pricing', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='usage_costs', to='usage_costs.modelpricing')), + ('saved_run', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='usage_costs', to='bots.savedrun')), + ], + ), + migrations.AddIndex( + model_name='modelpricing', + index=models.Index(fields=['model_id', 'sku'], name='usage_costs_model_i_d7b80a_idx'), + ), + migrations.AlterUniqueTogether( + name='modelpricing', + unique_together={('model_id', 'sku')}, + ), + ] diff --git a/costs/migrations/__init__.py b/usage_costs/migrations/__init__.py similarity index 100% rename from costs/migrations/__init__.py rename to usage_costs/migrations/__init__.py diff --git a/usage_costs/models.py b/usage_costs/models.py new file mode 100644 index 000000000..67b5c69e0 --- /dev/null +++ b/usage_costs/models.py @@ -0,0 +1,114 @@ +from django.db import models + +from bots.custom_fields import CustomURLField + +max_digits = 15 +decimal_places = 10 + + +class UsageCost(models.Model): + saved_run = models.ForeignKey( + "bots.SavedRun", on_delete=models.CASCADE, related_name="usage_costs" + ) + pricing = models.ForeignKey( + "usage_costs.ModelPricing", on_delete=models.CASCADE, related_name="usage_costs" + ) + + quantity = models.PositiveIntegerField() + unit_cost = models.DecimalField( + max_digits=max_digits, + decimal_places=decimal_places, + help_text="The cost per unit, recorded at the time of the usage.", + ) + unit_quantity = models.PositiveIntegerField( + help_text="The quantity of the unit (e.g. 1000 tokens), recorded at the time of the usage." + ) + + dollar_amount = models.DecimalField( + max_digits=max_digits, + decimal_places=decimal_places, + help_text="The dollar amount, calculated as unit_cost x quantity / unit_quantity.", + ) + + notes = models.TextField(default="", blank=True) + + created_at = models.DateTimeField(auto_now_add=True, db_index=True) + + def __str__(self): + return f"{self.saved_run} - {self.pricing} - {self.quantity}" + + +class ModelCategory(models.IntegerChoices): + LLM = 1, "LLM" + + +class ModelProvider(models.IntegerChoices): + openai = 1, "OpenAI" + google = 2, "Google" + together_ai = 3, "TogetherAI" + azure_openai = 4, "Azure OpenAI" + + +def get_model_choices(): + from daras_ai_v2.language_model import LargeLanguageModels + + return [(api.name, api.value) for api in LargeLanguageModels] + + +class ModelSku(models.IntegerChoices): + llm_prompt = 1, "LLM Prompt" + llm_completion = 2, "LLM Completion" + + +class ModelPricing(models.Model): + model_id = models.TextField( + help_text="The model ID. Model ID + SKU should be unique together." + ) + sku = models.IntegerField( + choices=ModelSku.choices, + help_text="The model's SKU. Model ID + SKU should be unique together.", + ) + + unit_cost = models.DecimalField( + max_digits=max_digits, + decimal_places=decimal_places, + help_text="The cost per unit.", + ) + unit_quantity = models.PositiveIntegerField( + help_text="The quantity of the unit. (e.g. 1000 tokens)" + ) + + category = models.IntegerField( + choices=ModelCategory.choices, + help_text="The category of the model. Only used for Display purposes.", + ) + provider = models.IntegerField( + choices=ModelProvider.choices, + help_text="The provider of the model. Only used for Display purposes.", + ) + model_name = models.CharField( + max_length=255, + choices=get_model_choices(), + help_text="The name of the model. Only used for Display purposes.", + ) + + notes = models.TextField( + default="", + blank=True, + help_text="Any notes about the pricing. (e.g. how the pricing was calculated)", + ) + pricing_url = CustomURLField( + default="", blank=True, help_text="The URL of the pricing." + ) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + unique_together = ["model_id", "sku"] + indexes = [ + models.Index(fields=["model_id", "sku"]), + ] + + def __str__(self): + return f"{self.get_provider_display()} / {self.get_model_name_display()} / {self.get_sku_display()}" diff --git a/costs/tests.py b/usage_costs/tests.py similarity index 100% rename from costs/tests.py rename to usage_costs/tests.py diff --git a/costs/views.py b/usage_costs/views.py similarity index 100% rename from costs/views.py rename to usage_costs/views.py From 08d46ae8bd91f7efa45b5ea4d1ff6edfe0585f9a Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 00:06:33 +0530 Subject: [PATCH 58/85] fix gpt4-v crash --- daras_ai_v2/language_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index fa8b48d86..3d48b1068 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -680,7 +680,9 @@ def _stream_openai_chunked( record_cost_auto( model=used_model, sku=ModelSku.llm_prompt, - quantity=sum(default_length_function(entry["content"]) for entry in messages), + quantity=sum( + default_length_function(get_entry_text(entry)) for entry in messages + ), ) record_cost_auto( model=used_model, From 61b88c11253c6d46e3585477c0bb09c034fd2bb4 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 00:17:44 +0530 Subject: [PATCH 59/85] show usage cost in saved run admin --- bots/admin.py | 14 +++++++++++--- bots/admin_links.py | 3 +++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index 0b6b475cc..38424dc79 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -5,7 +5,7 @@ from django import forms from django.conf import settings from django.contrib import admin -from django.db.models import Max, Count, F +from django.db.models import Max, Count, F, Sum from django.template import loader from django.utils import dateformat from django.utils.safestring import mark_safe @@ -28,8 +28,6 @@ WorkflowMetadata, ) from bots.tasks import create_personal_channels_for_all_members -from daras_ai.image_input import truncate_text_words -from daras_ai_v2.base import BasePage from gooeysite.custom_actions import export_to_excel, export_to_csv from gooeysite.custom_filters import ( related_json_field_summary, @@ -278,6 +276,7 @@ class SavedRunAdmin(admin.ModelAdmin): "parent", "view_bots", "price", + "view_usage_cost", "transaction", "created_at", "updated_at", @@ -313,6 +312,15 @@ def view_parent_published_run(self, saved_run: SavedRun): pr = saved_run.parent_published_run() return pr and change_obj_url(pr) + @admin.display(description="Usage Costs") + def view_usage_cost(self, saved_run: SavedRun): + total_cost = saved_run.usage_costs.aggregate(total_cost=Sum("dollar_amount"))[ + "total_cost" + ] + return list_related_html_url( + saved_run.usage_costs, extra_label=f"${total_cost.normalize()}" + ) + @admin.register(PublishedRunVersion) class PublishedRunVersionAdmin(admin.ModelAdmin): diff --git a/bots/admin_links.py b/bots/admin_links.py index 94086e348..c06601953 100644 --- a/bots/admin_links.py +++ b/bots/admin_links.py @@ -35,6 +35,7 @@ def list_related_html_url( query_param: str = None, instance_id: int = None, show_add: bool = True, + extra_label: str = None, ) -> typing.Optional[str]: num = manager.all().count() @@ -60,6 +61,8 @@ def list_related_html_url( ).url label = f"{num} {meta.verbose_name if num == 1 else meta.verbose_name_plural}" + if extra_label: + label = f"{label} ({extra_label})" if show_add: add_related_url = furl( From 9899741ddbb425fee7424dbcc95e0e45aa086a86 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 00:22:17 +0530 Subject: [PATCH 60/85] increase run id size --- daras_ai_v2/crypto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daras_ai_v2/crypto.py b/daras_ai_v2/crypto.py index 30bef7f4f..ec0047bbd 100644 --- a/daras_ai_v2/crypto.py +++ b/daras_ai_v2/crypto.py @@ -65,7 +65,7 @@ def safe_preview(password: str) -> str: def get_random_doc_id() -> str: return get_random_string( - length=8, allowed_chars=string.ascii_lowercase + string.digits + length=12, allowed_chars=string.ascii_lowercase + string.digits ) From cc8c24ba00b0f8c125b9891dc10dea6d0f940ff1 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 21:54:08 +0530 Subject: [PATCH 61/85] record costs for all gpu tasks --- daras_ai_v2/gpu_server.py | 12 +++++- scripts/init_self_hosted_pricing.py | 2 + usage_costs/admin.py | 8 +++- ...02_alter_modelpricing_category_and_more.py | 38 +++++++++++++++++++ usage_costs/models.py | 12 +++++- 5 files changed, 67 insertions(+), 5 deletions(-) create mode 100644 scripts/init_self_hosted_pricing.py create mode 100644 usage_costs/migrations/0002_alter_modelpricing_category_and_more.py diff --git a/daras_ai_v2/gpu_server.py b/daras_ai_v2/gpu_server.py index fa638da89..3050cd98b 100644 --- a/daras_ai_v2/gpu_server.py +++ b/daras_ai_v2/gpu_server.py @@ -2,6 +2,7 @@ import datetime import os import typing +from time import time import requests from furl import furl @@ -138,6 +139,7 @@ def get_celery(): _app = Celery() _app.conf.broker_url = settings.GPU_CELERY_BROKER_URL _app.conf.result_backend = settings.GPU_CELERY_RESULT_BACKEND + _app.conf.result_extended = True return _app @@ -149,8 +151,16 @@ def call_celery_task( inputs: dict, queue_prefix: str = "gooey-gpu", ): + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + queue = os.path.join(queue_prefix, pipeline["model_id"].strip()).strip("/") result = get_celery().send_task( task_name, kwargs=dict(pipeline=pipeline, inputs=inputs), queue=queue ) - return result.get(disable_sync_subtasks=False) + s = time() + ret = result.get(disable_sync_subtasks=False) + record_cost_auto( + model=queue, sku=ModelSku.gpu_ms, quantity=int((time() - s) * 1000) + ) + return ret diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py new file mode 100644 index 000000000..d04d42367 --- /dev/null +++ b/scripts/init_self_hosted_pricing.py @@ -0,0 +1,2 @@ +def run(): + pass diff --git a/usage_costs/admin.py b/usage_costs/admin.py index 611624ab1..fbe367243 100644 --- a/usage_costs/admin.py +++ b/usage_costs/admin.py @@ -1,5 +1,3 @@ -from decimal import Decimal - from django.contrib import admin from bots.admin_links import open_in_new_tab, change_obj_url @@ -21,6 +19,7 @@ class UsageCostAdmin(admin.ModelAdmin, CostQtyMixin): "display_dollar_amount", "view_pricing", "view_saved_run", + "view_parent_published_run", "notes", "created_at", ] @@ -41,6 +40,11 @@ class UsageCostAdmin(admin.ModelAdmin, CostQtyMixin): def display_dollar_amount(self, obj): return f"${obj.dollar_amount.normalize()}" + @admin.display(description="Published Run") + def view_parent_published_run(self, obj): + pr = obj.saved_run.parent_published_run() + return pr and change_obj_url(pr) + @admin.display(description="Saved Run", ordering="saved_run") def view_saved_run(self, obj): return change_obj_url( diff --git a/usage_costs/migrations/0002_alter_modelpricing_category_and_more.py b/usage_costs/migrations/0002_alter_modelpricing_category_and_more.py new file mode 100644 index 000000000..6097175f6 --- /dev/null +++ b/usage_costs/migrations/0002_alter_modelpricing_category_and_more.py @@ -0,0 +1,38 @@ +# Generated by Django 4.2.7 on 2024-02-07 15:44 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('usage_costs', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='modelpricing', + name='category', + field=models.IntegerField(choices=[(1, 'LLM'), (2, 'Self-Hosted')], help_text='The category of the model. Only used for Display purposes.'), + ), + migrations.AlterField( + model_name='modelpricing', + name='model_name', + field=models.CharField(choices=[('gpt_4_vision', 'GPT-4 Vision (openai)'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai)'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('llama2_70b_chat', 'Llama 2 (Meta AI)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 (openai)'), ('text_curie_001', 'Curie (openai)'), ('text_babbage_001', 'Babbage (openai)'), ('text_ada_001', 'Ada (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)')], help_text='The name of the model. Only used for Display purposes.', max_length=255), + ), + migrations.AlterField( + model_name='modelpricing', + name='provider', + field=models.IntegerField(choices=[(1, 'OpenAI'), (2, 'Google'), (3, 'TogetherAI'), (4, 'Azure OpenAI'), (5, 'Azure Kubernetes Service')], help_text='The provider of the model. Only used for Display purposes.'), + ), + migrations.AlterField( + model_name='modelpricing', + name='sku', + field=models.IntegerField(choices=[(1, 'LLM Prompt'), (2, 'LLM Completion'), (3, 'GPU Milliseconds')], help_text="The model's SKU. Model ID + SKU should be unique together."), + ), + migrations.AlterField( + model_name='modelpricing', + name='unit_quantity', + field=models.PositiveIntegerField(default=1, help_text='The quantity of the unit. (e.g. 1000 tokens)'), + ), + ] diff --git a/usage_costs/models.py b/usage_costs/models.py index 67b5c69e0..3cdee247f 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -40,6 +40,7 @@ def __str__(self): class ModelCategory(models.IntegerChoices): LLM = 1, "LLM" + SELF_HOSTED = 2, "Self-Hosted" class ModelProvider(models.IntegerChoices): @@ -48,17 +49,24 @@ class ModelProvider(models.IntegerChoices): together_ai = 3, "TogetherAI" azure_openai = 4, "Azure OpenAI" + aks = 5, "Azure Kubernetes Service" + def get_model_choices(): from daras_ai_v2.language_model import LargeLanguageModels + from recipes.DeforumSD import AnimationModels - return [(api.name, api.value) for api in LargeLanguageModels] + return [(api.name, api.value) for api in LargeLanguageModels] + [ + (model.name, model.label) for model in AnimationModels + ] class ModelSku(models.IntegerChoices): llm_prompt = 1, "LLM Prompt" llm_completion = 2, "LLM Completion" + gpu_ms = 3, "GPU Milliseconds" + class ModelPricing(models.Model): model_id = models.TextField( @@ -75,7 +83,7 @@ class ModelPricing(models.Model): help_text="The cost per unit.", ) unit_quantity = models.PositiveIntegerField( - help_text="The quantity of the unit. (e.g. 1000 tokens)" + help_text="The quantity of the unit. (e.g. 1000 tokens)", default=1 ) category = models.IntegerField( From b89905ba0912abec84233f1be09d9f77fa0bb280 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 22:07:29 +0530 Subject: [PATCH 62/85] init self hosted pricing --- daras_ai_v2/gpu_server.py | 6 +++++- scripts/init_self_hosted_pricing.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/daras_ai_v2/gpu_server.py b/daras_ai_v2/gpu_server.py index 3050cd98b..752879f8d 100644 --- a/daras_ai_v2/gpu_server.py +++ b/daras_ai_v2/gpu_server.py @@ -154,7 +154,7 @@ def call_celery_task( from usage_costs.cost_utils import record_cost_auto from usage_costs.models import ModelSku - queue = os.path.join(queue_prefix, pipeline["model_id"].strip()).strip("/") + queue = build_queue_name(queue_prefix, pipeline["model_id"]) result = get_celery().send_task( task_name, kwargs=dict(pipeline=pipeline, inputs=inputs), queue=queue ) @@ -164,3 +164,7 @@ def call_celery_task( model=queue, sku=ModelSku.gpu_ms, quantity=int((time() - s) * 1000) ) return ret + + +def build_queue_name(queue_prefix: str, model_id: str) -> str: + return os.path.join(queue_prefix, model_id.strip()).strip("/") diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index d04d42367..05947eb47 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -1,2 +1,29 @@ +from decimal import Decimal + +from daras_ai_v2.gpu_server import build_queue_name +from recipes.DeforumSD import AnimationModels +from usage_costs.models import ModelPricing +from usage_costs.models import ModelSku, ModelCategory, ModelProvider + +category = ModelCategory.SELF_HOSTED + + def run(): - pass + for model in AnimationModels: + add_model(model.value, model.name) + + +def add_model(model_id, model_name): + ModelPricing.objects.get_or_create( + model_id=build_queue_name("gooey-gpu", model_id), + sku=ModelSku.gpu_ms, + defaults=dict( + model_name=model_name, + unit_cost=Decimal("3.673"), + unit_quantity=3600000, + category=category, + provider=ModelProvider.aks, + notes="NC24ads A100 v4 - 1 X A100 - Pay as you go - $3.6730/hour", + pricing_url="https://azure.microsoft.com/en-in/pricing/details/virtual-machines/linux/#pricing", + ), + ) From f4f4c1fcb2833b966e87d7eb5413309060df058f Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 22:19:09 +0530 Subject: [PATCH 63/85] img2img/text2img/inpaint --- scripts/init_self_hosted_pricing.py | 20 +++++++++++++++++++ .../0003_alter_modelpricing_model_name.py | 18 +++++++++++++++++ usage_costs/models.py | 14 +++++++++---- 3 files changed, 48 insertions(+), 4 deletions(-) create mode 100644 usage_costs/migrations/0003_alter_modelpricing_model_name.py diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index 05947eb47..7bc16376a 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -1,6 +1,14 @@ from decimal import Decimal from daras_ai_v2.gpu_server import build_queue_name +from daras_ai_v2.stable_diffusion import ( + Text2ImgModels, + Img2ImgModels, + InpaintingModels, + text2img_model_ids, + img2img_model_ids, + inpaint_model_ids, +) from recipes.DeforumSD import AnimationModels from usage_costs.models import ModelPricing from usage_costs.models import ModelSku, ModelCategory, ModelProvider @@ -11,6 +19,18 @@ def run(): for model in AnimationModels: add_model(model.value, model.name) + for model_enum, model_ids in [ + (Text2ImgModels, text2img_model_ids), + (Img2ImgModels, img2img_model_ids), + (InpaintingModels, inpaint_model_ids), + ]: + for m in model_enum: + if "dall_e" in m.name: + continue + try: + add_model(model_ids[m], m.name) + except KeyError: + pass def add_model(model_id, model_name): diff --git a/usage_costs/migrations/0003_alter_modelpricing_model_name.py b/usage_costs/migrations/0003_alter_modelpricing_model_name.py new file mode 100644 index 000000000..a4cbe05fa --- /dev/null +++ b/usage_costs/migrations/0003_alter_modelpricing_model_name.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-07 16:48 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('usage_costs', '0002_alter_modelpricing_category_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='modelpricing', + name='model_name', + field=models.CharField(choices=[('gpt_4_vision', 'GPT-4 Vision (openai)'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai)'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('llama2_70b_chat', 'Llama 2 (Meta AI)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 (openai)'), ('text_curie_001', 'Curie (openai)'), ('text_babbage_001', 'Babbage (openai)'), ('text_ada_001', 'Ada (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'DALL·E 2 (OpenAI)'), ('dall_e_3', 'DALL·E 3 (OpenAI)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐢'), ('openjourney', 'Open Journey (PromptHero) 🐢'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐢'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐢'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('deepfloyd_if', 'DeepFloyd IF [Deprecated] (stability.ai)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('instruct_pix2pix', '✨ InstructPix2Pix (Tim Brooks)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐢'), ('openjourney', 'Open Journey (PromptHero) 🐢'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐢'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐢'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('runway_ml', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)')], help_text='The name of the model. Only used for Display purposes.', max_length=255), + ), + ] diff --git a/usage_costs/models.py b/usage_costs/models.py index 3cdee247f..5a3bf5171 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -1,6 +1,7 @@ from django.db import models from bots.custom_fields import CustomURLField +from daras_ai_v2.stable_diffusion import InpaintingModels max_digits = 15 decimal_places = 10 @@ -55,10 +56,15 @@ class ModelProvider(models.IntegerChoices): def get_model_choices(): from daras_ai_v2.language_model import LargeLanguageModels from recipes.DeforumSD import AnimationModels - - return [(api.name, api.value) for api in LargeLanguageModels] + [ - (model.name, model.label) for model in AnimationModels - ] + from daras_ai_v2.stable_diffusion import Text2ImgModels, Img2ImgModels + + return ( + [(api.name, api.value) for api in LargeLanguageModels] + + [(model.name, model.label) for model in AnimationModels] + + [(model.name, model.value) for model in Text2ImgModels] + + [(model.name, model.value) for model in Img2ImgModels] + + [(model.name, model.value) for model in InpaintingModels] + ) class ModelSku(models.IntegerChoices): From b0d9125d4caa576e3626ee2a8f969b3b703bb7c0 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 22:26:08 +0530 Subject: [PATCH 64/85] wav2lip pricing --- scripts/init_self_hosted_pricing.py | 1 + usage_costs/models.py | 1 + 2 files changed, 2 insertions(+) diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index 7bc16376a..6ea45c7d3 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -31,6 +31,7 @@ def run(): add_model(model_ids[m], m.name) except KeyError: pass + add_model("gooey-gpu/wav2lip_gan.pth", "wav2lip") def add_model(model_id, model_name): diff --git a/usage_costs/models.py b/usage_costs/models.py index 5a3bf5171..e3fce075f 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -64,6 +64,7 @@ def get_model_choices(): + [(model.name, model.value) for model in Text2ImgModels] + [(model.name, model.value) for model in Img2ImgModels] + [(model.name, model.value) for model in InpaintingModels] + + [("wav2lip", "LipSync (wav2lip)")] ) From 93d786da35d83a6b25ea97bf7fce3dd09b99a8ec Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 8 Feb 2024 18:39:52 +0530 Subject: [PATCH 65/85] expand gdrive folders in parallel avoid mutating the list that is being iterated --- daras_ai_v2/doc_search_settings_widgets.py | 26 +++++++++----- daras_ai_v2/gdrive_downloader.py | 41 +++++++++++----------- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index 39cc152db..f175cd911 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -1,10 +1,14 @@ import os import typing +from furl import furl +from sentry_sdk import capture_exception + import gooey_ui as st from daras_ai_v2 import settings from daras_ai_v2.asr import AsrModels, google_translate_language_selector from daras_ai_v2.enum_selector_widget import enum_selector +from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder from daras_ai_v2.search_ref import CitationStyles _user_media_url_prefix = os.path.join( @@ -76,20 +80,24 @@ def document_uploader( accept_multiple_files=accept_multiple_files, ) documents = st.session_state.get(key, []) - for document in documents: - if not document.startswith("https://drive.google.com/drive/folders"): - continue - from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder - from furl import furl - - folder_content_urls = gdrive_list_urls_of_files_in_folder(furl(document)) - documents.remove(document) - documents.extend(folder_content_urls) + try: + documents = list(_expand_gdrive_folders(documents)) + except Exception as e: + capture_exception(e) + st.error(f"Error expanding gdrive folders: {e}") st.session_state[key] = documents st.session_state[custom_key] = "\n".join(documents) return documents +def _expand_gdrive_folders(documents: list[str]) -> list[str]: + for url in documents: + if url.startswith("https://drive.google.com/drive/folders"): + yield from gdrive_list_urls_of_files_in_folder(furl(url)) + else: + yield url + + def doc_search_settings( asr_allowed: bool = False, keyword_instructions_allowed: bool = False, diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index 345099aba..5ae9fa176 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -5,6 +5,8 @@ from googleapiclient import discovery from googleapiclient.http import MediaIoBaseDownload +from daras_ai_v2.functional import flatmap_parallel + def is_gdrive_url(f: furl) -> bool: return f.host in ["drive.google.com", "docs.google.com"] @@ -25,34 +27,33 @@ def url_to_gdrive_file_id(f: furl) -> str: return file_id -def gdrive_list_urls_of_files_in_folder(f: furl, max_depth=10) -> list[str]: +def gdrive_list_urls_of_files_in_folder(f: furl, max_depth: int = 4) -> list[str]: if max_depth <= 0: return [] + assert f.host == "drive.google.com", f"Bad google drive folder url: {f}" # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) folder_id = f.path.segments[-1] service = discovery.build("drive", "v3") - if f.host == "drive.google.com": - request = service.files().list( - supportsAllDrives=True, - includeItemsFromAllDrives=True, - q=f"'{folder_id}' in parents", - fields="files(mimeType,webViewLink)", - ) - else: - raise ValueError(f"Can't list files from non google folder url: {str(f)!r}") + request = service.files().list( + supportsAllDrives=True, + includeItemsFromAllDrives=True, + q=f"'{folder_id}' in parents", + fields="files(mimeType,webViewLink)", + ) response = request.execute() files = response.get("files", []) - urls = [] - for file in files: - mime_type = file.get("mimeType") - url = file.get("webViewLink") - if mime_type == "application/vnd.google-apps.folder": - urls += gdrive_list_urls_of_files_in_folder( - furl(url), max_depth=max_depth - 1 + urls = flatmap_parallel( + lambda file: ( + gdrive_list_urls_of_files_in_folder(furl(url), max_depth=max_depth - 1) + if ( + (url := file.get("webViewLink")) + and file.get("mimeType") == "application/vnd.google-apps.folder" ) - elif url: - urls.append(url) - return urls + else [url] + ), + files, + ) + return filter(None, urls) def gdrive_download(f: furl, mime_type: str) -> tuple[bytes, str]: From f9723feef76382ee2713170e91b690112e41a864 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 9 Feb 2024 20:47:12 +0530 Subject: [PATCH 66/85] fix openai usage record for non streaming --- daras_ai_v2/language_model.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 3d48b1068..fd6142758 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -585,20 +585,9 @@ def _run_openai_chat( if stream: return _stream_openai_chunked(r, used_model, messages) else: - from usage_costs.cost_utils import record_cost_auto - from usage_costs.models import ModelSku - - record_cost_auto( - model=used_model, - sku=ModelSku.llm_prompt, - quantity=r.usage.prompt_tokens, - ) - record_cost_auto( - model=used_model, - sku=ModelSku.llm_completion, - quantity=r.usage.completion_tokens, - ) - return [choice.message.dict() for choice in r.choices] + ret = [choice.message.dict() for choice in r.choices] + record_openai_llm_usage(used_model, messages, ret) + return ret def _get_chat_completions_create(model: str, **kwargs): @@ -674,6 +663,12 @@ def _stream_openai_chunked( entry["content"] += entry["chunk"] yield ret + record_openai_llm_usage(used_model, messages, ret) + + +def record_openai_llm_usage( + used_model: str, messages: list[ConversationEntry], choices: list[ConversationEntry] +): from usage_costs.cost_utils import record_cost_auto from usage_costs.models import ModelSku @@ -687,7 +682,7 @@ def _stream_openai_chunked( record_cost_auto( model=used_model, sku=ModelSku.llm_completion, - quantity=sum(default_length_function(entry["content"]) for entry in ret), + quantity=sum(default_length_function(entry["content"]) for entry in choices), ) From d6481579f6e114dfaae25364efd3647e7eb067a6 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 8 Feb 2024 22:17:33 +0530 Subject: [PATCH 67/85] show custom urls in uppy rename Custom URLs -> Submit Links in Bulk --- daras_ai_v2/bot_integration_widgets.py | 3 +++ daras_ai_v2/doc_search_settings_widgets.py | 8 +++----- gooey_ui/components/__init__.py | 4 +--- recipes/QRCodeGenerator.py | 17 +++-------------- routers/root.py | 6 ++++++ 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 0790f0b43..73c6ccad0 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -19,6 +19,9 @@ def general_integration_settings(bi: BotIntegration): st.session_state[f"_bi_user_language_{bi.id}"] = BotIntegration._meta.get_field( "user_language" ).default + st.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( + BotIntegration._meta.get_field("streaming_enabled").default + ) st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( BotIntegration._meta.get_field("show_feedback_buttons").default ) diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index f175cd911..a64572b99 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -40,11 +40,8 @@ def document_uploader( documents = st.session_state.get(key) or [] if isinstance(documents, str): documents = [documents] - has_custom_urls = not all(map(is_user_uploaded_url, documents)) custom_key = "__custom_" + key - if st.checkbox( - "Enter Custom URLs", key=f"__custom_checkbox_{key}", value=has_custom_urls - ): + if st.session_state.get(f"__custom_checkbox_{key}"): if not custom_key in st.session_state: st.session_state[custom_key] = "\n".join(documents) if accept_multiple_files: @@ -67,7 +64,7 @@ def document_uploader( **kwargs, ) if accept_multiple_files: - st.session_state[key] = text_value.strip().splitlines() + st.session_state[key] = filter(None, text_value.strip().splitlines()) else: st.session_state[key] = text_value else: @@ -79,6 +76,7 @@ def document_uploader( accept=accept, accept_multiple_files=accept_multiple_files, ) + st.checkbox("Submit Links in Bulk", key=f"__custom_checkbox_{key}") documents = st.session_state.get(key, []) try: documents = list(_expand_gdrive_folders(documents)) diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index 408b950fb..0814d65e9 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -293,9 +293,7 @@ def text_area( # assert not value, "only one of value or key can be provided" # else: if not key: - key = md5_values( - "textarea", label, height, help, value, placeholder, label_visibility - ) + key = md5_values("textarea", label, height, help, placeholder, label_visibility) value = str(state.session_state.setdefault(key, value) or "") if label_visibility != "visible": label = None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 34b615225..741149afd 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -566,20 +566,9 @@ def vcard_form(*, key: str) -> VCARD: st.error("No contact info found for that email") else: vcard = imported_vcard - # clear inputs - st.js( - # language=js - """ - const form = document.getElementById("gooey-form"); - if (!form) return; - Object.entries(fields).forEach(([k, v]) => { - const field = form["__vcard_data__" + k]; - if (!field) return; - field.value = v; - }); - """, - fields=vcard.dict(), - ) + # update inputs + for k, v in vcard.dict().items(): + st.session_state[f"__vcard_data__{k}"] = v vcard.format_name = st.text_input( "Name*", diff --git a/routers/root.py b/routers/root.py index 6f854ba20..e5f27edc6 100644 --- a/routers/root.py +++ b/routers/root.py @@ -31,6 +31,7 @@ RedirectException, get_example_request_body, ) +from daras_ai_v2.bots import request_json from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_scripts from daras_ai_v2.db import FIREBASE_SESSION_COOKIE from daras_ai_v2.manage_api_keys_widget import manage_api_keys @@ -165,6 +166,11 @@ async def logout(request: Request): return RedirectResponse(request.query_params.get("next", DEFAULT_LOGOUT_REDIRECT)) +@app.post("/__/file-upload/url/meta") +async def file_upload(request: Request, body_json: dict = Depends(request_json)): + return dict(name=(body_json["url"]), type="url/undefined", size=None) + + @app.post("/__/file-upload/") def file_upload(request: Request, form_data: FormData = Depends(request_form_files)): from wand.image import Image From 25903559018aea9523e6e5998ede5aa048a6cacc Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:33:12 -0800 Subject: [PATCH 68/85] Added download buttons --- daras_ai_v2/base.py | 47 ++++++++++++++++--------- gooey_ui/components/__init__.py | 37 ++++++++++++++++++++ recipes/CompareText2Img.py | 1 + recipes/CompareUpscaler.py | 1 + recipes/DeforumSD.py | 3 +- recipes/FaceInpainting.py | 1 + recipes/GoogleImageGen.py | 1 + recipes/ImageSegmentation.py | 1 + recipes/Img2Img.py | 9 ++--- recipes/Lipsync.py | 62 +++++++++++++++++++-------------- recipes/LipsyncTTS.py | 37 ++++---------------- recipes/ObjectInpainting.py | 1 + recipes/QRCodeGenerator.py | 5 +-- recipes/Text2Audio.py | 5 +-- recipes/TextToSpeech.py | 1 + recipes/VideoBots.py | 4 +-- 16 files changed, 131 insertions(+), 85 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index f05a0c71c..8dab30806 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1264,7 +1264,9 @@ def _render_report_button(self): if not (self.request.user and run_id and uid): return - reported = st.button("❗Report") + reported = st.button( + ' Report', type="tertiary" + ) if not reported: return @@ -1351,7 +1353,6 @@ def _render_output_col(self, submitted: bool): def _render_completed_output(self): run_time = st.session_state.get(StateKeys.run_time, 0) - st.success(f"Success! Run Time: `{run_time:.2f}` seconds.") def _render_failed_output(self): err_msg = st.session_state.get(StateKeys.error_msg) @@ -1513,22 +1514,36 @@ def clear_outputs(self): st.session_state.pop(field_name, None) def _render_after_output(self): - col1, col2, col3 = st.columns([1, 1, 1], responsive=False) - col2.node.props[ - "className" - ] += " d-flex justify-content-center align-items-center" - col3.node.props["className"] += " d-flex justify-content-end align-items-center" + caption = "" + caption += f'\\\nGenerated in {st.session_state.get(StateKeys.run_time, 0):.2f}s' if "seed" in self.RequestModel.schema_json(): seed = st.session_state.get("seed") - with col1: - st.caption(f"*Seed\\\n`{seed}`*") - with col2: - randomize = st.button("♻️ Regenerate") - if randomize: - st.session_state[StateKeys.pressed_randomize] = True - st.experimental_rerun() - with col3: - self._render_report_button() + caption += f' with seed {seed} ' + try: + format_created_at = st.session_state.get( + "created_at", datetime.datetime.today() + ).strftime("%d %b %Y %-I:%M%p") + except: + format_created_at = format_created_at = st.session_state.get( + "created_at", datetime.datetime.today() + ) + caption += f' at {format_created_at}' + st.caption(caption, unsafe_allow_html=True) + + def render_buttons(self, url: str): + st.download_button( + label=' Download', + url=url, + type="secondary", + ) + if "seed" in self.RequestModel.schema_json(): + randomize = st.button( + ' Regenerate', type="tertiary" + ) + if randomize: + st.session_state[StateKeys.pressed_randomize] = True + st.experimental_rerun() + self._render_report_button() def state_to_doc(self, state: dict): ret = { diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index 0814d65e9..7e435b725 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -453,6 +453,43 @@ def button( form_submit_button = button +def download_button( + label: str, + url: str, + key: str = None, + help: str = None, + *, + type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", + disabled: bool = False, + **props, +) -> bool: + """ + Example: + st.button("Primary", key="test0", type="primary") + st.button("Secondary", key="test1") + st.button("Tertiary", key="test3", type="tertiary") + st.button("Link Button", key="test3", type="link") + """ + if not key: + key = md5_values("button", label, help, type, props) + className = f"btn-{type} " + props.pop("className", "") + state.RenderTreeNode( + name="download-button", + props=dict( + type="submit", + value="yes", + url=url, + name=key, + label=dedent(label), + help=help, + disabled=disabled, + className=className, + **props, + ), + ).mount() + return bool(state.session_state.pop(key, False)) + + def expander(label: str, *, expanded: bool = False, **props): node = state.RenderTreeNode( name="expander", diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 5f44ca34b..79fd76273 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -247,6 +247,7 @@ def _render_outputs(self, state): output_images: dict = state.get("output_images", {}).get(key, []) for img in output_images: st.image(img, caption=Text2ImgModels[key].value) + self.render_buttons(img) def preview_description(self, state: dict) -> str: return "Create multiple AI photos from one prompt using Stable Diffusion (1.5 -> 2.1, Open/Midjourney), DallE, and other models. Find out which AI Image generator works best for your text prompt on comparing OpenAI, Stability.AI etc." diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index e4ed8865f..3206f532d 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -108,6 +108,7 @@ def _render_outputs(self, state): if not img: continue st.image(img, caption=UpscalerModels[key].value) + self.render_buttons(img) def get_raw_price(self, state: dict) -> int: selected_models = state.get("selected_models", []) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index 57b5bb9f7..d971e708c 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -423,8 +423,9 @@ def render_description(self): def render_output(self): output_video = st.session_state.get("output_video") if output_video: - st.write("Output Video") + st.write("#### Output Video") st.video(output_video, autoplay=True) + self.render_buttons(output_video) def estimate_run_duration(self): # in seconds diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index ee377b53d..3c14a8147 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -201,6 +201,7 @@ def render_output(self): url, caption="```" + text_prompt.replace("\n", "") + "```", ) + self.render_buttons(url) else: st.div() diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index c3dab52f0..f49657346 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -207,6 +207,7 @@ def render_output(self): if out_imgs: for img in out_imgs: st.image(img, caption="Generated Image") + self.render_buttons(img) else: st.div() diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 751a0936b..0f5eacbee 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -343,6 +343,7 @@ def render_example(self, state: dict): input_image = state.get("input_image") if input_image: st.image(input_image, caption="Input Photo") + self.render_buttons(input_image) else: st.div() diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 68e696c08..7334ebb31 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -91,7 +91,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.file_uploader( """ - ### Input Image + #### Input Image """, key="input_image", upload_meta=dict(resize=f"{SD_IMG_MAX_SIZE[0] * SD_IMG_MAX_SIZE[1]}@>"), @@ -100,7 +100,7 @@ def render_form_v2(self): if st.session_state.get("selected_model") != InpaintingModels.dall_e.name: st.text_area( """ - ### Prompt + #### Prompt Describe your edits """, key="text_prompt", @@ -129,9 +129,10 @@ def render_usage_guide(self): def render_output(self): text_prompt = st.session_state.get("text_prompt", "") output_images = st.session_state.get("output_images", []) - + st.write("#### Output Image") for img in output_images: - st.image(img, caption="```" + text_prompt.replace("\n", "") + "```") + st.image(img) + self.render_buttons(img) def render_example(self, state: dict): col1, col2 = st.columns(2) diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 5fe08a09b..8ea341aad 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -82,33 +82,13 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) def render_example(self, state: dict): - col1, col2 = st.columns(2) - - with col1: - input_face = state.get("input_face") - if not input_face: - st.div() - elif input_face.endswith(".mp4") or input_face.endswith(".mov"): - st.write("Input Face (Video)") - st.video(input_face) - else: - st.write("Input Face (Image)") - st.image(input_face) - - input_audio = state.get("input_audio") - if input_audio: - st.write("Input Audio") - st.audio(input_audio) - else: - st.div() - - with col2: - output_video = state.get("output_video") - if output_video: - st.write("Output Video") - st.video(output_video, autoplay=True) - else: - st.div() + output_video = state.get("output_video") + if output_video: + st.write("#### Output Video") + st.video(output_video, autoplay=True) + self.render_buttons(output_video) + else: + st.div() def render_output(self): self.render_example(st.session_state) @@ -145,3 +125,31 @@ def get_raw_price(self, state: dict) -> float: total_mb = total_bytes / 1024 / 1024 return total_mb * CREDITS_PER_MB + + def download_blob(self, bucket_name, source_blob_name, destination_file_name): + """Downloads a blob from the bucket.""" + # The ID of your GCS bucket + # bucket_name = "your-bucket-name" + + # The ID of your GCS object + # source_blob_name = "storage-object-name" + + # The path to which the file should be downloaded + # destination_file_name = "local/path/to/file" + + storage_client = storage.Client() + + bucket = storage_client.bucket(bucket_name) + + # Construct a client side representation of a blob. + # Note `Bucket.blob` differs from `Bucket.get_blob` as it doesn't retrieve + # any content from Google Cloud Storage. As we don't need additional data, + # using `Bucket.blob` is preferred here. + blob = bucket.blob(source_blob_name) + blob.download_to_filename(destination_file_name) + + print( + "Downloaded storage object {} from bucket {} to local file {}.".format( + source_blob_name, bucket_name, destination_file_name + ) + ) diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 60886c017..d463d7483 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -127,37 +127,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield from LipsyncPage.run(self, state) def render_example(self, state: dict): - col1, col2 = st.columns(2) - - with col1: - input_face = state.get("input_face") - if not input_face: - pass - elif input_face.endswith(".mp4") or input_face.endswith(".mov"): - st.video(input_face, caption="Input Face (Video)") - else: - st.image(input_face, caption="Input Face (Image)") - - input_text = state.get("text_prompt") - if input_text: - st.write("**Input Text**") - st.write(input_text) - else: - st.div() - - # input_audio = state.get("input_audio") - # if input_audio: - # st.write("Synthesized Voice") - # st.audio(input_audio) - # else: - # st.empty() - - with col2: - output_video = state.get("output_video") - if output_video: - st.video(output_video, caption="Output Video", autoplay=True) - else: - st.div() + output_video = state.get("output_video") + if output_video: + st.video(output_video, caption="#### Output Video", autoplay=True) + self.render_buttons(output_video) + else: + st.div() def render_output(self): self.render_example(st.session_state) diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index d84c59069..4d7f7342b 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -202,6 +202,7 @@ def render_output(self): if output_images: for url in output_images: st.image(url, caption=f"{text_prompt}") + self.render_buttons(url) else: st.div() diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 741149afd..1fe283825 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -144,7 +144,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.text_area( """ - ##### 👩‍💻 Prompt + #### 👩‍💻 Prompt Describe the subject/scene of the QR Code. Choose clear prompts and distinguishable visuals to ensure optimal readability. """, @@ -208,7 +208,7 @@ def render_form_v2(self): st.file_uploader( """ - ##### 🏞️ Reference Image *[optional]* + #### 🏞️ Reference Image *[optional]* This image will be used as inspiration to blend with the QR Code. """, key="image_prompt", @@ -458,6 +458,7 @@ def _render_outputs(self, state: dict, max_count: int | None = None): st.caption(f"{shortened_url} → {qr_code_data} (Views: {clicks})") else: st.caption(f"{shortened_url} → {qr_code_data}") + self.render_buttons(img) def run(self, state: dict) -> typing.Iterator[str | None]: request: QRCodeGeneratorPage.RequestModel = self.RequestModel.parse_obj(state) diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 18270cc8c..0e0991c35 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -128,7 +128,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) def render_output(self): - _render_output(st.session_state) + _render_output(self, st.session_state) def render_example(self, state: dict): col1, col2 = st.columns(2) @@ -141,9 +141,10 @@ def preview_description(self, state: dict) -> str: return "Generate AI Music with text instruction prompts. AudiLDM is capable of generating realistic audio samples by process any text input. Learn more [here](https://huggingface.co/cvssp/audioldm-m-full)." -def _render_output(state): +def _render_output(self, state): selected_models = state.get("selected_models", []) for key in selected_models: output: dict = state.get("output_audios", {}).get(key, []) for audio in output: st.audio(audio, caption=Text2AudioModels[key].value) + self.render_buttons(audio) diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index 721b2cb7a..de13bfbc0 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -135,6 +135,7 @@ def render_output(self): audio_url = st.session_state.get("audio_url") if audio_url: st.audio(audio_url) + self.render_buttons(audio_url) else: st.div() diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 5d23502ff..2d0bc26e0 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -312,7 +312,7 @@ def render_description(self): def render_form_v2(self): st.text_area( """ - ##### 📝 Prompt + #### 📝 Prompt High-level system instructions. """, key="bot_script", @@ -321,7 +321,7 @@ def render_form_v2(self): document_uploader( """ -##### 📄 Documents (*optional*) +#### 📄 Documents (*optional*) Upload documents or enter URLs to give your copilot a knowledge base. With each incoming user message, we'll search your documents via a vector DB query. """ ) From cd991b5664014b6d238e01419cf69d1f969a36ad Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:48:19 -0800 Subject: [PATCH 69/85] Fixed headings --- daras_ai_v2/img_model_settings_widgets.py | 2 +- daras_ai_v2/prompt_vars.py | 2 +- recipes/CompareText2Img.py | 2 +- recipes/CompareUpscaler.py | 2 +- recipes/DocExtract.py | 4 ++-- recipes/DocSearch.py | 6 +++--- recipes/DocSummary.py | 4 ++-- recipes/EmailFaceInpainting.py | 4 ++-- recipes/FaceInpainting.py | 10 ++++------ recipes/GoogleGPT.py | 2 +- recipes/GoogleImageGen.py | 6 +++--- recipes/ImageSegmentation.py | 2 +- recipes/ObjectInpainting.py | 4 ++-- recipes/SEOSummary.py | 4 ++-- recipes/SocialLookupEmail.py | 6 +++--- recipes/Text2Audio.py | 2 +- recipes/TextToSpeech.py | 2 +- recipes/asr.py | 4 ++-- 18 files changed, 33 insertions(+), 35 deletions(-) diff --git a/daras_ai_v2/img_model_settings_widgets.py b/daras_ai_v2/img_model_settings_widgets.py index e86c18873..8249833a0 100644 --- a/daras_ai_v2/img_model_settings_widgets.py +++ b/daras_ai_v2/img_model_settings_widgets.py @@ -83,7 +83,7 @@ def model_selector( selected_model = enum_selector( models_enum, label=""" - ### 🤖 Choose your preferred AI Model + #### 🤖 Choose your preferred AI Model """, key="selected_model", use_selectbox=True, diff --git a/daras_ai_v2/prompt_vars.py b/daras_ai_v2/prompt_vars.py index 80f13d906..fabb3928c 100644 --- a/daras_ai_v2/prompt_vars.py +++ b/daras_ai_v2/prompt_vars.py @@ -28,7 +28,7 @@ def prompt_vars_widget(*keys: str, variables_key: str = "variables"): if not (template_vars or err): return - st.write("##### ⌥ Variables") + st.write("#### ⌥ Variables") old_state = st.session_state.get(variables_key, {}) new_state = {} for name in sorted(template_vars): diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 79fd76273..47347bb65 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -96,7 +96,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.text_area( """ - ### 👩‍💻 Prompt + #### 👩‍💻 Prompt Describe the scene that you'd like to generate. """, key="text_prompt", diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index 3206f532d..4aad4ffb2 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -34,7 +34,7 @@ class ResponseModel(BaseModel): def render_form_v2(self): st.file_uploader( """ - ### Input Image + #### Input Image """, key="input_image", upload_meta=dict(resize=f"{SD_IMG_MAX_SIZE[0] * SD_IMG_MAX_SIZE[1]}@>"), diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 2672fffdb..ade752624 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -94,11 +94,11 @@ def preview_image(self, state: dict) -> str | None: def render_form_v2(self): document_uploader( - "##### 🤖 Youtube URLS", + "#### 🤖 Youtube URLS", accept=("audio/*", "application/pdf", "video/*"), ) st.text_input( - "##### 📊 Google Sheets URL", + "#### 📊 Google Sheets URL", key="sheet_url", ) diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 0b95464e1..6d7977c25 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -82,8 +82,8 @@ class ResponseModel(BaseModel): final_search_query: str | None def render_form_v2(self): - st.text_area("##### Search Query", key="search_query") - document_uploader("##### Documents") + st.text_area("#### Search Query", key="search_query") + document_uploader("#### Documents") prompt_vars_widget("task_instructions", "query_instructions") def validate_form_v2(self): @@ -111,7 +111,7 @@ def render_output(self): def render_example(self, state: dict): render_documents(state) - st.write("**Search Query**") + st.html("**Search Query**") st.write("```properties\n" + state.get("search_query", "") + "\n```") render_output_with_refs(state, 200) diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 498f2d405..476bd5da6 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -84,8 +84,8 @@ def preview_image(self, state: dict) -> str | None: return DEFAULT_DOC_SUMMARY_META_IMG def render_form_v2(self): - document_uploader("##### 📎 Documents") - st.text_area("##### 👩‍💻 Instructions", key="task_instructions", height=150) + document_uploader("#### 📎 Documents") + st.text_area("#### 👩‍💻 Instructions", key="task_instructions", height=150) def render_settings(self): st.text_area( diff --git a/recipes/EmailFaceInpainting.py b/recipes/EmailFaceInpainting.py index 74749bdb5..4b8b50d1e 100644 --- a/recipes/EmailFaceInpainting.py +++ b/recipes/EmailFaceInpainting.py @@ -115,7 +115,7 @@ def render_description(self): def render_form_v2(self): st.text_area( """ - ### Prompt + #### Prompt Describe the scene that you'd like to generate around the face. """, key="text_prompt", @@ -130,7 +130,7 @@ def render_form_v2(self): source = st.radio( """ - ### Photo Source + #### Photo Source From where we should get the photo?""", options=["Email Address", "Twitter Handle"], key="__photo_source", diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index 3c14a8147..f6a550d48 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -105,7 +105,7 @@ def render_description(self): def render_form_v2(self): st.text_area( """ - ### Prompt + #### Prompt Describe the character that you'd like to generate. """, key="text_prompt", @@ -114,7 +114,7 @@ def render_form_v2(self): st.file_uploader( """ - ### Face Photo + #### Face Photo Give us a photo of yourself, or anyone else """, key="input_image", @@ -196,11 +196,9 @@ def render_output(self): output_images = st.session_state.get("output_images") if output_images: + st.write("#### Output Image") for url in output_images: - st.image( - url, - caption="```" + text_prompt.replace("\n", "") + "```", - ) + st.image(url) self.render_buttons(url) else: st.div() diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 7a2d2591f..35d650acb 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -111,7 +111,7 @@ class ResponseModel(BaseModel): final_search_query: str | None def render_form_v2(self): - st.text_area("##### Google Search Query", key="search_query") + st.text_area("#### Google Search Query", key="search_query") st.text_input("Search on a specific site *(optional)*", key="site_filter") prompt_vars_widget("task_instructions", "query_instructions") diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index f49657346..fd5d2db80 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -180,7 +180,7 @@ def run(self, state: dict): def render_form_v2(self): st.text_input( """ - ### 🔎 Google Image Search + #### 🔎 Google Image Search Type a query you'd use in [Google image search](https://images.google.com/?gws_rd=ssl) """, key="search_query", @@ -188,7 +188,7 @@ def render_form_v2(self): model_selector(Img2ImgModels) st.text_area( """ - ### 👩‍💻 Prompt + #### 👩‍💻 Prompt Describe how you want to edit the photo in words """, key="text_prompt", @@ -206,7 +206,7 @@ def render_output(self): out_imgs = st.session_state.get("output_images") if out_imgs: for img in out_imgs: - st.image(img, caption="Generated Image") + st.image(img, caption="#### Generated Image") self.render_buttons(img) else: st.div() diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 0f5eacbee..39713cb2b 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -86,7 +86,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.file_uploader( """ - ### Input Photo + #### Input Photo Give us a photo of anything """, key="input_image", diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index 4d7f7342b..c2c3b29c4 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -95,7 +95,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.text_area( """ - ### Prompt + #### Prompt Describe the scene that you'd like to generate. """, key="text_prompt", @@ -104,7 +104,7 @@ def render_form_v2(self): st.file_uploader( """ - ### Object Photo + #### Object Photo Give us a photo of anything """, key="input_image", diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index 78328d822..79d8092a8 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -160,7 +160,7 @@ def render_description(self): ) def render_form_v2(self): - st.write("### Inputs") + st.write("#### Inputs") st.text_input("Google Search Query", key="search_query") st.text_input("Website Name", key="title") st.text_input("Website URL", key="company_url") @@ -192,7 +192,7 @@ def render_settings(self): def render_output(self): output_content = st.session_state.get("output_content") if output_content: - st.write("### Generated Content") + st.write("#### Generated Content") for idx, text in enumerate(output_content): if st.session_state.get("enable_html"): scrollable_html(text) diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index 1094cea53..b2518b1c8 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -120,7 +120,7 @@ def render_settings(self): def render_form_v2(self): st.text_input( """ - ### Email Address + #### Email Address Give us an email address and we'll try to get determine the profile data associated with it """, key="email_address", @@ -132,7 +132,7 @@ def render_form_v2(self): st.text_area( """ - ### Email Body + #### Email Body """, key="input_email_body", height=200, @@ -198,7 +198,7 @@ def _input_variables(self, state: dict): def render_output(self): st.text_area( """ - ### Email Body Output + #### Email Body Output """, disabled=True, value=st.session_state.get("output_email_body", ""), diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 0e0991c35..4d68b8eb5 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -64,7 +64,7 @@ def preview_image(self, state: dict) -> str | None: def render_form_v2(self): st.text_area( """ - ### 👩‍💻 Prompt + #### 👩‍💻 Prompt Describe the audio that you'd like to generate. """, key="text_prompt", diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index de13bfbc0..b90bb8a4b 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -100,7 +100,7 @@ def render_description(self): def render_form_v2(self): st.text_area( """ - ### Prompt + #### Prompt Enter text you want to convert to speech """, key="text_prompt", diff --git a/recipes/asr.py b/recipes/asr.py index a297bf9c3..7dcea4249 100644 --- a/recipes/asr.py +++ b/recipes/asr.py @@ -82,14 +82,14 @@ def related_workflows(self) -> list: def render_form_v2(self): document_uploader( - "##### Audio Files", + "#### Audio Files", accept=("audio/*", "video/*", "application/octet-stream"), ) col1, col2 = st.columns(2, responsive=False) with col1: selected_model = enum_selector( AsrModels, - label="##### ASR Model", + label="#### ASR Model", key="selected_model", use_selectbox=True, ) From e985052c1dc3b01aa64abd647dbcea3d19aed906 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:56:46 -0800 Subject: [PATCH 70/85] Fixed datetime format --- daras_ai_v2/base.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 8dab30806..75068f911 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1519,14 +1519,12 @@ def _render_after_output(self): if "seed" in self.RequestModel.schema_json(): seed = st.session_state.get("seed") caption += f' with seed {seed} ' - try: - format_created_at = st.session_state.get( - "created_at", datetime.datetime.today() - ).strftime("%d %b %Y %-I:%M%p") - except: - format_created_at = format_created_at = st.session_state.get( - "created_at", datetime.datetime.today() - ) + created_at = st.session_state.get( + StateKeys.created_at, datetime.datetime.today() + ) + if not isinstance(created_at, datetime.datetime): + created_at = datetime.datetime.fromisoformat(created_at) + format_created_at = created_at.strftime("%d %b %Y %-I:%M%p") caption += f' at {format_created_at}' st.caption(caption, unsafe_allow_html=True) From 6f75eca0c212cc5242bf787d8f9fdef2b69b4c57 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Fri, 26 Jan 2024 13:15:03 -0800 Subject: [PATCH 71/85] Regen button appears once --- daras_ai_v2/base.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 75068f911..28756f96a 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1514,6 +1514,13 @@ def clear_outputs(self): st.session_state.pop(field_name, None) def _render_after_output(self): + if "seed" in self.RequestModel.schema_json(): + randomize = st.button( + ' Regenerate', type="tertiary" + ) + if randomize: + st.session_state[StateKeys.pressed_randomize] = True + st.experimental_rerun() caption = "" caption += f'\\\nGenerated in {st.session_state.get(StateKeys.run_time, 0):.2f}s' if "seed" in self.RequestModel.schema_json(): @@ -1534,13 +1541,6 @@ def render_buttons(self, url: str): url=url, type="secondary", ) - if "seed" in self.RequestModel.schema_json(): - randomize = st.button( - ' Regenerate', type="tertiary" - ) - if randomize: - st.session_state[StateKeys.pressed_randomize] = True - st.experimental_rerun() self._render_report_button() def state_to_doc(self, state: dict): From 64b7dfcefe19b8da6f3bb3caa84b313786a43799 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 12 Feb 2024 11:57:49 +0530 Subject: [PATCH 72/85] fix report button re-usability fuckup put download button inside the audio/video/image components consistent datetime format fix half ass caption showing up when the recipe is running make download_button re-use the button code --- bots/models.py | 14 +++---- daras_ai_v2/base.py | 55 ++++++++++++------------- daras_ai_v2/settings.py | 2 + gooey_ui/components/__init__.py | 72 +++++++++++++++++---------------- recipes/CompareText2Img.py | 5 ++- recipes/CompareUpscaler.py | 3 +- recipes/DeforumSD.py | 3 +- recipes/FaceInpainting.py | 3 +- recipes/GoogleImageGen.py | 3 +- recipes/ImageSegmentation.py | 3 +- recipes/Img2Img.py | 6 +-- recipes/Lipsync.py | 3 +- recipes/LipsyncTTS.py | 8 +++- recipes/ObjectInpainting.py | 3 +- recipes/QRCodeGenerator.py | 3 +- recipes/Text2Audio.py | 9 +++-- recipes/TextToSpeech.py | 7 +--- 17 files changed, 101 insertions(+), 101 deletions(-) diff --git a/bots/models.py b/bots/models.py index 205dfb5ff..558512e8c 100644 --- a/bots/models.py +++ b/bots/models.py @@ -700,8 +700,8 @@ def to_df_format( .replace(tzinfo=None) ) row |= { - "Last Sent": last_time.strftime("%b %d, %Y %I:%M %p"), - "First Sent": first_time.strftime("%b %d, %Y %I:%M %p"), + "Last Sent": last_time.strftime(settings.SHORT_DATETIME_FORMAT), + "First Sent": first_time.strftime(settings.SHORT_DATETIME_FORMAT), "A7": not convo.d7(), "A30": not convo.d30(), "R1": last_time - first_time < datetime.timedelta(days=1), @@ -926,7 +926,7 @@ def to_df_format( "Message (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Feedback": ( message.feedbacks.first().get_display_text() if message.feedbacks.first() @@ -968,7 +968,7 @@ def to_df_analysis_format( "Answer (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Analysis JSON": message.analysis_result, } rows.append(row) @@ -1153,16 +1153,16 @@ def to_df_format( "Question Sent": feedback.message.get_previous_by_created_at() .created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Answer (EN)": feedback.message.content, "Answer Sent": feedback.message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Rating": Feedback.Rating(feedback.rating).label, "Feedback (EN)": feedback.text_english, "Feedback Sent": feedback.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Question Answered": feedback.message.question_answered, } rows.append(row) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 28756f96a..0a885bf0a 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1348,11 +1348,11 @@ def _render_output_col(self, submitted: bool): # render outputs self.render_output() - if run_state != "waiting": + if run_state != RecipeRunState.running: self._render_after_output() def _render_completed_output(self): - run_time = st.session_state.get(StateKeys.run_time, 0) + pass def _render_failed_output(self): err_msg = st.session_state.get(StateKeys.error_msg) @@ -1368,12 +1368,10 @@ def render_extra_waiting_output(self): if not estimated_run_time: return if created_at := st.session_state.get("created_at"): - if isinstance(created_at, datetime.datetime): - start_time = created_at - else: - start_time = datetime.datetime.fromisoformat(created_at) + if isinstance(created_at, str): + created_at = datetime.datetime.fromisoformat(created_at) with st.countdown_timer( - end_time=start_time + datetime.timedelta(seconds=estimated_run_time), + end_time=created_at + datetime.timedelta(seconds=estimated_run_time), delay_text="Sorry for the wait. Your run is taking longer than we expected.", ): if self.is_current_user_owner() and self.request.user.email: @@ -1514,6 +1512,8 @@ def clear_outputs(self): st.session_state.pop(field_name, None) def _render_after_output(self): + self._render_report_button() + if "seed" in self.RequestModel.schema_json(): randomize = st.button( ' Regenerate', type="tertiary" @@ -1521,27 +1521,8 @@ def _render_after_output(self): if randomize: st.session_state[StateKeys.pressed_randomize] = True st.experimental_rerun() - caption = "" - caption += f'\\\nGenerated in {st.session_state.get(StateKeys.run_time, 0):.2f}s' - if "seed" in self.RequestModel.schema_json(): - seed = st.session_state.get("seed") - caption += f' with seed {seed} ' - created_at = st.session_state.get( - StateKeys.created_at, datetime.datetime.today() - ) - if not isinstance(created_at, datetime.datetime): - created_at = datetime.datetime.fromisoformat(created_at) - format_created_at = created_at.strftime("%d %b %Y %-I:%M%p") - caption += f' at {format_created_at}' - st.caption(caption, unsafe_allow_html=True) - def render_buttons(self, url: str): - st.download_button( - label=' Download', - url=url, - type="secondary", - ) - self._render_report_button() + render_output_caption() def state_to_doc(self, state: dict): ret = { @@ -1918,6 +1899,26 @@ def is_current_user_owner(self) -> bool: ) +def render_output_caption(): + caption = "" + + run_time = st.session_state.get(StateKeys.run_time, 0) + if run_time: + caption += f'Generated in {run_time :.2f}s' + + if seed := st.session_state.get("seed"): + caption += f' with seed {seed} ' + + created_at = st.session_state.get(StateKeys.created_at, datetime.datetime.today()) + if created_at: + if isinstance(created_at, str): + created_at = datetime.datetime.fromisoformat(created_at) + format_created_at = created_at.strftime(settings.SHORT_DATETIME_FORMAT) + caption += f' at {format_created_at}' + + st.caption(caption, unsafe_allow_html=True) + + def get_example_request_body( request_model: typing.Type[BaseModel], state: dict, diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index deda81500..512fb3834 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -162,6 +162,8 @@ es_formats.DATETIME_FORMAT = DATETIME_FORMAT +SHORT_DATETIME_FORMAT = "%b %d, %Y %-I:%M %p" + # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.2/howto/static-files/ diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index 7e435b725..6322cf629 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -217,6 +217,7 @@ def image( caption: str = None, alt: str = None, href: str = None, + show_download_button: bool = False, **props, ): if isinstance(src, np.ndarray): @@ -241,9 +242,18 @@ def image( **props, ), ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) -def video(src: str, caption: str = None, autoplay: bool = False): +def video( + src: str, + caption: str = None, + autoplay: bool = False, + show_download_button: bool = False, +): autoplay_props = {} if autoplay: autoplay_props = { @@ -266,15 +276,23 @@ def video(src: str, caption: str = None, autoplay: bool = False): name="video", props=dict(src=src, caption=dedent(caption), **autoplay_props), ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) -def audio(src: str, caption: str = None): +def audio(src: str, caption: str = None, show_download_button: bool = False): if not src: return state.RenderTreeNode( name="audio", props=dict(src=src, caption=dedent(caption)), ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) def text_area( @@ -415,8 +433,9 @@ def selectbox( return value -def button( +def download_button( label: str, + url: str, key: str = None, help: str = None, *, @@ -424,43 +443,26 @@ def button( disabled: bool = False, **props, ) -> bool: - """ - Example: - st.button("Primary", key="test0", type="primary") - st.button("Secondary", key="test1") - st.button("Tertiary", key="test3", type="tertiary") - st.button("Link Button", key="test3", type="link") - """ - if not key: - key = md5_values("button", label, help, type, props) - className = f"btn-{type} " + props.pop("className", "") - state.RenderTreeNode( - name="gui-button", - props=dict( - type="submit", - value="yes", - name=key, - label=dedent(label), - help=help, - disabled=disabled, - className=className, - **props, - ), - ).mount() - return bool(state.session_state.pop(key, False)) - - -form_submit_button = button + return button( + component="download-button", + url=url, + label=label, + key=key, + help=help, + type=type, + disabled=disabled, + **props, + ) -def download_button( +def button( label: str, - url: str, key: str = None, help: str = None, *, type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", disabled: bool = False, + component: typing.Literal["download-button", "gui-button"] = "gui-button", **props, ) -> bool: """ @@ -474,11 +476,10 @@ def download_button( key = md5_values("button", label, help, type, props) className = f"btn-{type} " + props.pop("className", "") state.RenderTreeNode( - name="download-button", + name=component, props=dict( type="submit", value="yes", - url=url, name=key, label=dedent(label), help=help, @@ -490,6 +491,9 @@ def download_button( return bool(state.session_state.pop(key, False)) +form_submit_button = button + + def expander(label: str, *, expanded: bool = False, **props): node = state.RenderTreeNode( name="expander", diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 47347bb65..dc5ea1ae2 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -246,8 +246,9 @@ def _render_outputs(self, state): for key in selected_models: output_images: dict = state.get("output_images", {}).get(key, []) for img in output_images: - st.image(img, caption=Text2ImgModels[key].value) - self.render_buttons(img) + st.image( + img, caption=Text2ImgModels[key].value, show_download_button=True + ) def preview_description(self, state: dict) -> str: return "Create multiple AI photos from one prompt using Stable Diffusion (1.5 -> 2.1, Open/Midjourney), DallE, and other models. Find out which AI Image generator works best for your text prompt on comparing OpenAI, Stability.AI etc." diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index 4aad4ffb2..13f62de91 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -107,8 +107,7 @@ def _render_outputs(self, state): img: dict = state.get("output_images", {}).get(key) if not img: continue - st.image(img, caption=UpscalerModels[key].value) - self.render_buttons(img) + st.image(img, caption=UpscalerModels[key].value, show_download_button=True) def get_raw_price(self, state: dict) -> int: selected_models = state.get("selected_models", []) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index d971e708c..35a71a12c 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -424,8 +424,7 @@ def render_output(self): output_video = st.session_state.get("output_video") if output_video: st.write("#### Output Video") - st.video(output_video, autoplay=True) - self.render_buttons(output_video) + st.video(output_video, autoplay=True, show_download_button=True) def estimate_run_duration(self): # in seconds diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index f6a550d48..6d787bba6 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -198,8 +198,7 @@ def render_output(self): if output_images: st.write("#### Output Image") for url in output_images: - st.image(url) - self.render_buttons(url) + st.image(url, show_download_button=True) else: st.div() diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index fd5d2db80..278128a37 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -206,8 +206,7 @@ def render_output(self): out_imgs = st.session_state.get("output_images") if out_imgs: for img in out_imgs: - st.image(img, caption="#### Generated Image") - self.render_buttons(img) + st.image(img, caption="#### Generated Image", show_download_button=True) else: st.div() diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 39713cb2b..c247fdeda 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -342,8 +342,7 @@ def render_example(self, state: dict): with col1: input_image = state.get("input_image") if input_image: - st.image(input_image, caption="Input Photo") - self.render_buttons(input_image) + st.image(input_image, caption="Input Photo", show_download_button=True) else: st.div() diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 7334ebb31..e2713aaa2 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -127,12 +127,12 @@ def render_usage_guide(self): youtube_video("narcZNyuNAg") def render_output(self): - text_prompt = st.session_state.get("text_prompt", "") output_images = st.session_state.get("output_images", []) + if not output_images: + return st.write("#### Output Image") for img in output_images: - st.image(img) - self.render_buttons(img) + st.image(img, show_download_button=True) def render_example(self, state: dict): col1, col2 = st.columns(2) diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 8ea341aad..c2ac64369 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -85,8 +85,7 @@ def render_example(self, state: dict): output_video = state.get("output_video") if output_video: st.write("#### Output Video") - st.video(output_video, autoplay=True) - self.render_buttons(output_video) + st.video(output_video, autoplay=True, show_download_button=True) else: st.div() diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index d463d7483..1fc0f6c64 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -129,8 +129,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]: def render_example(self, state: dict): output_video = state.get("output_video") if output_video: - st.video(output_video, caption="#### Output Video", autoplay=True) - self.render_buttons(output_video) + st.video( + output_video, + caption="#### Output Video", + autoplay=True, + show_download_button=True, + ) else: st.div() diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index c2c3b29c4..a1e0c2449 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -201,8 +201,7 @@ def render_output(self): if output_images: for url in output_images: - st.image(url, caption=f"{text_prompt}") - self.render_buttons(url) + st.image(url, caption=f"{text_prompt}", show_download_button=True) else: st.div() diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 1fe283825..1512e4523 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -436,7 +436,7 @@ def _render_outputs(self, state: dict, max_count: int | None = None): if max_count: output_images = output_images[:max_count] for img in output_images: - st.image(img) + st.image(img, show_download_button=True) qr_code_data = ( state.get(QrSources.qr_code_data.name) or state.get(QrSources.qr_code_input_image.name) @@ -458,7 +458,6 @@ def _render_outputs(self, state: dict, max_count: int | None = None): st.caption(f"{shortened_url} → {qr_code_data} (Views: {clicks})") else: st.caption(f"{shortened_url} → {qr_code_data}") - self.render_buttons(img) def run(self, state: dict) -> typing.Iterator[str | None]: request: QRCodeGeneratorPage.RequestModel = self.RequestModel.parse_obj(state) diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 4d68b8eb5..f7ddf5aab 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -128,7 +128,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) def render_output(self): - _render_output(self, st.session_state) + _render_output(st.session_state) def render_example(self, state: dict): col1, col2 = st.columns(2) @@ -141,10 +141,11 @@ def preview_description(self, state: dict) -> str: return "Generate AI Music with text instruction prompts. AudiLDM is capable of generating realistic audio samples by process any text input. Learn more [here](https://huggingface.co/cvssp/audioldm-m-full)." -def _render_output(self, state): +def _render_output(state): selected_models = state.get("selected_models", []) for key in selected_models: output: dict = state.get("output_audios", {}).get(key, []) for audio in output: - st.audio(audio, caption=Text2AudioModels[key].value) - self.render_buttons(audio) + st.audio( + audio, caption=Text2AudioModels[key].value, show_download_button=True + ) diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index b90bb8a4b..7494c992d 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -131,13 +131,8 @@ def render_usage_guide(self): # loom_video("2d853b7442874b9cbbf3f27b98594add") def render_output(self): - text_prompt = st.session_state.get("text_prompt", "") audio_url = st.session_state.get("audio_url") - if audio_url: - st.audio(audio_url) - self.render_buttons(audio_url) - else: - st.div() + st.audio(audio_url, show_download_button=True) def _get_elevenlabs_price(self, state: dict): _, is_user_provided_key = self._get_elevenlabs_api_key(state) From dc7056900c6c4569a0e6a204f501f7fb92940c07 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 12 Feb 2024 13:04:00 +0530 Subject: [PATCH 73/85] fix expand gdrive folders and accept_multiple_files clash --- daras_ai_v2/doc_search_settings_widgets.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index a64572b99..fb69f22a9 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -78,11 +78,12 @@ def document_uploader( ) st.checkbox("Submit Links in Bulk", key=f"__custom_checkbox_{key}") documents = st.session_state.get(key, []) - try: - documents = list(_expand_gdrive_folders(documents)) - except Exception as e: - capture_exception(e) - st.error(f"Error expanding gdrive folders: {e}") + if accept_multiple_files: + try: + documents = list(_expand_gdrive_folders(documents)) + except Exception as e: + capture_exception(e) + st.error(f"Error expanding gdrive folders: {e}") st.session_state[key] = documents st.session_state[custom_key] = "\n".join(documents) return documents From 6cd5617e898a57528b15eaf1440d37a5e6ca11b0 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 12 Feb 2024 13:24:47 +0530 Subject: [PATCH 74/85] add is_api_call to savedrun show api call in breadcrumbs --- bots/migrations/0059_savedrun_is_api_call.py | 18 +++++++++++++++ bots/models.py | 2 ++ celeryapp/tasks.py | 1 + daras_ai_v2/base.py | 23 ++++++++----------- daras_ai_v2/breadcrumbs.py | 5 +++- routers/api.py | 2 +- .../0004_alter_modelpricing_model_name.py | 18 +++++++++++++++ 7 files changed, 54 insertions(+), 15 deletions(-) create mode 100644 bots/migrations/0059_savedrun_is_api_call.py create mode 100644 usage_costs/migrations/0004_alter_modelpricing_model_name.py diff --git a/bots/migrations/0059_savedrun_is_api_call.py b/bots/migrations/0059_savedrun_is_api_call.py new file mode 100644 index 000000000..dc93057fc --- /dev/null +++ b/bots/migrations/0059_savedrun_is_api_call.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-12 07:54 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0058_alter_savedrun_unique_together_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='savedrun', + name='is_api_call', + field=models.BooleanField(default=False), + ), + ] diff --git a/bots/models.py b/bots/models.py index 558512e8c..e2c66148a 100644 --- a/bots/models.py +++ b/bots/models.py @@ -254,6 +254,8 @@ class SavedRun(models.Model): page_title = models.TextField(default="", blank=True, help_text="(Deprecated)") page_notes = models.TextField(default="", blank=True, help_text="(Deprecated)") + is_api_call = models.BooleanField(default=False) + objects = SavedRunQuerySet.as_manager() class Meta: diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 1868c7a18..5e1690b31 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -32,6 +32,7 @@ def gui_runner( ): page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id))) sr = page.run_doc_sr(run_id, uid) + sr.is_api_call = is_api_call st.set_session_state(state) run_time = 0 diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 0a885bf0a..906b73a08 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -243,7 +243,7 @@ def _render_header(self): if tbreadcrumbs: with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): - render_breadcrumbs(tbreadcrumbs) + render_breadcrumbs(tbreadcrumbs, current_run.is_api_call) author = self.run_user or current_run.get_creator() if not is_root_example: @@ -996,9 +996,7 @@ def get_or_create_root_published_run(cls) -> PublishedRun: workflow=cls.workflow, published_run_id="", defaults={ - "saved_run": lambda: cls.run_doc_sr( - run_id="", uid="", create=True, parent=None, parent_version=None - ), + "saved_run": lambda: cls.run_doc_sr(run_id="", uid="", create=True), "created_by": None, "last_edited_by": None, "title": cls.title, @@ -1022,15 +1020,11 @@ def run_doc_sr( run_id: str, uid: str, create: bool = False, - parent: SavedRun | None = None, - parent_version: PublishedRunVersion | None = None, + defaults: dict = None, ) -> SavedRun: config = dict(workflow=cls.workflow, uid=uid, run_id=run_id) if create: - return SavedRun.objects.get_or_create( - **config, - defaults=dict(parent=parent, parent_version=parent_version), - )[0] + return SavedRun.objects.get_or_create(**config, defaults=defaults)[0] else: return SavedRun.objects.get(**config) @@ -1408,7 +1402,7 @@ def should_submit_after_login(self) -> bool: and not self.request.user.is_anonymous ) - def create_new_run(self): + def create_new_run(self, is_api_call: bool = False): st.session_state[StateKeys.run_status] = "Starting..." st.session_state.pop(StateKeys.error_msg, None) st.session_state.pop(StateKeys.run_time, None) @@ -1443,8 +1437,11 @@ def create_new_run(self): run_id, uid, create=True, - parent=parent, - parent_version=parent_version, + defaults=dict( + parent=parent, + parent_version=parent_version, + is_api_call=is_api_call, + ), ).set(self.state_to_doc(st.session_state)) return None, run_id, uid diff --git a/daras_ai_v2/breadcrumbs.py b/daras_ai_v2/breadcrumbs.py index 96dc919e5..62ea058dd 100644 --- a/daras_ai_v2/breadcrumbs.py +++ b/daras_ai_v2/breadcrumbs.py @@ -31,7 +31,7 @@ def has_breadcrumbs(self): return bool(self.root_title or self.published_title) -def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs): +def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs, is_api_call: bool = False): st.html( """