From 70555b5f5fcdacec4c399e9f0002f23e961ea09b Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Wed, 13 Dec 2023 08:30:29 -0800 Subject: [PATCH 01/58] Added cost table --- .../0049_alter_savedrun_workflow.py | 52 +++++++++++++++++++ cost/__init__.py | 0 cost/admin.py | 18 +++++++ cost/apps.py | 7 +++ cost/migrations/__init__.py | 0 cost/models.py | 31 +++++++++++ cost/tests.py | 3 ++ cost/views.py | 3 ++ daras_ai_v2/settings.py | 1 + 9 files changed, 115 insertions(+) create mode 100644 bots/migrations/0049_alter_savedrun_workflow.py create mode 100644 cost/__init__.py create mode 100644 cost/admin.py create mode 100644 cost/apps.py create mode 100644 cost/migrations/__init__.py create mode 100644 cost/models.py create mode 100644 cost/tests.py create mode 100644 cost/views.py diff --git a/bots/migrations/0049_alter_savedrun_workflow.py b/bots/migrations/0049_alter_savedrun_workflow.py new file mode 100644 index 000000000..82780828c --- /dev/null +++ b/bots/migrations/0049_alter_savedrun_workflow.py @@ -0,0 +1,52 @@ +# Generated by Django 4.2.5 on 2023-12-12 08:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0048_alter_messageattachment_url"), + ] + + operations = [ + migrations.AlterField( + model_name="savedrun", + name="workflow", + field=models.IntegerField( + choices=[ + (1, "Doc Search"), + (2, "Doc Summary"), + (3, "Google GPT"), + (4, "Copilot"), + (5, "Lipysnc + TTS"), + (6, "Text to Speech"), + (7, "Speech Recognition"), + (8, "Lipsync"), + (9, "Deforum Animation"), + (10, "Compare Text2Img"), + (11, "Text2Audio"), + (12, "Img2Img"), + (13, "Face Inpainting"), + (14, "Google Image Gen"), + (15, "Compare AI Upscalers"), + (16, "SEO Summary"), + (17, "Email Face Inpainting"), + (18, "Social Lookup Email"), + (19, "Object Inpainting"), + (20, "Image Segmentation"), + (21, "Compare LLM"), + (22, "Chyron Plant"), + (23, "Letter Writer"), + (24, "Smart GPT"), + (25, "AI QR Code"), + (26, "Doc Extract"), + (27, "Related QnA Maker"), + (28, "Related QnA Maker Doc"), + (29, "Embeddings"), + (30, "Bulk Runner"), + (31, "Bulk Evaluator"), + ], + default=4, + ), + ), + ] diff --git a/cost/__init__.py b/cost/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cost/admin.py b/cost/admin.py new file mode 100644 index 000000000..a7ae9e50f --- /dev/null +++ b/cost/admin.py @@ -0,0 +1,18 @@ +from django.contrib import admin +from cost import models + +# Register your models here. + + +@admin.register(models.Cost) +class CostAdmin(admin.ModelAdmin): + list_display = [ + "run_id", + "provider", + "model", + "param", + "notes", + "quantity", + "calculation_notes", + "cost", + ] diff --git a/cost/apps.py b/cost/apps.py new file mode 100644 index 000000000..72b24d71a --- /dev/null +++ b/cost/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class CostConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "cost" + verbose_name = "Cost" diff --git a/cost/migrations/__init__.py b/cost/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cost/models.py b/cost/models.py new file mode 100644 index 000000000..9207f4b94 --- /dev/null +++ b/cost/models.py @@ -0,0 +1,31 @@ +import requests +from django.db import models, IntegrityError, transaction +from django.utils import timezone +from firebase_admin import auth + +from bots.custom_fields import CustomURLField +from daras_ai.image_input import upload_file_from_bytes, guess_ext_from_response +from daras_ai_v2 import settings, db +from gooeysite.bg_db_conn import db_middleware + + +# class CostQuerySet(models.QuerySet): + + +class Cost(models.Model): + run_id = models.ForeignKey( + "bots.SavedRun", + on_delete=models.SET_NULL, + related_name="cost", + null=True, + default=None, + blank=True, + help_text="The run that was last saved by the user.", + ) + provider = models.CharField(max_length=255, default="", blank=True) + model = models.TextField("model", blank=True) + param = models.TextField("param", blank=True) + notes = models.TextField(default="", blank=True) + quantity = models.IntegerField(default=1) + calculation_notes = models.TextField(default="", blank=True) + cost = models.FloatField(default=0.0) diff --git a/cost/tests.py b/cost/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/cost/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/cost/views.py b/cost/views.py new file mode 100644 index 000000000..91ea44a21 --- /dev/null +++ b/cost/views.py @@ -0,0 +1,3 @@ +from django.shortcuts import render + +# Create your views here. diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index a13565115..12de870a7 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -55,6 +55,7 @@ "files", "url_shortener", "glossary_resources", + "cost", ] MIDDLEWARE = [ From ea4542b34b8a1a79769f6f06f884c3208c8c949e Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 14 Dec 2023 08:39:20 -0800 Subject: [PATCH 02/58] Initial table --- cost/migrations/0001_initial.py | 48 +++++++++++++++++++++++++++++++++ cost/models.py | 1 + 2 files changed, 49 insertions(+) create mode 100644 cost/migrations/0001_initial.py diff --git a/cost/migrations/0001_initial.py b/cost/migrations/0001_initial.py new file mode 100644 index 000000000..dddbf1dfc --- /dev/null +++ b/cost/migrations/0001_initial.py @@ -0,0 +1,48 @@ +# Generated by Django 4.2.5 on 2023-12-13 20:14 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ("bots", "0049_alter_savedrun_workflow"), + ] + + operations = [ + migrations.CreateModel( + name="Cost", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("provider", models.CharField(blank=True, default="", max_length=255)), + ("model", models.TextField(blank=True, verbose_name="model")), + ("param", models.TextField(blank=True, verbose_name="param")), + ("notes", models.TextField(blank=True, default="")), + ("quantity", models.IntegerField(default=1)), + ("calculation_notes", models.TextField(blank=True, default="")), + ("cost", models.FloatField(default=0.0)), + ( + "run_id", + models.ForeignKey( + blank=True, + default=None, + help_text="The run that was last saved by the user.", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="cost", + to="bots.savedrun", + ), + ), + ], + ), + ] diff --git a/cost/models.py b/cost/models.py index 9207f4b94..19e20cf05 100644 --- a/cost/models.py +++ b/cost/models.py @@ -29,3 +29,4 @@ class Cost(models.Model): quantity = models.IntegerField(default=1) calculation_notes = models.TextField(default="", blank=True) cost = models.FloatField(default=0.0) + input = models. From d57b5c1488432b2c994c62001b8842761e3d8d02 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 14 Dec 2023 12:41:20 -0800 Subject: [PATCH 03/58] Incorrect names --- app_users/models.py | 1 + cost/models.py | 32 -------------------- {cost => costs}/__init__.py | 0 {cost => costs}/admin.py | 6 ++-- {cost => costs}/apps.py | 6 ++-- {cost => costs}/migrations/0001_initial.py | 0 {cost => costs}/migrations/__init__.py | 0 costs/models.py | 35 ++++++++++++++++++++++ {cost => costs}/tests.py | 0 {cost => costs}/views.py | 0 daras_ai_v2/settings.py | 2 +- 11 files changed, 43 insertions(+), 39 deletions(-) delete mode 100644 cost/models.py rename {cost => costs}/__init__.py (100%) rename {cost => costs}/admin.py (72%) rename {cost => costs}/apps.py (55%) rename {cost => costs}/migrations/0001_initial.py (100%) rename {cost => costs}/migrations/__init__.py (100%) create mode 100644 costs/models.py rename {cost => costs}/tests.py (100%) rename {cost => costs}/views.py (100%) diff --git a/app_users/models.py b/app_users/models.py index 5e373f414..8594a44fb 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -225,6 +225,7 @@ class AppUserTransaction(models.Model): amount = models.IntegerField() end_balance = models.IntegerField() created_at = models.DateTimeField(editable=False, blank=True, default=timezone.now) + dollar_amt = models.DecimalField(default=decimal(0.0), blank=True) class Meta: verbose_name = "Transaction" diff --git a/cost/models.py b/cost/models.py deleted file mode 100644 index 19e20cf05..000000000 --- a/cost/models.py +++ /dev/null @@ -1,32 +0,0 @@ -import requests -from django.db import models, IntegrityError, transaction -from django.utils import timezone -from firebase_admin import auth - -from bots.custom_fields import CustomURLField -from daras_ai.image_input import upload_file_from_bytes, guess_ext_from_response -from daras_ai_v2 import settings, db -from gooeysite.bg_db_conn import db_middleware - - -# class CostQuerySet(models.QuerySet): - - -class Cost(models.Model): - run_id = models.ForeignKey( - "bots.SavedRun", - on_delete=models.SET_NULL, - related_name="cost", - null=True, - default=None, - blank=True, - help_text="The run that was last saved by the user.", - ) - provider = models.CharField(max_length=255, default="", blank=True) - model = models.TextField("model", blank=True) - param = models.TextField("param", blank=True) - notes = models.TextField(default="", blank=True) - quantity = models.IntegerField(default=1) - calculation_notes = models.TextField(default="", blank=True) - cost = models.FloatField(default=0.0) - input = models. diff --git a/cost/__init__.py b/costs/__init__.py similarity index 100% rename from cost/__init__.py rename to costs/__init__.py diff --git a/cost/admin.py b/costs/admin.py similarity index 72% rename from cost/admin.py rename to costs/admin.py index a7ae9e50f..4b3ed79a3 100644 --- a/cost/admin.py +++ b/costs/admin.py @@ -1,11 +1,11 @@ from django.contrib import admin -from cost import models +from costs import models # Register your models here. -@admin.register(models.Cost) -class CostAdmin(admin.ModelAdmin): +@admin.register(models.UsageCost) +class CostsAdmin(admin.ModelAdmin): list_display = [ "run_id", "provider", diff --git a/cost/apps.py b/costs/apps.py similarity index 55% rename from cost/apps.py rename to costs/apps.py index 72b24d71a..55bed0ca1 100644 --- a/cost/apps.py +++ b/costs/apps.py @@ -1,7 +1,7 @@ from django.apps import AppConfig -class CostConfig(AppConfig): +class CostsConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" - name = "cost" - verbose_name = "Cost" + name = "costs" + verbose_name = "Costs" diff --git a/cost/migrations/0001_initial.py b/costs/migrations/0001_initial.py similarity index 100% rename from cost/migrations/0001_initial.py rename to costs/migrations/0001_initial.py diff --git a/cost/migrations/__init__.py b/costs/migrations/__init__.py similarity index 100% rename from cost/migrations/__init__.py rename to costs/migrations/__init__.py diff --git a/costs/models.py b/costs/models.py new file mode 100644 index 000000000..bc75b8bc2 --- /dev/null +++ b/costs/models.py @@ -0,0 +1,35 @@ +from django.db import models + + +# class CostQuerySet(models.QuerySet): + + +class UsageCost(models.Model): + saved_run = models.ForeignKey( + "bots.SavedRun", + on_delete=models.CASCADE, + related_name="usage_costs", + null=True, + default=None, + blank=True, + help_text="The run that was last saved by the user.", + ) + + class Provider(models.IntegerChoices): + OpenAI = 1, "OpenAI" + GPT_4 = 2, "GPT-4" + dalle_e = 3, "dalle-e" + whisper = 4, "whisper" + GPT_3_5 = 5, "GPT-3.5" + + provider = models.IntegerField(choices=Provider.choices) + model = models.TextField("model", blank=True) + param = models.TextField("param", blank=True) # combine with quantity as JSON obj + notes = models.TextField(default="", blank=True) + quantity = models.JSONField(default=dict, blank=True) + calculation_notes = models.TextField(default="", blank=True) + dollar_amt = models.DecimalField( + max_digits=13, + decimal_places=8, + ) + created_at = models.DateTimeField(auto_now_add=True) diff --git a/cost/tests.py b/costs/tests.py similarity index 100% rename from cost/tests.py rename to costs/tests.py diff --git a/cost/views.py b/costs/views.py similarity index 100% rename from cost/views.py rename to costs/views.py diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 12de870a7..812ec32bc 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -55,7 +55,7 @@ "files", "url_shortener", "glossary_resources", - "cost", + "costs", ] MIDDLEWARE = [ From 956ebe14b09e05e7133bab91ad5c18607b3d8607 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Wed, 27 Dec 2023 09:07:35 -0800 Subject: [PATCH 04/58] Change names --- app_users/models.py | 1 - costs/admin.py | 6 +++--- costs/migrations/0001_initial.py | 27 +++++++++++++++++++-------- costs/models.py | 8 +++----- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/app_users/models.py b/app_users/models.py index 8594a44fb..5e373f414 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -225,7 +225,6 @@ class AppUserTransaction(models.Model): amount = models.IntegerField() end_balance = models.IntegerField() created_at = models.DateTimeField(editable=False, blank=True, default=timezone.now) - dollar_amt = models.DecimalField(default=decimal(0.0), blank=True) class Meta: verbose_name = "Transaction" diff --git a/costs/admin.py b/costs/admin.py index 4b3ed79a3..db3485d79 100644 --- a/costs/admin.py +++ b/costs/admin.py @@ -7,12 +7,12 @@ @admin.register(models.UsageCost) class CostsAdmin(admin.ModelAdmin): list_display = [ - "run_id", + "saved_run", "provider", "model", "param", "notes", - "quantity", "calculation_notes", - "cost", + "dollar_amt", + "created_at", ] diff --git a/costs/migrations/0001_initial.py b/costs/migrations/0001_initial.py index dddbf1dfc..4051d1109 100644 --- a/costs/migrations/0001_initial.py +++ b/costs/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.5 on 2023-12-13 20:14 +# Generated by Django 4.2.5 on 2023-12-14 20:52 from django.db import migrations, models import django.db.models.deletion @@ -13,7 +13,7 @@ class Migration(migrations.Migration): operations = [ migrations.CreateModel( - name="Cost", + name="UsageCost", fields=[ ( "id", @@ -24,22 +24,33 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), - ("provider", models.CharField(blank=True, default="", max_length=255)), + ( + "provider", + models.IntegerField( + choices=[ + (1, "OpenAI"), + (2, "GPT-4"), + (3, "dalle-e"), + (4, "whisper"), + (5, "GPT-3.5"), + ] + ), + ), ("model", models.TextField(blank=True, verbose_name="model")), ("param", models.TextField(blank=True, verbose_name="param")), ("notes", models.TextField(blank=True, default="")), - ("quantity", models.IntegerField(default=1)), ("calculation_notes", models.TextField(blank=True, default="")), - ("cost", models.FloatField(default=0.0)), + ("dollar_amt", models.DecimalField(decimal_places=8, max_digits=13)), + ("created_at", models.DateTimeField(auto_now_add=True)), ( - "run_id", + "saved_run", models.ForeignKey( blank=True, default=None, help_text="The run that was last saved by the user.", null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="cost", + on_delete=django.db.models.deletion.CASCADE, + related_name="usage_costs", to="bots.savedrun", ), ), diff --git a/costs/models.py b/costs/models.py index bc75b8bc2..92dcb5ce9 100644 --- a/costs/models.py +++ b/costs/models.py @@ -1,9 +1,6 @@ from django.db import models -# class CostQuerySet(models.QuerySet): - - class UsageCost(models.Model): saved_run = models.ForeignKey( "bots.SavedRun", @@ -24,9 +21,10 @@ class Provider(models.IntegerChoices): provider = models.IntegerField(choices=Provider.choices) model = models.TextField("model", blank=True) - param = models.TextField("param", blank=True) # combine with quantity as JSON obj + param = models.TextField( + "param", blank=True + ) # contains input/output tokens and quantity notes = models.TextField(default="", blank=True) - quantity = models.JSONField(default=dict, blank=True) calculation_notes = models.TextField(default="", blank=True) dollar_amt = models.DecimalField( max_digits=13, From 7cc9f676745c77a7142e5d906ba38e69edc55c6a Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Wed, 3 Jan 2024 20:01:35 -0800 Subject: [PATCH 05/58] Fixed migration errors --- .../0049_alter_savedrun_workflow.py | 52 ------------ celeryapp/tasks.py | 1 + costs/admin.py | 16 ++++ costs/migrations/0001_initial.py | 4 +- costs/migrations/0002_providerpricing.py | 79 +++++++++++++++++++ costs/models.py | 39 +++++++-- daras_ai_v2/language_model.py | 4 + 7 files changed, 134 insertions(+), 61 deletions(-) delete mode 100644 bots/migrations/0049_alter_savedrun_workflow.py create mode 100644 costs/migrations/0002_providerpricing.py diff --git a/bots/migrations/0049_alter_savedrun_workflow.py b/bots/migrations/0049_alter_savedrun_workflow.py deleted file mode 100644 index 82780828c..000000000 --- a/bots/migrations/0049_alter_savedrun_workflow.py +++ /dev/null @@ -1,52 +0,0 @@ -# Generated by Django 4.2.5 on 2023-12-12 08:26 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("bots", "0048_alter_messageattachment_url"), - ] - - operations = [ - migrations.AlterField( - model_name="savedrun", - name="workflow", - field=models.IntegerField( - choices=[ - (1, "Doc Search"), - (2, "Doc Summary"), - (3, "Google GPT"), - (4, "Copilot"), - (5, "Lipysnc + TTS"), - (6, "Text to Speech"), - (7, "Speech Recognition"), - (8, "Lipsync"), - (9, "Deforum Animation"), - (10, "Compare Text2Img"), - (11, "Text2Audio"), - (12, "Img2Img"), - (13, "Face Inpainting"), - (14, "Google Image Gen"), - (15, "Compare AI Upscalers"), - (16, "SEO Summary"), - (17, "Email Face Inpainting"), - (18, "Social Lookup Email"), - (19, "Object Inpainting"), - (20, "Image Segmentation"), - (21, "Compare LLM"), - (22, "Chyron Plant"), - (23, "Letter Writer"), - (24, "Smart GPT"), - (25, "AI QR Code"), - (26, "Doc Extract"), - (27, "Related QnA Maker"), - (28, "Related QnA Maker Doc"), - (29, "Embeddings"), - (30, "Bulk Runner"), - (31, "Bulk Evaluator"), - ], - default=4, - ), - ), - ] diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 1868c7a18..7bd144091 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -8,6 +8,7 @@ import gooey_ui as st from app_users.models import AppUser from bots.models import SavedRun +from costs.models import UsageCost from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings diff --git a/costs/admin.py b/costs/admin.py index db3485d79..12c106a14 100644 --- a/costs/admin.py +++ b/costs/admin.py @@ -16,3 +16,19 @@ class CostsAdmin(admin.ModelAdmin): "dollar_amt", "created_at", ] + + +@admin.register(models.ProviderPricing) +class ProviderAdmin(admin.ModelAdmin): + list_display = [ + "type", + "provider", + "product", + "param", + "cost", + "unit", + "created_at", + "last_updated", + "updated_by", + "pricing_url", + ] diff --git a/costs/migrations/0001_initial.py b/costs/migrations/0001_initial.py index 4051d1109..d3f6d1948 100644 --- a/costs/migrations/0001_initial.py +++ b/costs/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.5 on 2023-12-14 20:52 +# Generated by Django 4.2.5 on 2024-01-04 03:59 from django.db import migrations, models import django.db.models.deletion @@ -8,7 +8,7 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ("bots", "0049_alter_savedrun_workflow"), + ("bots", "0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more"), ] operations = [ diff --git a/costs/migrations/0002_providerpricing.py b/costs/migrations/0002_providerpricing.py new file mode 100644 index 000000000..6b0bb17a1 --- /dev/null +++ b/costs/migrations/0002_providerpricing.py @@ -0,0 +1,79 @@ +# Generated by Django 4.2.5 on 2024-01-04 04:00 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0001_initial"), + ] + + operations = [ + migrations.CreateModel( + name="ProviderPricing", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("type", models.TextField(choices=[("LLM", "Llm")])), + ( + "provider", + models.IntegerField( + choices=[ + (1, "OpenAI"), + (2, "GPT-4"), + (3, "dalle-e"), + (4, "whisper"), + (5, "GPT-3.5"), + ] + ), + ), + ( + "product", + models.TextField( + choices=[ + ("GPT-4 Vision (openai)", "gpt_4_vision"), + ("GPT-4 Turbo (openai)", "gpt_4_turbo"), + ("GPT-4 (openai)", "gpt_4"), + ("GPT-4 32K (openai)", "gpt_4_32k"), + ("ChatGPT (openai)", "gpt_3_5_turbo"), + ("ChatGPT 16k (openai)", "gpt_3_5_turbo_16k"), + ("Llama 2 (Meta AI)", "llama2_70b_chat"), + ("PaLM 2 Text (Google)", "palm2_chat"), + ("PaLM 2 Chat (Google)", "palm2_text"), + ("GPT-3.5 Davinci-3 (openai)", "text_davinci_003"), + ("GPT-3.5 Davinci-2 (openai)", "text_davinci_002"), + ("Curie (openai)", "text_curie_001"), + ("Babbage (openai)", "text_babbage_001"), + ("Ada (openai)", "text_ada_001"), + ("Codex [Deprecated] (openai)", "code_davinci_002"), + ] + ), + ), + ( + "param", + models.TextField( + choices=[ + ("input", "Input"), + ("output", "Output"), + ("input image", "Input Image"), + ("output image", "Output Image"), + ] + ), + ), + ("cost", models.DecimalField(decimal_places=8, max_digits=13)), + ("unit", models.TextField(default="")), + ("notes", models.TextField(default="")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("last_updated", models.DateTimeField(auto_now=True)), + ("updated_by", models.TextField(default="")), + ("pricing_url", models.TextField(default="")), + ], + ), + ] diff --git a/costs/models.py b/costs/models.py index 92dcb5ce9..ae04baf83 100644 --- a/costs/models.py +++ b/costs/models.py @@ -1,4 +1,13 @@ from django.db import models +from daras_ai_v2.language_model import LargeLanguageModels + + +class Provider(models.IntegerChoices): + OpenAI = 1, "OpenAI" + GPT_4 = 2, "GPT-4" + dalle_e = 3, "dalle-e" + whisper = 4, "whisper" + GPT_3_5 = 5, "GPT-3.5" class UsageCost(models.Model): @@ -12,13 +21,6 @@ class UsageCost(models.Model): help_text="The run that was last saved by the user.", ) - class Provider(models.IntegerChoices): - OpenAI = 1, "OpenAI" - GPT_4 = 2, "GPT-4" - dalle_e = 3, "dalle-e" - whisper = 4, "whisper" - GPT_3_5 = 5, "GPT-3.5" - provider = models.IntegerField(choices=Provider.choices) model = models.TextField("model", blank=True) param = models.TextField( @@ -31,3 +33,26 @@ class Provider(models.IntegerChoices): decimal_places=8, ) created_at = models.DateTimeField(auto_now_add=True) + + +class ProviderPricing(models.Model): + class Type(models.TextChoices): + LLM = "LLM" + + class Param(models.TextChoices): + input = "input" + output = "output" + input_image = "input image" + output_image = "output image" + + type = models.TextField(choices=Type.choices) + provider = models.IntegerField(choices=Provider.choices) + product = models.TextField(choices=LargeLanguageModels.choices()) + param = models.TextField(choices=Param.choices) + cost = models.DecimalField(max_digits=13, decimal_places=8) + unit = models.TextField(default="") + notes = models.TextField(default="") + created_at = models.DateTimeField(auto_now_add=True) + last_updated = models.DateTimeField(auto_now=True) + updated_by = models.TextField(default="") + pricing_url = models.TextField(default="") diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index f63c7e01d..77d04d01a 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -65,6 +65,10 @@ class LargeLanguageModels(Enum): code_davinci_002 = "Codex [Deprecated] (openai)" + @classmethod + def choices(cls): + return tuple((e.value, e.name) for e in cls) + @classmethod def _deprecated(cls): return {cls.code_davinci_002} From 9c25a857bbfe36b714a52d2ef92e46f8cdaae7ae Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Wed, 3 Jan 2024 22:40:55 -0800 Subject: [PATCH 06/58] Return token count for openai chat and palm chat --- celeryapp/tasks.py | 1 - daras_ai_v2/language_model.py | 56 ++++++++++++++++++++++++----------- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 7bd144091..1868c7a18 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -8,7 +8,6 @@ import gooey_ui as st from app_users.models import AppUser from bots.models import SavedRun -from costs.models import UsageCost from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 77d04d01a..676b9b9b6 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -350,7 +350,7 @@ def run_language_model( format_chat_entry(role=entry["role"], content=get_entry_text(entry)) for entry in messages ] - result = _run_chat_model( + result, output_token, input_token = _run_chat_model( api=api, model=model_name, messages=messages, # type: ignore @@ -372,10 +372,14 @@ def run_language_model( else (entry.get("content") or "").strip() for entry in result ] + print("out_contentttt", out_content, input_token, output_token) if tools: - return out_content, [(entry.get("tool_calls") or []) for entry in result] + return ( + out_content, + [(entry.get("tool_calls") or []) for entry in result], + ) else: - return out_content + return out_content, input_token, output_token else: if tools: raise ValueError("Only OpenAI chat models support Tools") @@ -442,7 +446,7 @@ def _run_chat_model( avoid_repetition: bool, tools: list[LLMTools] | None, response_format_type: typing.Literal["text", "json_object"] | None, -) -> list[ConversationEntry]: +) -> tuple[list[ConversationEntry], int, int]: match api: case LLMApis.openai: return _run_openai_chat( @@ -493,7 +497,7 @@ def _run_openai_chat( avoid_repetition: bool, tools: list[LLMTools] | None, response_format_type: typing.Literal["text", "json_object"] | None, -) -> list[ConversationEntry]: +) -> tuple[list[ConversationEntry], int, int]: from openai._types import NOT_GIVEN if avoid_repetition: @@ -524,7 +528,11 @@ def _run_openai_chat( for model_str in model ], ) - return [choice.message.dict() for choice in r.choices] + return ( + [choice.message.dict() for choice in r.choices], + r.usage.completion_tokens, + r.usage.prompt_tokens, + ) @retry_if(openai_should_retry) @@ -630,7 +638,7 @@ def _run_palm_chat( max_output_tokens: int, candidate_count: int, temperature: float, -) -> list[ConversationEntry]: +) -> tuple[list[ConversationEntry], int, int]: """ Args: model_id: The model id to use for the request. See available models: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models @@ -670,14 +678,23 @@ def _run_palm_chat( ) r.raise_for_status() - return [ - { - "role": msg["author"], - "content": msg["content"], - } - for pred in r.json()["predictions"] - for msg in pred["candidates"] - ] + print( + "real r.json()", + r.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + ) + + return ( + [ + { + "role": msg["author"], + "content": msg["content"], + } + for pred in r.json()["predictions"] + for msg in pred["candidates"] + ], + r.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + r.json()["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"], + ) @retry_if(vertex_ai_should_retry) @@ -688,7 +705,7 @@ def _run_palm_text( max_output_tokens: int, candidate_count: int, temperature: float, -) -> list[str]: +) -> tuple[list[str], int, int]: """ Args: model_id: The model id to use for the request. See available models: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models @@ -714,7 +731,12 @@ def _run_palm_text( }, ) res.raise_for_status() - return [prediction["content"] for prediction in res.json()["predictions"]] + print("res.json()", res.json()) + return ( + [prediction["content"] for prediction in res.json()["predictions"]], + res.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + res.json()["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"], + ) def format_chatml_message(entry: ConversationEntry) -> str: From d28125ecbc7716efcc99ae422baffce4719618e4 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 4 Jan 2024 10:35:08 -0800 Subject: [PATCH 07/58] Get tokens from response --- daras_ai_v2/language_model.py | 38 +++++++++++++++++++++-------------- recipes/CompareLLM.py | 1 + 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 676b9b9b6..3e9400ff5 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -329,7 +329,11 @@ def run_language_model( avoid_repetition: bool = False, tools: list[LLMTools] = None, response_format_type: typing.Literal["text", "json_object"] = None, -) -> list[str] | tuple[list[str], list[list[dict]]] | list[dict]: +) -> ( + tuple[list[str], int, int] + | tuple[list[str], list[list[dict]], int, int] + | tuple[list[dict], int, int] +): assert bool(prompt) != bool( messages ), "Pleave provide exactly one of { prompt, messages }" @@ -372,19 +376,20 @@ def run_language_model( else (entry.get("content") or "").strip() for entry in result ] - print("out_contentttt", out_content, input_token, output_token) if tools: return ( out_content, [(entry.get("tool_calls") or []) for entry in result], + output_token, + input_token, ) else: - return out_content, input_token, output_token + return out_content, output_token, input_token else: if tools: raise ValueError("Only OpenAI chat models support Tools") logger.info(f"{model_name=}, {len(prompt)=}, {max_tokens=}, {temperature=}") - result = _run_text_model( + result, output_token, input_token = _run_text_model( api=api, model=model_name, prompt=prompt, @@ -395,7 +400,7 @@ def run_language_model( avoid_repetition=avoid_repetition, quality=quality, ) - return [msg.strip() for msg in result] + return [msg.strip() for msg in result], output_token, input_token def _run_text_model( @@ -409,7 +414,7 @@ def _run_text_model( stop: list[str] | None, avoid_repetition: bool, quality: float, -) -> list[str]: +) -> tuple[list[str], int, int]: match api: case LLMApis.openai: return _run_openai_text( @@ -528,6 +533,7 @@ def _run_openai_chat( for model_str in model ], ) + print("entire r", r) return ( [choice.message.dict() for choice in r.choices], r.usage.completion_tokens, @@ -557,7 +563,11 @@ def _run_openai_text( frequency_penalty=0.1 if avoid_repetition else 0, presence_penalty=0.25 if avoid_repetition else 0, ) - return [choice.text for choice in r.choices] + return ( + [choice.text for choice in r.choices], + r.usage.completion_tokens, + r.usage.prompt_tokens, + ) def _get_openai_client(model: str): @@ -586,7 +596,7 @@ def _run_together_chat( temperature: float, repetition_penalty: float, num_outputs: int, -) -> list[ConversationEntry]: +) -> tuple[list[ConversationEntry], int, int]: """ Args: model: The model version to use for the request. @@ -614,11 +624,15 @@ def _run_together_chat( range(num_outputs), ) ret = [] + total_out_tokens = 0 + total_in_tokens = 0 for r in results: r.raise_for_status() data = r.json() output = data["output"] error = output.get("error") + total_out_tokens += output.get("usage", {}).get("completion_tokens", 0) + total_in_tokens += output.get("usage", {}).get("prompt_tokens", 0) if error: raise ValueError(error) ret.append( @@ -627,7 +641,7 @@ def _run_together_chat( "content": output["choices"][0]["text"], } ) - return ret + return ret, total_out_tokens, total_in_tokens @retry_if(vertex_ai_should_retry) @@ -678,11 +692,6 @@ def _run_palm_chat( ) r.raise_for_status() - print( - "real r.json()", - r.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], - ) - return ( [ { @@ -731,7 +740,6 @@ def _run_palm_text( }, ) res.raise_for_status() - print("res.json()", res.json()) return ( [prediction["content"] for prediction in res.json()["predictions"]], res.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 583317ddc..c9ade07d8 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -105,6 +105,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: max_tokens=request.max_tokens, avoid_repetition=request.avoid_repetition, ) + print("final output", output_text) def render_output(self): self._render_outputs(st.session_state, 450) From 20e4fba79a1e9d2ba97bb2bdbb0af237d05c3eb4 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Fri, 5 Jan 2024 19:11:00 -0800 Subject: [PATCH 08/58] Saving LLM works --- celeryapp/tasks.py | 1 + costs/admin.py | 7 +- costs/cost_utils.py | 36 +++++++++++ ...t_model_remove_usagecost_param_and_more.py | 64 +++++++++++++++++++ .../0004_remove_usagecost_saved_run.py | 16 +++++ .../0005_remove_usagecost_provider_pricing.py | 16 +++++ ...st_provider_pricing_usagecost_saved_run.py | 38 +++++++++++ .../0007_alter_providerpricing_provider.py | 25 ++++++++ ...08_alter_providerpricing_param_and_more.py | 33 ++++++++++ .../0009_alter_providerpricing_product.py | 19 ++++++ ...0010_remove_usagecost_calculation_notes.py | 16 +++++ costs/models.py | 46 +++++++------ daras_ai_v2/language_model.py | 45 +++++++++---- recipes/CompareLLM.py | 1 - 14 files changed, 326 insertions(+), 37 deletions(-) create mode 100644 costs/cost_utils.py create mode 100644 costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py create mode 100644 costs/migrations/0004_remove_usagecost_saved_run.py create mode 100644 costs/migrations/0005_remove_usagecost_provider_pricing.py create mode 100644 costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py create mode 100644 costs/migrations/0007_alter_providerpricing_provider.py create mode 100644 costs/migrations/0008_alter_providerpricing_param_and_more.py create mode 100644 costs/migrations/0009_alter_providerpricing_product.py create mode 100644 costs/migrations/0010_remove_usagecost_calculation_notes.py diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 1868c7a18..7bd144091 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -8,6 +8,7 @@ import gooey_ui as st from app_users.models import AppUser from bots.models import SavedRun +from costs.models import UsageCost from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings diff --git a/costs/admin.py b/costs/admin.py index 12c106a14..f3d99d472 100644 --- a/costs/admin.py +++ b/costs/admin.py @@ -8,11 +8,9 @@ class CostsAdmin(admin.ModelAdmin): list_display = [ "saved_run", - "provider", - "model", - "param", + "provider_pricing", + "quantity", "notes", - "calculation_notes", "dollar_amt", "created_at", ] @@ -27,6 +25,7 @@ class ProviderAdmin(admin.ModelAdmin): "param", "cost", "unit", + "notes", "created_at", "last_updated", "updated_by", diff --git a/costs/cost_utils.py b/costs/cost_utils.py new file mode 100644 index 000000000..c04a30c92 --- /dev/null +++ b/costs/cost_utils.py @@ -0,0 +1,36 @@ +from costs.models import UsageCost, ProviderPricing +from bots.models import SavedRun +from django.forms.models import model_to_dict + + +def get_provider_pricing( + type: str, + provider: str, + product: str, + param: str, +) -> ProviderPricing: + return ProviderPricing.objects.get( + type=type, + provider=provider, + product=product, + param=param, + ) + + +def record_cost( + run_id: str | None, + uid: str | None, + provider_pricing: ProviderPricing, + quantity: int, +) -> UsageCost: + saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) + cost = UsageCost( + saved_run=saved_run, + provider_pricing=provider_pricing, + quantity=quantity, + notes="", + dollar_amt=provider_pricing.cost * quantity, + created_at=saved_run.created_at, + ) + cost.save() + return cost diff --git a/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py b/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py new file mode 100644 index 000000000..4aa89cadd --- /dev/null +++ b/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py @@ -0,0 +1,64 @@ +# Generated by Django 4.2.5 on 2024-01-05 06:02 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0002_providerpricing"), + ] + + operations = [ + migrations.RemoveField( + model_name="usagecost", + name="model", + ), + migrations.RemoveField( + model_name="usagecost", + name="param", + ), + migrations.RemoveField( + model_name="usagecost", + name="provider", + ), + migrations.AddField( + model_name="usagecost", + name="provider_pricing", + field=models.ForeignKey( + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="usage_costs", + to="costs.providerpricing", + ), + ), + migrations.AddField( + model_name="usagecost", + name="quantity", + field=models.IntegerField(default=1), + ), + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("GPT-4 Vision (openai)", "gpt_4_vision"), + ("GPT-4 Turbo (openai)", "gpt_4_turbo"), + ("GPT-4 (openai)", "gpt_4"), + ("GPT-4 32K (openai)", "gpt_4_32k"), + ("ChatGPT (openai)", "gpt_3_5_turbo"), + ("ChatGPT 16k (openai)", "gpt_3_5_turbo_16k"), + ("Llama 2 (Meta AI)", "llama2_70b_chat"), + ("PaLM 2 Chat (Google)", "palm2_chat"), + ("PaLM 2 Text (Google)", "palm2_text"), + ("GPT-3.5 Davinci-3 (openai)", "text_davinci_003"), + ("GPT-3.5 Davinci-2 (openai)", "text_davinci_002"), + ("Curie (openai)", "text_curie_001"), + ("Babbage (openai)", "text_babbage_001"), + ("Ada (openai)", "text_ada_001"), + ("Codex [Deprecated] (openai)", "code_davinci_002"), + ] + ), + ), + ] diff --git a/costs/migrations/0004_remove_usagecost_saved_run.py b/costs/migrations/0004_remove_usagecost_saved_run.py new file mode 100644 index 000000000..2e68a519e --- /dev/null +++ b/costs/migrations/0004_remove_usagecost_saved_run.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.5 on 2024-01-05 07:33 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0003_remove_usagecost_model_remove_usagecost_param_and_more"), + ] + + operations = [ + migrations.RemoveField( + model_name="usagecost", + name="saved_run", + ), + ] diff --git a/costs/migrations/0005_remove_usagecost_provider_pricing.py b/costs/migrations/0005_remove_usagecost_provider_pricing.py new file mode 100644 index 000000000..2ff682ab0 --- /dev/null +++ b/costs/migrations/0005_remove_usagecost_provider_pricing.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.5 on 2024-01-05 08:06 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0004_remove_usagecost_saved_run"), + ] + + operations = [ + migrations.RemoveField( + model_name="usagecost", + name="provider_pricing", + ), + ] diff --git a/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py b/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py new file mode 100644 index 000000000..ba207c963 --- /dev/null +++ b/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py @@ -0,0 +1,38 @@ +# Generated by Django 4.2.5 on 2024-01-05 08:06 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("bots", "0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more"), + ("costs", "0005_remove_usagecost_provider_pricing"), + ] + + operations = [ + migrations.AddField( + model_name="usagecost", + name="provider_pricing", + field=models.ForeignKey( + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="usage_costs", + to="costs.providerpricing", + ), + ), + migrations.AddField( + model_name="usagecost", + name="saved_run", + field=models.ForeignKey( + blank=True, + default=None, + help_text="The run that was last saved by the user.", + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="usage_costs", + to="bots.savedrun", + ), + ), + ] diff --git a/costs/migrations/0007_alter_providerpricing_provider.py b/costs/migrations/0007_alter_providerpricing_provider.py new file mode 100644 index 000000000..e6220e391 --- /dev/null +++ b/costs/migrations/0007_alter_providerpricing_provider.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.5 on 2024-01-06 00:51 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0006_usagecost_provider_pricing_usagecost_saved_run"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="provider", + field=models.TextField( + choices=[ + ("OpenAI", "Openai"), + ("GPT-4", "Gpt 4"), + ("dalle-e", "Dalle E"), + ("whisper", "Whisper"), + ("GPT-3.5", "Gpt 3 5"), + ] + ), + ), + ] diff --git a/costs/migrations/0008_alter_providerpricing_param_and_more.py b/costs/migrations/0008_alter_providerpricing_param_and_more.py new file mode 100644 index 000000000..bcc9a0030 --- /dev/null +++ b/costs/migrations/0008_alter_providerpricing_param_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 4.2.5 on 2024-01-06 01:10 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0007_alter_providerpricing_provider"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="param", + field=models.TextField(choices=[("Input", "input"), ("Output", "output")]), + ), + migrations.AlterField( + model_name="providerpricing", + name="provider", + field=models.TextField( + choices=[ + ("vertex_ai", "Vertex AI"), + ("openai", "OpenAI"), + ("together", "Together"), + ] + ), + ), + migrations.AlterField( + model_name="providerpricing", + name="type", + field=models.TextField(choices=[("LLM", "LLM")]), + ), + ] diff --git a/costs/migrations/0009_alter_providerpricing_product.py b/costs/migrations/0009_alter_providerpricing_product.py new file mode 100644 index 000000000..5e5a51019 --- /dev/null +++ b/costs/migrations/0009_alter_providerpricing_product.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.5 on 2024-01-06 01:28 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0008_alter_providerpricing_param_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[("gpt-4-vision-preview", "gpt-4-vision-preview")] + ), + ), + ] diff --git a/costs/migrations/0010_remove_usagecost_calculation_notes.py b/costs/migrations/0010_remove_usagecost_calculation_notes.py new file mode 100644 index 000000000..3eba7ace1 --- /dev/null +++ b/costs/migrations/0010_remove_usagecost_calculation_notes.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.5 on 2024-01-06 02:02 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0009_alter_providerpricing_product"), + ] + + operations = [ + migrations.RemoveField( + model_name="usagecost", + name="calculation_notes", + ), + ] diff --git a/costs/models.py b/costs/models.py index ae04baf83..cc997eb83 100644 --- a/costs/models.py +++ b/costs/models.py @@ -2,12 +2,17 @@ from daras_ai_v2.language_model import LargeLanguageModels -class Provider(models.IntegerChoices): - OpenAI = 1, "OpenAI" - GPT_4 = 2, "GPT-4" - dalle_e = 3, "dalle-e" - whisper = 4, "whisper" - GPT_3_5 = 5, "GPT-3.5" +class Provider(models.TextChoices): + vertex_ai = "vertex_ai", "Vertex AI" + openai = "openai", "OpenAI" + together = "together", "Together" + + +class Product(models.TextChoices): + gpt_4_vision = ( + "gpt-4-vision-preview", + "gpt-4-vision-preview", + ) class UsageCost(models.Model): @@ -21,13 +26,15 @@ class UsageCost(models.Model): help_text="The run that was last saved by the user.", ) - provider = models.IntegerField(choices=Provider.choices) - model = models.TextField("model", blank=True) - param = models.TextField( - "param", blank=True - ) # contains input/output tokens and quantity + provider_pricing = models.ForeignKey( + "costs.ProviderPricing", + on_delete=models.CASCADE, + related_name="usage_costs", + null=True, + default=None, + ) + quantity = models.IntegerField(default=1) notes = models.TextField(default="", blank=True) - calculation_notes = models.TextField(default="", blank=True) dollar_amt = models.DecimalField( max_digits=13, decimal_places=8, @@ -37,17 +44,15 @@ class UsageCost(models.Model): class ProviderPricing(models.Model): class Type(models.TextChoices): - LLM = "LLM" + LLM = "LLM", "LLM" class Param(models.TextChoices): - input = "input" - output = "output" - input_image = "input image" - output_image = "output image" + input = "Input", "input" + output = "Output", "output" type = models.TextField(choices=Type.choices) - provider = models.IntegerField(choices=Provider.choices) - product = models.TextField(choices=LargeLanguageModels.choices()) + provider = models.TextField(choices=Provider.choices) + product = models.TextField(choices=Product.choices) param = models.TextField(choices=Param.choices) cost = models.DecimalField(max_digits=13, decimal_places=8) unit = models.TextField(default="") @@ -56,3 +61,6 @@ class Param(models.TextChoices): last_updated = models.DateTimeField(auto_now=True) updated_by = models.TextField(default="") pricing_url = models.TextField(default="") + + def __str__(self): + return self.type + " " + self.provider + " " + self.product + " " + self.param diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index d6e93da89..f78998de0 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -27,6 +27,8 @@ get_redis_cache, ) from daras_ai_v2.text_splitter import default_length_function +from daras_ai_v2.query_params_util import extract_query_params +from daras_ai_v2.query_params import gooey_get_query_params DEFAULT_SYSTEM_MSG = "You are an intelligent AI assistant. Follow the instructions as closely as possible." @@ -65,10 +67,6 @@ class LargeLanguageModels(Enum): code_davinci_002 = "Codex [Deprecated] (openai)" - @classmethod - def choices(cls): - return tuple((e.value, e.name) for e in cls) - @classmethod def _deprecated(cls): return {cls.code_davinci_002} @@ -329,11 +327,9 @@ def run_language_model( avoid_repetition: bool = False, tools: list[LLMTools] = None, response_format_type: typing.Literal["text", "json_object"] = None, -) -> ( - tuple[list[str], int, int] - | tuple[list[str], list[list[dict]], int, int] - | tuple[list[dict], int, int] -): +) -> list[str] | tuple[list[str], list[list[dict]]] | list[dict]: + from costs.cost_utils import record_cost, get_provider_pricing + assert bool(prompt) != bool( messages ), "Pleave provide exactly one of { prompt, messages }" @@ -380,11 +376,35 @@ def run_language_model( return ( out_content, [(entry.get("tool_calls") or []) for entry in result], - output_token, - input_token, ) else: - return out_content, output_token, input_token + provider_pricing_in = get_provider_pricing( + type="LLM", + provider=api.name, + product=model_name, + param="Input", + ) + + provider_pricing_out = get_provider_pricing( + type="LLM", + provider=api.name, + product=model_name, + param="Output", + ) + example_id, run_id, uid = extract_query_params(gooey_get_query_params()) + record_cost( + run_id=run_id, + uid=uid, + provider_pricing=provider_pricing_in, + quantity=input_token, + ) + record_cost( + run_id=run_id, + uid=uid, + provider_pricing=provider_pricing_out, + quantity=output_token, + ) + return out_content else: if tools: raise ValueError("Only OpenAI chat models support Tools") @@ -533,7 +553,6 @@ def _run_openai_chat( for model_str in model ], ) - print("entire r", r) return ( [choice.message.dict() for choice in r.choices], r.usage.completion_tokens, diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index c9ade07d8..583317ddc 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -105,7 +105,6 @@ def run(self, state: dict) -> typing.Iterator[str | None]: max_tokens=request.max_tokens, avoid_repetition=request.avoid_repetition, ) - print("final output", output_text) def render_output(self): self._render_outputs(st.session_state, 450) From ee0e058c72a1aafced67c060285f44e95d827236 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Mon, 8 Jan 2024 09:02:21 -0800 Subject: [PATCH 09/58] Temp --- costs/cost_utils.py | 1 - costs/models.py | 12 ++++++------ daras_ai_v2/language_model.py | 4 ++++ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/costs/cost_utils.py b/costs/cost_utils.py index c04a30c92..f7799a970 100644 --- a/costs/cost_utils.py +++ b/costs/cost_utils.py @@ -1,6 +1,5 @@ from costs.models import UsageCost, ProviderPricing from bots.models import SavedRun -from django.forms.models import model_to_dict def get_provider_pricing( diff --git a/costs/models.py b/costs/models.py index cc997eb83..ece1500f8 100644 --- a/costs/models.py +++ b/costs/models.py @@ -1,11 +1,11 @@ from django.db import models -from daras_ai_v2.language_model import LargeLanguageModels +from daras_ai_v2.language_model import LLMApis -class Provider(models.TextChoices): - vertex_ai = "vertex_ai", "Vertex AI" - openai = "openai", "OpenAI" - together = "together", "Together" +# class Provider(models.TextChoices): +# vertex_ai = "vertex_ai", "Vertex AI" +# openai = "openai", "OpenAI" +# together = "together", "Together" class Product(models.TextChoices): @@ -51,7 +51,7 @@ class Param(models.TextChoices): output = "Output", "output" type = models.TextField(choices=Type.choices) - provider = models.TextField(choices=Provider.choices) + provider = models.TextField(choices=LLMApis.choices()) product = models.TextField(choices=Product.choices) param = models.TextField(choices=Param.choices) cost = models.DecimalField(max_digits=13, decimal_places=8) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index f78998de0..d1a4eb184 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -45,6 +45,10 @@ class LLMApis(Enum): openai = "OpenAI" together = "Together" + @classmethod + def choices(cls): + return tuple((api.name, api.value) for api in cls) + class LargeLanguageModels(Enum): gpt_4_vision = "GPT-4 Vision (openai)" From a5a64f24e8f1f6e238b2e36b3fc78cc89a714622 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 11 Jan 2024 01:26:36 -0800 Subject: [PATCH 10/58] Added choices --- costs/cost_utils.py | 1 + .../0011_alter_providerpricing_product.py | 35 +++++++++++ .../0012_alter_providerpricing_product.py | 37 +++++++++++ .../0013_alter_providerpricing_product.py | 37 +++++++++++ .../0014_alter_providerpricing_product.py | 40 ++++++++++++ .../0015_alter_providerpricing_product.py | 43 +++++++++++++ costs/models.py | 63 ++++++++++++++++--- 7 files changed, 248 insertions(+), 8 deletions(-) create mode 100644 costs/migrations/0011_alter_providerpricing_product.py create mode 100644 costs/migrations/0012_alter_providerpricing_product.py create mode 100644 costs/migrations/0013_alter_providerpricing_product.py create mode 100644 costs/migrations/0014_alter_providerpricing_product.py create mode 100644 costs/migrations/0015_alter_providerpricing_product.py diff --git a/costs/cost_utils.py b/costs/cost_utils.py index f7799a970..25f6ec28b 100644 --- a/costs/cost_utils.py +++ b/costs/cost_utils.py @@ -8,6 +8,7 @@ def get_provider_pricing( product: str, param: str, ) -> ProviderPricing: + print("get_provider_pricing", type, provider, product, param) return ProviderPricing.objects.get( type=type, provider=provider, diff --git a/costs/migrations/0011_alter_providerpricing_product.py b/costs/migrations/0011_alter_providerpricing_product.py new file mode 100644 index 000000000..b49d3f308 --- /dev/null +++ b/costs/migrations/0011_alter_providerpricing_product.py @@ -0,0 +1,35 @@ +# Generated by Django 4.2.5 on 2024-01-10 08:50 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0010_remove_usagecost_calculation_notes"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ("gpt-4-1106-preview", "gpt-4-1106-preview"), + ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/migrations/0012_alter_providerpricing_product.py b/costs/migrations/0012_alter_providerpricing_product.py new file mode 100644 index 000000000..3bc92c01c --- /dev/null +++ b/costs/migrations/0012_alter_providerpricing_product.py @@ -0,0 +1,37 @@ +# Generated by Django 4.2.5 on 2024-01-10 16:45 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0011_alter_providerpricing_product"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ("openai-gpt-4-turbo-prod-ca-1", "openai-gpt-4-turbo-prod-ca-1"), + ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), + ("gpt-3.5-turbo", "gpt-3.5-turbo"), + ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/migrations/0013_alter_providerpricing_product.py b/costs/migrations/0013_alter_providerpricing_product.py new file mode 100644 index 000000000..9449fbf91 --- /dev/null +++ b/costs/migrations/0013_alter_providerpricing_product.py @@ -0,0 +1,37 @@ +# Generated by Django 4.2.5 on 2024-01-10 16:50 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0012_alter_providerpricing_product"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ("openai-gpt-4-turbo-prod-ca-1", "gpt-4-1106-preview"), + ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), + ("gpt-3.5-turbo", "gpt-3.5-turbo"), + ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/migrations/0014_alter_providerpricing_product.py b/costs/migrations/0014_alter_providerpricing_product.py new file mode 100644 index 000000000..22b5439ed --- /dev/null +++ b/costs/migrations/0014_alter_providerpricing_product.py @@ -0,0 +1,40 @@ +# Generated by Django 4.2.5 on 2024-01-10 16:58 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0013_alter_providerpricing_product"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ( + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + ), + ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), + ("gpt-3.5-turbo", "gpt-3.5-turbo"), + ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/migrations/0015_alter_providerpricing_product.py b/costs/migrations/0015_alter_providerpricing_product.py new file mode 100644 index 000000000..2658de04f --- /dev/null +++ b/costs/migrations/0015_alter_providerpricing_product.py @@ -0,0 +1,43 @@ +# Generated by Django 4.2.5 on 2024-01-10 17:02 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("costs", "0014_alter_providerpricing_product"), + ] + + operations = [ + migrations.AlterField( + model_name="providerpricing", + name="product", + field=models.TextField( + choices=[ + ("gpt-4-vision-preview", "gpt-4-vision-preview"), + ( + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + ), + ( + "('openai-gpt-4-prod-ca-1', 'gpt-4')", + "('openai-gpt-4-prod-ca-1', 'gpt-4')", + ), + ("gpt-3.5-turbo", "gpt-3.5-turbo"), + ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), + ("text-davinci-003", "text-davinci-003"), + ("text-davinci-002", "text-davinci-002"), + ("code-davinci-002", "code-davinci-002"), + ("text-curie-001", "text-curie-001"), + ("text-babbage-001", "text-babbage-001"), + ("text-ada-001", "text-ada-001"), + ("text-bison", "text-bison"), + ("chat-bison", "chat-bison"), + ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ), + ] + ), + ), + ] diff --git a/costs/models.py b/costs/models.py index ece1500f8..18700537f 100644 --- a/costs/models.py +++ b/costs/models.py @@ -2,17 +2,64 @@ from daras_ai_v2.language_model import LLMApis -# class Provider(models.TextChoices): -# vertex_ai = "vertex_ai", "Vertex AI" -# openai = "openai", "OpenAI" -# together = "together", "Together" - - class Product(models.TextChoices): gpt_4_vision = ( "gpt-4-vision-preview", "gpt-4-vision-preview", ) + gpt_4_turbo = ( + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", + ) + + gpt_4_32k = ( + "('openai-gpt-4-prod-ca-1', 'gpt-4')", + "('openai-gpt-4-prod-ca-1', 'gpt-4')", + ) + gpt_3_5_turbo = ( + "gpt-3.5-turbo", + "gpt-3.5-turbo", + ) + gpt_3_5_turbo_16k = ( + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-16k", + ) + text_davinci_003 = ( + "text-davinci-003", + "text-davinci-003", + ) + text_davinci_002 = ( + "text-davinci-002", + "text-davinci-002", + ) + code_davinci_002 = ( + "code-davinci-002", + "code-davinci-002", + ) + text_curie_001 = ( + "text-curie-001", + "text-curie-001", + ) + text_babbage_001 = ( + "text-babbage-001", + "text-babbage-001", + ) + text_ada_001 = ( + "text-ada-001", + "text-ada-001", + ) + palm2_text = ( + "text-bison", + "text-bison", + ) + palm2_chat = ( + "chat-bison", + "chat-bison", + ) + llama2_70b_chat = ( + "togethercomputer/llama-2-70b-chat", + "togethercomputer/llama-2-70b-chat", + ) class UsageCost(models.Model): @@ -43,14 +90,14 @@ class UsageCost(models.Model): class ProviderPricing(models.Model): - class Type(models.TextChoices): + class Group(models.TextChoices): # change to different name than type LLM = "LLM", "LLM" class Param(models.TextChoices): input = "Input", "input" output = "Output", "output" - type = models.TextField(choices=Type.choices) + type = models.TextField(choices=Group.choices) provider = models.TextField(choices=LLMApis.choices()) product = models.TextField(choices=Product.choices) param = models.TextField(choices=Param.choices) From ff3c86f98cd70f7dd44174cfb71013d0406014b5 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 25 Jan 2024 12:02:58 -0800 Subject: [PATCH 11/58] fixes keyerror with sorting options --- bots/models.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/bots/models.py b/bots/models.py index 3ad94cd31..30c294f3e 100644 --- a/bots/models.py +++ b/bots/models.py @@ -704,7 +704,26 @@ def to_df_format( "Bot": str(convo.bot_integration), } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=[ + "Name", + "Messages", + "Correct Answers", + "Thumbs up", + "Thumbs down", + "Last Sent", + "First Sent", + "A7", + "A30", + "R1", + "R7", + "R30", + "Delta Hours", + "Created At", + "Bot", + ], + ) return df @@ -900,7 +919,17 @@ def to_df_format( "Analysis JSON": message.analysis_result, } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=[ + "Name", + "Role", + "Message (EN)", + "Sent", + "Feedback", + "Analysis JSON", + ], + ) return df def to_df_analysis_format( @@ -922,7 +951,10 @@ def to_df_analysis_format( "Analysis JSON": message.analysis_result, } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=["Name", "Question (EN)", "Answer (EN)", "Sent", "Analysis JSON"], + ) return df def as_llm_context(self, limit: int = 100) -> list["ConversationEntry"]: @@ -1107,7 +1139,20 @@ def to_df_format( "Question Answered": feedback.message.question_answered, } rows.append(row) - df = pd.DataFrame.from_records(rows) + df = pd.DataFrame.from_records( + rows, + columns=[ + "Name", + "Question (EN)", + "Question Sent", + "Answer (EN)", + "Answer Sent", + "Rating", + "Feedback (EN)", + "Feedback Sent", + "Question Answered", + ], + ) return df From 4733be17e0d0cd9620cc45a043bea471841699ac Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 25 Jan 2024 12:07:43 -0800 Subject: [PATCH 12/58] make 100% sure sortby is valid --- recipes/VideoBotsStats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index c94b847f2..b57c9f2ac 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -667,7 +667,7 @@ def get_tabular_data( df["Run URL"] = run_url df["Bot"] = bi.name - if sort_by: + if sort_by and sort_by in df.columns: df.sort_values(by=[sort_by], ascending=False, inplace=True) return df From dcc057a0e545306a1afc5bc7089e0baf125f416f Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 25 Jan 2024 12:45:19 -0800 Subject: [PATCH 13/58] added regression test --- bots/models.py | 6 ++--- bots/tests.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/bots/models.py b/bots/models.py index 30c294f3e..457f3bb6a 100644 --- a/bots/models.py +++ b/bots/models.py @@ -937,14 +937,14 @@ def to_df_analysis_format( ) -> "pd.DataFrame": import pandas as pd - qs = self.filter(role=CHATML_ROLE_USER).prefetch_related("feedbacks") + qs = self.filter(role=CHATML_ROLE_ASSISSTANT).prefetch_related("feedbacks") rows = [] for message in qs[:row_limit]: message: Message row = { "Name": message.conversation.get_display_name(), - "Question (EN)": message.content, - "Answer (EN)": message.get_next_by_created_at().content, + "Question (EN)": message.get_previous_by_created_at().content, + "Answer (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) .strftime("%b %d, %Y %I:%M %p"), diff --git a/bots/tests.py b/bots/tests.py index 21eef78f8..30105e648 100644 --- a/bots/tests.py +++ b/bots/tests.py @@ -92,3 +92,66 @@ def test_create_bot_integration_conversation_message(transactional_db): assert message_b.role == CHATML_ROLE_ASSISSTANT assert message_b.content == "Red, green, and yellow grow the best." assert message_b.display_content == "Red, green, and yellow grow the best." + + +def test_stats_get_tabular_data_invalid_sorting_options(transactional_db): + from recipes.VideoBotsStats import VideoBotsStatsPage + + page = VideoBotsStatsPage() + + # setup + run_url = "https://my_run_url" + bi = BotIntegration.objects.create( + name="My Bot Integration", + saved_run=None, + billing_account_uid="fdnacsFSBQNKVW8z6tzhBLHKpAm1", # digital green's account id + user_language="en", + show_feedback_buttons=True, + platform=Platform.WHATSAPP, + wa_phone_number="my_whatsapp_number", + wa_phone_number_id="my_whatsapp_number_id", + ) + convos = Conversation.objects.filter(bot_integration=bi) + msgs = Message.objects.filter(conversation__in=convos) + + # valid option but no data + df = page.get_tabular_data( + bi, run_url, convos, msgs, "Answered Successfully", "Name" + ) + assert df.shape[0] == 0 + assert "Name" in df.columns + + # valid option and data + convo = Conversation.objects.create( + bot_integration=bi, + state=ConvoState.INITIAL, + wa_phone_number="+919876543210", + ) + Message.objects.create( + conversation=convo, + role=CHATML_ROLE_USER, + content="What types of chilies can be grown in Mumbai?", + display_content="What types of chilies can be grown in Mumbai?", + ) + Message.objects.create( + conversation=convo, + role=CHATML_ROLE_ASSISSTANT, + content="Red, green, and yellow grow the best.", + display_content="Red, green, and yellow grow the best.", + analysis_result={"Answered": True}, + ) + convos = Conversation.objects.filter(bot_integration=bi) + msgs = Message.objects.filter(conversation__in=convos) + assert msgs.count() == 2 + df = page.get_tabular_data( + bi, run_url, convos, msgs, "Answered Successfully", "Name" + ) + assert df.shape[0] == 1 + assert "Name" in df.columns + + # invalid sort option should be ignored + df = page.get_tabular_data( + bi, run_url, convos, msgs, "Answered Successfully", "Invalid" + ) + assert df.shape[0] == 1 + assert "Name" in df.columns From 08ed3deda4ce9540c928c2bdf4a6df6d7efaba7d Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Fri, 26 Jan 2024 00:34:29 -0800 Subject: [PATCH 14/58] uploading google folders to the document uploader now expands them --- daras_ai_v2/doc_search_settings_widgets.py | 14 +++++++++- daras_ai_v2/gdrive_downloader.py | 30 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index b240b482f..39cc152db 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -75,7 +75,19 @@ def document_uploader( accept=accept, accept_multiple_files=accept_multiple_files, ) - return st.session_state.get(key, []) + documents = st.session_state.get(key, []) + for document in documents: + if not document.startswith("https://drive.google.com/drive/folders"): + continue + from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder + from furl import furl + + folder_content_urls = gdrive_list_urls_of_files_in_folder(furl(document)) + documents.remove(document) + documents.extend(folder_content_urls) + st.session_state[key] = documents + st.session_state[custom_key] = "\n".join(documents) + return documents def doc_search_settings( diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index 89bcadb5b..c8d3bb7ba 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -25,6 +25,36 @@ def url_to_gdrive_file_id(f: furl) -> str: return file_id +def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: + # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) + folder_id = f.path.segments[-1] + # get metadata + service = discovery.build("drive", "v3") + # get files in drive directly + if f.host == "drive.google.com": + request = service.files().list( + supportsAllDrives=True, + includeItemsFromAllDrives=True, + q=f"'{folder_id}' in parents", + fields="files(mimeType,webViewLink)", + ) + # export google docs to appropriate type + else: + raise ValueError(f"Can't list files non google folder url: {str(f)!r}") + # download + response = request.execute() + files = response.get("files", []) + urls = [] + for file in files: + mime_type = file.get("mimeType") + url = file.get("webViewLink") + if mime_type == "application/vnd.google-apps.folder": + continue + elif url: + urls.append(url) + return urls + + def gdrive_download(f: furl, mime_type: str) -> tuple[bytes, str]: # get drive file id file_id = url_to_gdrive_file_id(f) From 12a60c27ad7c661b5ad005980691fcd9d56396ee Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Fri, 26 Jan 2024 00:41:19 -0800 Subject: [PATCH 15/58] remove unessecary comments --- daras_ai_v2/gdrive_downloader.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index c8d3bb7ba..339a634a3 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -28,9 +28,7 @@ def url_to_gdrive_file_id(f: furl) -> str: def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) folder_id = f.path.segments[-1] - # get metadata service = discovery.build("drive", "v3") - # get files in drive directly if f.host == "drive.google.com": request = service.files().list( supportsAllDrives=True, @@ -38,10 +36,8 @@ def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: q=f"'{folder_id}' in parents", fields="files(mimeType,webViewLink)", ) - # export google docs to appropriate type else: raise ValueError(f"Can't list files non google folder url: {str(f)!r}") - # download response = request.execute() files = response.get("files", []) urls = [] From dd525156920c7a4511207fb4a9fd6270be1b95a3 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Fri, 26 Jan 2024 01:03:52 -0800 Subject: [PATCH 16/58] linting fix --- bots/models.py | 8 +++-- daras_ai_v2/bot_integration_widgets.py | 6 ++-- daras_ai_v2/bots.py | 12 ++++--- daras_ai_v2/language_model.py | 28 +++++++++------ .../language_model_settings_widgets.py | 6 ++-- daras_ai_v2/stable_diffusion.py | 12 +++---- gooeysite/urls.py | 1 + pyproject.toml | 3 ++ recipes/CompareLLM.py | 6 ++-- recipes/CompareText2Img.py | 6 ++-- recipes/CompareUpscaler.py | 6 ++-- recipes/DocExtract.py | 6 ++-- recipes/DocSearch.py | 6 ++-- recipes/DocSummary.py | 6 ++-- recipes/GoogleGPT.py | 6 ++-- recipes/ImageSegmentation.py | 6 ++-- recipes/Img2Img.py | 8 +++-- recipes/QRCodeGenerator.py | 12 +++---- recipes/SEOSummary.py | 6 ++-- recipes/SmartGPT.py | 6 ++-- recipes/SocialLookupEmail.py | 6 ++-- recipes/Text2Audio.py | 12 +++---- recipes/TextToSpeech.py | 6 ++-- recipes/VideoBots.py | 34 ++++++++++--------- recipes/VideoBotsStats.py | 10 +++--- 25 files changed, 121 insertions(+), 103 deletions(-) diff --git a/bots/models.py b/bots/models.py index 3ad94cd31..c464156e4 100644 --- a/bots/models.py +++ b/bots/models.py @@ -894,9 +894,11 @@ def to_df_format( "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) .strftime("%b %d, %Y %I:%M %p"), - "Feedback": message.feedbacks.first().get_display_text() - if message.feedbacks.first() - else None, # only show first feedback as per Sean's request + "Feedback": ( + message.feedbacks.first().get_display_text() + if message.feedbacks.first() + else None + ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, } rows.append(row) diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 3880f9fe7..2686b2298 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -19,9 +19,9 @@ def general_integration_settings(bi: BotIntegration): st.session_state[f"_bi_user_language_{bi.id}"] = BotIntegration._meta.get_field( "user_language" ).default - st.session_state[ - f"_bi_show_feedback_buttons_{bi.id}" - ] = BotIntegration._meta.get_field("show_feedback_buttons").default + st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( + BotIntegration._meta.get_field("show_feedback_buttons").default + ) st.session_state[f"_bi_analysis_url_{bi.id}"] = None bi.show_feedback_buttons = st.checkbox( diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index dfc00c794..8cc1d69e1 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -389,11 +389,13 @@ def _save_msgs( role=CHATML_ROLE_USER, content=response.raw_input_text, display_content=input_text, - saved_run=SavedRun.objects.get_or_create( - workflow=Workflow.ASR, **furl(speech_run).query.params - )[0] - if speech_run - else None, + saved_run=( + SavedRun.objects.get_or_create( + workflow=Workflow.ASR, **furl(speech_run).query.params + )[0] + if speech_run + else None + ), ) attachments = [] for f_url in (input_images or []) + (input_documents or []): diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 80136cbe8..070aa1c22 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -198,12 +198,14 @@ def calc_gpt_tokens( for entry in messages if ( content := ( - format_chatml_message(entry) + "\n" - if is_chat_model - else entry.get("content", "") + ( + format_chatml_message(entry) + "\n" + if is_chat_model + else entry.get("content", "") + ) + if isinstance(entry, dict) + else str(entry) ) - if isinstance(entry, dict) - else str(entry) ) ) return default_length_function(combined) @@ -364,9 +366,11 @@ def run_language_model( else: out_content = [ # return messages back as either chatml or json messages - format_chatml_message(entry) - if is_chatml - else (entry.get("content") or "").strip() + ( + format_chatml_message(entry) + if is_chatml + else (entry.get("content") or "").strip() + ) for entry in result ] if tools: @@ -514,9 +518,11 @@ def _run_openai_chat( frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, tools=[tool.spec for tool in tools] if tools else NOT_GIVEN, - response_format={"type": response_format_type} - if response_format_type - else NOT_GIVEN, + response_format=( + {"type": response_format_type} + if response_format_type + else NOT_GIVEN + ), ) for model_str in model ], diff --git a/daras_ai_v2/language_model_settings_widgets.py b/daras_ai_v2/language_model_settings_widgets.py index e5ab27a59..80f785bd4 100644 --- a/daras_ai_v2/language_model_settings_widgets.py +++ b/daras_ai_v2/language_model_settings_widgets.py @@ -24,9 +24,9 @@ def language_model_settings(show_selector=True, show_document_model=False): f"###### {field_title_desc(VideoBotsPage.RequestModel, 'document_model')}", key="document_model", options=[None, *doc_model_descriptions], - format_func=lambda x: f"{doc_model_descriptions[x]} ({x})" - if x - else "β€”β€”β€”", + format_func=lambda x: ( + f"{doc_model_descriptions[x]} ({x})" if x else "β€”β€”β€”" + ), ) st.checkbox("Avoid Repetition", key="avoid_repetition") diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index 22c7adc8e..ce6333068 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -247,9 +247,9 @@ def instruct_pix2pix( }, inputs={ "prompt": [prompt] * len(images), - "negative_prompt": [negative_prompt] * len(images) - if negative_prompt - else None, + "negative_prompt": ( + [negative_prompt] * len(images) if negative_prompt else None + ), "num_images_per_prompt": num_outputs, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, @@ -440,9 +440,9 @@ def controlnet( pipeline={ "model_id": text2img_model_ids[Text2ImgModels[selected_model]], "seed": seed, - "scheduler": Schedulers[scheduler].label - if scheduler - else "UniPCMultistepScheduler", + "scheduler": ( + Schedulers[scheduler].label if scheduler else "UniPCMultistepScheduler" + ), "disable_safety_checker": True, "controlnet_model_id": [ controlnet_model_ids[ControlNetModels[model]] diff --git a/gooeysite/urls.py b/gooeysite/urls.py index 767660f21..6c3809436 100644 --- a/gooeysite/urls.py +++ b/gooeysite/urls.py @@ -14,6 +14,7 @@ 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ + from django.contrib import admin from django.urls import path diff --git a/pyproject.toml b/pyproject.toml index c6a699024..6726c1fce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,3 +93,6 @@ pre-commit = "^3.5.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.black] +--force-exclude = "migrations|node_modules|\\.git|\\.venv|\\.env|\\.pytest_cache|\\.vscode|\\.github|\\.to" diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 583317ddc..f1be5a937 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -36,9 +36,9 @@ class CompareLLMPage(BasePage): class RequestModel(BaseModel): input_prompt: str | None - selected_models: list[ - typing.Literal[tuple(e.name for e in LargeLanguageModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in LargeLanguageModels)]] | None + ) avoid_repetition: bool | None num_outputs: int | None diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index e79ef5d54..5f44ca34b 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -64,9 +64,9 @@ class RequestModel(BaseModel): seed: int | None sd_2_upscaling: bool | None - selected_models: list[ - typing.Literal[tuple(e.name for e in Text2ImgModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in Text2ImgModels)]] | None + ) scheduler: typing.Literal[tuple(e.name for e in Schedulers)] | None edit_instruction: str | None diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index f6f4e988b..e4ed8865f 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -24,9 +24,9 @@ class RequestModel(BaseModel): scale: int - selected_models: list[ - typing.Literal[tuple(e.name for e in UpscalerModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in UpscalerModels)]] | None + ) class ResponseModel(BaseModel): output_images: dict[typing.Literal[tuple(e.name for e in UpscalerModels)], str] diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 6c9489fff..2672fffdb 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -77,9 +77,9 @@ class RequestModel(BaseModel): task_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 2a509d8fd..d18b5a54c 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -60,9 +60,9 @@ class RequestModel(DocSearchRequest): task_instructions: str | None query_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 4b9283cde..498f2d405 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -60,9 +60,9 @@ class RequestModel(BaseModel): task_instructions: str | None merge_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 57b1078ba..7a2d2591f 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -79,9 +79,9 @@ class RequestModel(GoogleSearchMixin, BaseModel): task_instructions: str | None query_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 543664065..751a0936b 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -49,9 +49,9 @@ class ImageSegmentationPage(BasePage): class RequestModel(BaseModel): input_image: str - selected_model: typing.Literal[ - tuple(e.name for e in ImageSegmentationModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in ImageSegmentationModels)] | None + ) mask_threshold: float | None rect_persepective_transform: bool | None diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index a97bc7c62..68e696c08 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -46,9 +46,11 @@ class RequestModel(BaseModel): text_prompt: str | None selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None - selected_controlnet_model: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)] - ] | typing.Literal[tuple(e.name for e in ControlNetModels)] | None + selected_controlnet_model: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)]] + | typing.Literal[tuple(e.name for e in ControlNetModels)] + | None + ) negative_prompt: str | None num_outputs: int | None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 6b7260368..34b615225 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -88,18 +88,18 @@ class RequestModel(BaseModel): text_prompt: str negative_prompt: str | None image_prompt: str | None - image_prompt_controlnet_models: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)], ... - ] | None + image_prompt_controlnet_models: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None + ) image_prompt_strength: float | None image_prompt_scale: float | None image_prompt_pos_x: float | None image_prompt_pos_y: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None - selected_controlnet_model: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)], ... - ] | None + selected_controlnet_model: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None + ) output_width: int | None output_height: int | None diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index a82424ae1..78328d822 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -98,9 +98,9 @@ class RequestModel(GoogleSearchMixin, BaseModel): enable_html: bool | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) sampling_temperature: float | None max_tokens: int | None num_outputs: int | None diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index 554a74940..0aabcfefb 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -34,9 +34,9 @@ class RequestModel(BaseModel): reflexion_prompt: str | None dera_prompt: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index 082fdcdaa..1094cea53 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -40,9 +40,9 @@ class RequestModel(BaseModel): domain: str | None key_words: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) sampling_temperature: float | None max_tokens: int | None diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 77776ddf5..18270cc8c 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -49,9 +49,9 @@ class RequestModel(BaseModel): seed: int | None sd_2_upscaling: bool | None - selected_models: list[ - typing.Literal[tuple(e.name for e in Text2AudioModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in Text2AudioModels)]] | None + ) class ResponseModel(BaseModel): output_audios: dict[ @@ -114,9 +114,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ), inputs=dict( prompt=[request.text_prompt], - negative_prompt=[request.negative_prompt] - if request.negative_prompt - else None, + negative_prompt=( + [request.negative_prompt] if request.negative_prompt else None + ), num_waveforms_per_prompt=request.num_outputs, num_inference_steps=request.quality, guidance_scale=request.guidance_scale, diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index c37a8eb5f..721b2cb7a 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -53,9 +53,9 @@ class TextToSpeechPage(BasePage): class RequestModel(BaseModel): text_prompt: str - tts_provider: typing.Literal[ - tuple(e.name for e in TextToSpeechProviders) - ] | None + tts_provider: ( + typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + ) uberduck_voice_name: str | None uberduck_speaking_rate: float | None diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 149931ffb..faa88d533 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -157,9 +157,9 @@ class RequestModel(BaseModel): messages: list[ConversationEntry] | None # tts settings - tts_provider: typing.Literal[ - tuple(e.name for e in TextToSpeechProviders) - ] | None + tts_provider: ( + typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + ) uberduck_voice_name: str | None uberduck_speaking_rate: float | None google_voice_name: str | None @@ -174,9 +174,9 @@ class RequestModel(BaseModel): elevenlabs_similarity_boost: float | None # llm settings - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) document_model: str | None = Field( title="🩻 Photo / Document Intelligence", description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? " @@ -1028,10 +1028,12 @@ def messenger_bot_integration(self): favicon = Platform(bi.platform).get_favicon() with st.div(className="mt-2"): st.markdown( - f'  ' - f'{bi}' - if bi.saved_run - else f"{bi}", + ( + f'  ' + f'{bi}' + if bi.saved_run + else f"{bi}" + ), unsafe_allow_html=True, ) with col2: @@ -1082,9 +1084,9 @@ def slack_specific_settings(self, bi: BotIntegration): st.session_state[f"_bi_name_{bi.id}"] = ( pr and pr.title ) or self.get_recipe_title() - st.session_state[ - f"_bi_slack_read_receipt_msg_{bi.id}" - ] = BotIntegration._meta.get_field("slack_read_receipt_msg").default + st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( + BotIntegration._meta.get_field("slack_read_receipt_msg").default + ) bi.slack_read_receipt_msg = st.text_input( """ @@ -1283,9 +1285,9 @@ def msg_container_widget(role: str): return st.div( className="px-3 py-1 pt-2", style=dict( - background="rgba(239, 239, 239, 0.6)" - if role == CHATML_ROLE_USER - else "#fff", + background=( + "rgba(239, 239, 239, 0.6)" if role == CHATML_ROLE_USER else "#fff" + ), ), ) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index c94b847f2..8d58dc272 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -310,11 +310,11 @@ def parse_run_info(self, bi): run_title = ( bi.published_run.title if bi.published_run - else saved_run.page_title - if saved_run and saved_run.page_title - else "This Copilot Run" - if saved_run - else "No Run Connected" + else ( + saved_run.page_title + if saved_run and saved_run.page_title + else "This Copilot Run" if saved_run else "No Run Connected" + ) ) run_url = furl(saved_run.get_app_url()).tostr() if saved_run else "" return run_title, run_url From ffb0a619efb118c77335d1bfad6a99fc9e6a275b Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 01:49:09 -0800 Subject: [PATCH 17/58] rename tables to remove "all" --- recipes/VideoBotsStats.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index b57c9f2ac..e625fc977 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -154,8 +154,8 @@ def render(self): details = st.horizontal_radio( "### Details", options=[ - "All Conversations", - "All Messages", + "Conversations", + "Messages", "Feedback Positive", "Feedback Negative", "Answered Successfully", @@ -164,7 +164,7 @@ def render(self): key="details", ) - if details == "All Conversations": + if details == "Conversations": options = [ "Messages", "Correct Answers", @@ -310,11 +310,11 @@ def parse_run_info(self, bi): run_title = ( bi.published_run.title if bi.published_run - else saved_run.page_title - if saved_run and saved_run.page_title - else "This Copilot Run" - if saved_run - else "No Run Connected" + else ( + saved_run.page_title + if saved_run and saved_run.page_title + else "This Copilot Run" if saved_run else "No Run Connected" + ) ) run_url = furl(saved_run.get_app_url()).tostr() if saved_run else "" return run_title, run_url @@ -626,9 +626,9 @@ def get_tabular_data( self, bi, run_url, conversations, messages, details, sort_by, rows=10000 ): df = pd.DataFrame() - if details == "All Conversations": + if details == "Conversations": df = conversations.to_df_format(row_limit=rows) - elif details == "All Messages": + elif details == "Messages": df = messages.order_by("-created_at", "conversation__id").to_df_format( row_limit=rows ) From d0b28f25283c590d5ea6d4c1c27c6ab94871b15b Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 03:42:21 -0800 Subject: [PATCH 18/58] average response time --- recipes/VideoBotsStats.py | 45 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index e625fc977..4939944cd 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -31,7 +31,8 @@ TruncYear, Concat, ) -from django.db.models import Count +from django.db.models import Count, F, Window +from django.db.models.functions import Lag ID_COLUMNS = [ "conversation__fb_page_id", @@ -388,6 +389,33 @@ def calculate_overall_stats(self, bid, bi, run_title, run_url): def calculate_stats_binned_by_time( self, bi, start_date, end_date, factor, trunc_fn ): + average_response_time = ( + Message.objects.filter( + created_at__date__gte=start_date, + created_at__date__lte=end_date, + conversation__bot_integration=bi, + ) + .values("conversation_id") + .order_by("created_at") + .annotate( + response_time=F("created_at") - Window(expression=Lag("created_at")), + ) + .annotate(date=trunc_fn("created_at")) + .values("date", "response_time", "role") + ) + average_response_time = ( + pd.DataFrame( + average_response_time, + columns=["date", "response_time", "role"], + ) + .loc[lambda df: df["role"] == CHATML_ROLE_ASSISTANT] + .groupby("date") + .agg({"response_time": "median"}) + .apply(lambda x: x.clip(lower=timedelta(0))) + .rename(columns={"response_time": "Average_response_time"}) + .reset_index() + ) + messages_received = ( Message.objects.filter( created_at__date__gte=start_date, @@ -468,6 +496,12 @@ def calculate_stats_binned_by_time( left_on="date", right_on="date", ) + df = df.merge( + average_response_time, + how="outer", + left_on="date", + right_on="date", + ) df["Messages_Sent"] = df["Messages_Sent"] * factor df["Convos"] = df["Convos"] * factor df["Senders"] = df["Senders"] * factor @@ -476,6 +510,7 @@ def calculate_stats_binned_by_time( df["Neg_feedback"] = df["Neg_feedback"] * factor df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] + df["Average_response_time"] = df["Average_response_time"] * factor df.fillna(0, inplace=True) df = df.round(0).astype("int32", errors="ignore") return df @@ -576,6 +611,14 @@ def plot_graphs(self, view, df): text=list(df["Msgs_per_user"]), hovertemplate="Messages per User: %{y:.0f}", ), + go.Scatter( + name="Average Response Time", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Average_response_time"]), + text=list(df["Average_response_time"]), + hovertemplate="Average Response Time: %{y:.0f}", + ), ], layout=dict( margin=dict(l=0, r=0, t=28, b=0), From b359f65adf93f73ac359c7e4683ccea1a36426f8 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 03:42:26 -0800 Subject: [PATCH 19/58] linting --- bots/models.py | 12 +++++-- daras_ai_v2/bot_integration_widgets.py | 6 ++-- daras_ai_v2/bots.py | 12 ++++--- daras_ai_v2/language_model.py | 28 +++++++++------ .../language_model_settings_widgets.py | 6 ++-- daras_ai_v2/stable_diffusion.py | 12 +++---- gooeysite/urls.py | 1 + pyproject.toml | 3 ++ recipes/CompareLLM.py | 6 ++-- recipes/CompareText2Img.py | 6 ++-- recipes/CompareUpscaler.py | 6 ++-- recipes/DocExtract.py | 6 ++-- recipes/DocSearch.py | 6 ++-- recipes/DocSummary.py | 6 ++-- recipes/GoogleGPT.py | 6 ++-- recipes/ImageSegmentation.py | 6 ++-- recipes/Img2Img.py | 8 +++-- recipes/QRCodeGenerator.py | 12 +++---- recipes/SEOSummary.py | 6 ++-- recipes/SmartGPT.py | 6 ++-- recipes/SocialLookupEmail.py | 6 ++-- recipes/Text2Audio.py | 12 +++---- recipes/TextToSpeech.py | 6 ++-- recipes/VideoBots.py | 34 ++++++++++--------- 24 files changed, 120 insertions(+), 98 deletions(-) diff --git a/bots/models.py b/bots/models.py index 457f3bb6a..5148a531e 100644 --- a/bots/models.py +++ b/bots/models.py @@ -913,9 +913,11 @@ def to_df_format( "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) .strftime("%b %d, %Y %I:%M %p"), - "Feedback": message.feedbacks.first().get_display_text() - if message.feedbacks.first() - else None, # only show first feedback as per Sean's request + "Feedback": ( + message.feedbacks.first().get_display_text() + if message.feedbacks.first() + else None + ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, } rows.append(row) @@ -1051,6 +1053,10 @@ def __str__(self): def local_lang(self): return Truncator(self.display_content).words(30) + @property + def response_time(self): + return self.created_at - self.get_previous_by_created_at().created_at + class MessageAttachment(models.Model): message = models.ForeignKey( diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 3880f9fe7..2686b2298 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -19,9 +19,9 @@ def general_integration_settings(bi: BotIntegration): st.session_state[f"_bi_user_language_{bi.id}"] = BotIntegration._meta.get_field( "user_language" ).default - st.session_state[ - f"_bi_show_feedback_buttons_{bi.id}" - ] = BotIntegration._meta.get_field("show_feedback_buttons").default + st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( + BotIntegration._meta.get_field("show_feedback_buttons").default + ) st.session_state[f"_bi_analysis_url_{bi.id}"] = None bi.show_feedback_buttons = st.checkbox( diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index dfc00c794..8cc1d69e1 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -389,11 +389,13 @@ def _save_msgs( role=CHATML_ROLE_USER, content=response.raw_input_text, display_content=input_text, - saved_run=SavedRun.objects.get_or_create( - workflow=Workflow.ASR, **furl(speech_run).query.params - )[0] - if speech_run - else None, + saved_run=( + SavedRun.objects.get_or_create( + workflow=Workflow.ASR, **furl(speech_run).query.params + )[0] + if speech_run + else None + ), ) attachments = [] for f_url in (input_images or []) + (input_documents or []): diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 80136cbe8..070aa1c22 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -198,12 +198,14 @@ def calc_gpt_tokens( for entry in messages if ( content := ( - format_chatml_message(entry) + "\n" - if is_chat_model - else entry.get("content", "") + ( + format_chatml_message(entry) + "\n" + if is_chat_model + else entry.get("content", "") + ) + if isinstance(entry, dict) + else str(entry) ) - if isinstance(entry, dict) - else str(entry) ) ) return default_length_function(combined) @@ -364,9 +366,11 @@ def run_language_model( else: out_content = [ # return messages back as either chatml or json messages - format_chatml_message(entry) - if is_chatml - else (entry.get("content") or "").strip() + ( + format_chatml_message(entry) + if is_chatml + else (entry.get("content") or "").strip() + ) for entry in result ] if tools: @@ -514,9 +518,11 @@ def _run_openai_chat( frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, tools=[tool.spec for tool in tools] if tools else NOT_GIVEN, - response_format={"type": response_format_type} - if response_format_type - else NOT_GIVEN, + response_format=( + {"type": response_format_type} + if response_format_type + else NOT_GIVEN + ), ) for model_str in model ], diff --git a/daras_ai_v2/language_model_settings_widgets.py b/daras_ai_v2/language_model_settings_widgets.py index e5ab27a59..80f785bd4 100644 --- a/daras_ai_v2/language_model_settings_widgets.py +++ b/daras_ai_v2/language_model_settings_widgets.py @@ -24,9 +24,9 @@ def language_model_settings(show_selector=True, show_document_model=False): f"###### {field_title_desc(VideoBotsPage.RequestModel, 'document_model')}", key="document_model", options=[None, *doc_model_descriptions], - format_func=lambda x: f"{doc_model_descriptions[x]} ({x})" - if x - else "β€”β€”β€”", + format_func=lambda x: ( + f"{doc_model_descriptions[x]} ({x})" if x else "β€”β€”β€”" + ), ) st.checkbox("Avoid Repetition", key="avoid_repetition") diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index 22c7adc8e..ce6333068 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -247,9 +247,9 @@ def instruct_pix2pix( }, inputs={ "prompt": [prompt] * len(images), - "negative_prompt": [negative_prompt] * len(images) - if negative_prompt - else None, + "negative_prompt": ( + [negative_prompt] * len(images) if negative_prompt else None + ), "num_images_per_prompt": num_outputs, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, @@ -440,9 +440,9 @@ def controlnet( pipeline={ "model_id": text2img_model_ids[Text2ImgModels[selected_model]], "seed": seed, - "scheduler": Schedulers[scheduler].label - if scheduler - else "UniPCMultistepScheduler", + "scheduler": ( + Schedulers[scheduler].label if scheduler else "UniPCMultistepScheduler" + ), "disable_safety_checker": True, "controlnet_model_id": [ controlnet_model_ids[ControlNetModels[model]] diff --git a/gooeysite/urls.py b/gooeysite/urls.py index 767660f21..6c3809436 100644 --- a/gooeysite/urls.py +++ b/gooeysite/urls.py @@ -14,6 +14,7 @@ 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ + from django.contrib import admin from django.urls import path diff --git a/pyproject.toml b/pyproject.toml index c6a699024..01ccdd04e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,3 +93,6 @@ pre-commit = "^3.5.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.black] +force-exclude = "migrations" diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 583317ddc..f1be5a937 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -36,9 +36,9 @@ class CompareLLMPage(BasePage): class RequestModel(BaseModel): input_prompt: str | None - selected_models: list[ - typing.Literal[tuple(e.name for e in LargeLanguageModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in LargeLanguageModels)]] | None + ) avoid_repetition: bool | None num_outputs: int | None diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index e79ef5d54..5f44ca34b 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -64,9 +64,9 @@ class RequestModel(BaseModel): seed: int | None sd_2_upscaling: bool | None - selected_models: list[ - typing.Literal[tuple(e.name for e in Text2ImgModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in Text2ImgModels)]] | None + ) scheduler: typing.Literal[tuple(e.name for e in Schedulers)] | None edit_instruction: str | None diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index f6f4e988b..e4ed8865f 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -24,9 +24,9 @@ class RequestModel(BaseModel): scale: int - selected_models: list[ - typing.Literal[tuple(e.name for e in UpscalerModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in UpscalerModels)]] | None + ) class ResponseModel(BaseModel): output_images: dict[typing.Literal[tuple(e.name for e in UpscalerModels)], str] diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 6c9489fff..2672fffdb 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -77,9 +77,9 @@ class RequestModel(BaseModel): task_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 2a509d8fd..d18b5a54c 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -60,9 +60,9 @@ class RequestModel(DocSearchRequest): task_instructions: str | None query_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 4b9283cde..498f2d405 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -60,9 +60,9 @@ class RequestModel(BaseModel): task_instructions: str | None merge_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 57b1078ba..7a2d2591f 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -79,9 +79,9 @@ class RequestModel(GoogleSearchMixin, BaseModel): task_instructions: str | None query_instructions: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 543664065..751a0936b 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -49,9 +49,9 @@ class ImageSegmentationPage(BasePage): class RequestModel(BaseModel): input_image: str - selected_model: typing.Literal[ - tuple(e.name for e in ImageSegmentationModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in ImageSegmentationModels)] | None + ) mask_threshold: float | None rect_persepective_transform: bool | None diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index a97bc7c62..68e696c08 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -46,9 +46,11 @@ class RequestModel(BaseModel): text_prompt: str | None selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None - selected_controlnet_model: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)] - ] | typing.Literal[tuple(e.name for e in ControlNetModels)] | None + selected_controlnet_model: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)]] + | typing.Literal[tuple(e.name for e in ControlNetModels)] + | None + ) negative_prompt: str | None num_outputs: int | None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 6b7260368..34b615225 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -88,18 +88,18 @@ class RequestModel(BaseModel): text_prompt: str negative_prompt: str | None image_prompt: str | None - image_prompt_controlnet_models: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)], ... - ] | None + image_prompt_controlnet_models: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None + ) image_prompt_strength: float | None image_prompt_scale: float | None image_prompt_pos_x: float | None image_prompt_pos_y: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None - selected_controlnet_model: list[ - typing.Literal[tuple(e.name for e in ControlNetModels)], ... - ] | None + selected_controlnet_model: ( + list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None + ) output_width: int | None output_height: int | None diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index a82424ae1..78328d822 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -98,9 +98,9 @@ class RequestModel(GoogleSearchMixin, BaseModel): enable_html: bool | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) sampling_temperature: float | None max_tokens: int | None num_outputs: int | None diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index 554a74940..0aabcfefb 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -34,9 +34,9 @@ class RequestModel(BaseModel): reflexion_prompt: str | None dera_prompt: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) avoid_repetition: bool | None num_outputs: int | None quality: float | None diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index 082fdcdaa..1094cea53 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -40,9 +40,9 @@ class RequestModel(BaseModel): domain: str | None key_words: str | None - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) sampling_temperature: float | None max_tokens: int | None diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 77776ddf5..18270cc8c 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -49,9 +49,9 @@ class RequestModel(BaseModel): seed: int | None sd_2_upscaling: bool | None - selected_models: list[ - typing.Literal[tuple(e.name for e in Text2AudioModels)] - ] | None + selected_models: ( + list[typing.Literal[tuple(e.name for e in Text2AudioModels)]] | None + ) class ResponseModel(BaseModel): output_audios: dict[ @@ -114,9 +114,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ), inputs=dict( prompt=[request.text_prompt], - negative_prompt=[request.negative_prompt] - if request.negative_prompt - else None, + negative_prompt=( + [request.negative_prompt] if request.negative_prompt else None + ), num_waveforms_per_prompt=request.num_outputs, num_inference_steps=request.quality, guidance_scale=request.guidance_scale, diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index c37a8eb5f..721b2cb7a 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -53,9 +53,9 @@ class TextToSpeechPage(BasePage): class RequestModel(BaseModel): text_prompt: str - tts_provider: typing.Literal[ - tuple(e.name for e in TextToSpeechProviders) - ] | None + tts_provider: ( + typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + ) uberduck_voice_name: str | None uberduck_speaking_rate: float | None diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 149931ffb..faa88d533 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -157,9 +157,9 @@ class RequestModel(BaseModel): messages: list[ConversationEntry] | None # tts settings - tts_provider: typing.Literal[ - tuple(e.name for e in TextToSpeechProviders) - ] | None + tts_provider: ( + typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + ) uberduck_voice_name: str | None uberduck_speaking_rate: float | None google_voice_name: str | None @@ -174,9 +174,9 @@ class RequestModel(BaseModel): elevenlabs_similarity_boost: float | None # llm settings - selected_model: typing.Literal[ - tuple(e.name for e in LargeLanguageModels) - ] | None + selected_model: ( + typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None + ) document_model: str | None = Field( title="🩻 Photo / Document Intelligence", description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? " @@ -1028,10 +1028,12 @@ def messenger_bot_integration(self): favicon = Platform(bi.platform).get_favicon() with st.div(className="mt-2"): st.markdown( - f'  ' - f'{bi}' - if bi.saved_run - else f"{bi}", + ( + f'  ' + f'{bi}' + if bi.saved_run + else f"{bi}" + ), unsafe_allow_html=True, ) with col2: @@ -1082,9 +1084,9 @@ def slack_specific_settings(self, bi: BotIntegration): st.session_state[f"_bi_name_{bi.id}"] = ( pr and pr.title ) or self.get_recipe_title() - st.session_state[ - f"_bi_slack_read_receipt_msg_{bi.id}" - ] = BotIntegration._meta.get_field("slack_read_receipt_msg").default + st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( + BotIntegration._meta.get_field("slack_read_receipt_msg").default + ) bi.slack_read_receipt_msg = st.text_input( """ @@ -1283,9 +1285,9 @@ def msg_container_widget(role: str): return st.div( className="px-3 py-1 pt-2", style=dict( - background="rgba(239, 239, 239, 0.6)" - if role == CHATML_ROLE_USER - else "#fff", + background=( + "rgba(239, 239, 239, 0.6)" if role == CHATML_ROLE_USER else "#fff" + ), ), ) From e00411135e7aeaaf5b9897c3ef1e614a1b174dd4 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 03:51:42 -0800 Subject: [PATCH 20/58] add response time to table --- bots/models.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/bots/models.py b/bots/models.py index 5148a531e..61c56abc3 100644 --- a/bots/models.py +++ b/bots/models.py @@ -919,6 +919,7 @@ def to_df_format( else None ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, + "Response Time": message.response_time.total_seconds(), } rows.append(row) df = pd.DataFrame.from_records( @@ -930,6 +931,7 @@ def to_df_format( "Sent", "Feedback", "Analysis JSON", + "Response Time", ], ) return df @@ -1055,7 +1057,20 @@ def local_lang(self): @property def response_time(self): - return self.created_at - self.get_previous_by_created_at().created_at + import pandas as pd + + if self.role == CHATML_ROLE_USER: + return pd.NaT + return ( + self.created_at + - Message.objects.filter( + conversation=self.conversation, + role=CHATML_ROLE_USER, + created_at__lt=self.created_at, + ) + .latest() + .created_at + ) class MessageAttachment(models.Model): From 38badb99343156997dcff01725fd2e6fde16df38 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 10:36:33 -0800 Subject: [PATCH 21/58] recursive folder expansion --- daras_ai_v2/gdrive_downloader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index 339a634a3..ee1f7453c 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -25,7 +25,7 @@ def url_to_gdrive_file_id(f: furl) -> str: return file_id -def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: +def gdrive_list_urls_of_files_in_folder(f: furl, max_depth=10) -> list[str]: # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) folder_id = f.path.segments[-1] service = discovery.build("drive", "v3") @@ -37,7 +37,7 @@ def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: fields="files(mimeType,webViewLink)", ) else: - raise ValueError(f"Can't list files non google folder url: {str(f)!r}") + raise ValueError(f"Can't list files from non google folder url: {str(f)!r}") response = request.execute() files = response.get("files", []) urls = [] @@ -45,7 +45,9 @@ def gdrive_list_urls_of_files_in_folder(f: furl) -> list[str]: mime_type = file.get("mimeType") url = file.get("webViewLink") if mime_type == "application/vnd.google-apps.folder": - continue + urls += gdrive_list_urls_of_files_in_folder( + furl(url), max_depth=max_depth - 1 + ) elif url: urls.append(url) return urls From 1c8179cc09247a3437b692127f9103766d5d1b26 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Mon, 29 Jan 2024 10:37:10 -0800 Subject: [PATCH 22/58] stop recursing at max depth lol --- daras_ai_v2/gdrive_downloader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index ee1f7453c..345099aba 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -26,6 +26,8 @@ def url_to_gdrive_file_id(f: furl) -> str: def gdrive_list_urls_of_files_in_folder(f: furl, max_depth=10) -> list[str]: + if max_depth <= 0: + return [] # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) folder_id = f.path.segments[-1] service = discovery.build("drive", "v3") From 91e7810a8803c500a152d694c5687ff4d37f8e34 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 12:07:40 -0800 Subject: [PATCH 23/58] SavedRun run_time attempt --- bots/models.py | 29 +++++++++++++++-------------- recipes/VideoBotsStats.py | 39 ++++----------------------------------- 2 files changed, 19 insertions(+), 49 deletions(-) diff --git a/bots/models.py b/bots/models.py index 0714f187b..204a625bd 100644 --- a/bots/models.py +++ b/bots/models.py @@ -1065,20 +1065,21 @@ def local_lang(self): @property def response_time(self): - import pandas as pd - - if self.role == CHATML_ROLE_USER: - return pd.NaT - return ( - self.created_at - - Message.objects.filter( - conversation=self.conversation, - role=CHATML_ROLE_USER, - created_at__lt=self.created_at, - ) - .latest() - .created_at - ) + return self.saved_run.run_time + # import pandas as pd + + # if self.role == CHATML_ROLE_USER: + # return pd.NaT + # return ( + # self.created_at + # - Message.objects.filter( + # conversation=self.conversation, + # role=CHATML_ROLE_USER, + # created_at__lt=self.created_at, + # ) + # .latest() + # .created_at + # ) class MessageAttachment(models.Model): diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 4939944cd..c82989c02 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -31,8 +31,7 @@ TruncYear, Concat, ) -from django.db.models import Count, F, Window -from django.db.models.functions import Lag +from django.db.models import Count, Avg ID_COLUMNS = [ "conversation__fb_page_id", @@ -389,33 +388,6 @@ def calculate_overall_stats(self, bid, bi, run_title, run_url): def calculate_stats_binned_by_time( self, bi, start_date, end_date, factor, trunc_fn ): - average_response_time = ( - Message.objects.filter( - created_at__date__gte=start_date, - created_at__date__lte=end_date, - conversation__bot_integration=bi, - ) - .values("conversation_id") - .order_by("created_at") - .annotate( - response_time=F("created_at") - Window(expression=Lag("created_at")), - ) - .annotate(date=trunc_fn("created_at")) - .values("date", "response_time", "role") - ) - average_response_time = ( - pd.DataFrame( - average_response_time, - columns=["date", "response_time", "role"], - ) - .loc[lambda df: df["role"] == CHATML_ROLE_ASSISTANT] - .groupby("date") - .agg({"response_time": "median"}) - .apply(lambda x: x.clip(lower=timedelta(0))) - .rename(columns={"response_time": "Average_response_time"}) - .reset_index() - ) - messages_received = ( Message.objects.filter( created_at__date__gte=start_date, @@ -436,6 +408,7 @@ def calculate_stats_binned_by_time( distinct=True, ) ) + .annotate(Average_response_time=Avg("saved_run__run_time")) .annotate(Unique_feedback_givers=Count("feedbacks", distinct=True)) .values( "date", @@ -443,6 +416,7 @@ def calculate_stats_binned_by_time( "Convos", "Senders", "Unique_feedback_givers", + "Average_response_time", ) ) @@ -482,6 +456,7 @@ def calculate_stats_binned_by_time( "Convos", "Senders", "Unique_feedback_givers", + "Average_response_time", ], ) df = df.merge( @@ -496,12 +471,6 @@ def calculate_stats_binned_by_time( left_on="date", right_on="date", ) - df = df.merge( - average_response_time, - how="outer", - left_on="date", - right_on="date", - ) df["Messages_Sent"] = df["Messages_Sent"] * factor df["Convos"] = df["Convos"] * factor df["Senders"] = df["Senders"] * factor From 344f6ba493e79a5ee999572186e9f87a12d4c0fc Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 13:01:09 -0800 Subject: [PATCH 24/58] response time collection --- bots/migrations/0056_message_response_time.py | 19 +++++++++++++++ bots/models.py | 23 ++++--------------- daras_ai_v2/bots.py | 12 ++++++++++ recipes/VideoBotsStats.py | 2 +- 4 files changed, 37 insertions(+), 19 deletions(-) create mode 100644 bots/migrations/0056_message_response_time.py diff --git a/bots/migrations/0056_message_response_time.py b/bots/migrations/0056_message_response_time.py new file mode 100644 index 000000000..06730e931 --- /dev/null +++ b/bots/migrations/0056_message_response_time.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.7 on 2024-02-01 20:15 + +import datetime +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0055_workflowmetadata'), + ] + + operations = [ + migrations.AddField( + model_name='message', + name='response_time', + field=models.DurationField(default=datetime.timedelta(days=-1, seconds=86399), help_text='The time it took for the bot to respond to the corresponding user message'), + ), + ] diff --git a/bots/models.py b/bots/models.py index 204a625bd..22008019f 100644 --- a/bots/models.py +++ b/bots/models.py @@ -1048,6 +1048,11 @@ class Message(models.Model): help_text="Subject of given question (DEPRECATED)", ) + response_time = models.DurationField( + default=datetime.timedelta(seconds=-1), + help_text="The time it took for the bot to respond to the corresponding user message", + ) + _analysis_started = False objects = MessageQuerySet.as_manager() @@ -1063,24 +1068,6 @@ def __str__(self): def local_lang(self): return Truncator(self.display_content).words(30) - @property - def response_time(self): - return self.saved_run.run_time - # import pandas as pd - - # if self.role == CHATML_ROLE_USER: - # return pd.NaT - # return ( - # self.created_at - # - Message.objects.filter( - # conversation=self.conversation, - # role=CHATML_ROLE_USER, - # created_at__lt=self.created_at, - # ) - # .latest() - # .created_at - # ) - class MessageAttachment(models.Model): message = models.ForeignKey( diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 8cc1d69e1..324affbb7 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -3,11 +3,14 @@ import typing from urllib.parse import parse_qs +import pytz +from datetime import datetime from django.db import transaction from fastapi import HTTPException, Request from furl import furl from sentry_sdk import capture_exception +from daras_ai_v2 import settings from app_users.models import AppUser from bots.models import ( Platform, @@ -202,6 +205,7 @@ def _on_msg(bot: BotInterface): speech_run = None input_images = None input_documents = None + recieved_time: datetime = datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) if not bot.page_cls: bot.send_msg(text=PAGE_NOT_CONNECTED_ERROR) return @@ -286,6 +290,7 @@ def _on_msg(bot: BotInterface): input_documents=input_documents, input_text=input_text, speech_run=speech_run, + recieved_time=recieved_time, ) @@ -324,6 +329,7 @@ def _process_and_send_msg( input_images: list[str] | None, input_documents: list[str] | None, input_text: str, + recieved_time: datetime, speech_run: str | None, ): try: @@ -369,6 +375,7 @@ def _process_and_send_msg( platform_msg_id=msg_id, response=response, url=url, + received_time=recieved_time, ) @@ -381,6 +388,7 @@ def _save_msgs( platform_msg_id: str | None, response: VideoBotsPage.ResponseModel, url: str, + received_time: datetime, ): # create messages for future context user_msg = Message( @@ -396,6 +404,8 @@ def _save_msgs( if speech_run else None ), + response_time=datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) + - received_time, ) attachments = [] for f_url in (input_images or []) + (input_documents or []): @@ -412,6 +422,8 @@ def _save_msgs( saved_run=SavedRun.objects.get_or_create( workflow=Workflow.VIDEO_BOTS, **furl(url).query.params )[0], + response_time=datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) + - received_time, ) # save the messages & attachments with transaction.atomic(): diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index c82989c02..a265f5a92 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -408,7 +408,7 @@ def calculate_stats_binned_by_time( distinct=True, ) ) - .annotate(Average_response_time=Avg("saved_run__run_time")) + .annotate(Average_response_time=Avg("response_time")) .annotate(Unique_feedback_givers=Count("feedbacks", distinct=True)) .values( "date", From fb9d32cebe0bf3d441acc966fdcd99088e108320 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 13:26:34 -0800 Subject: [PATCH 25/58] enable show all --- recipes/VideoBotsStats.py | 51 ++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index a265f5a92..12eabf51f 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -136,7 +136,7 @@ def render(self): view, factor, trunc_fn, - ) = self.render_date_view_inputs() + ) = self.render_date_view_inputs(bi) df = self.calculate_stats_binned_by_time( bi, start_date, end_date, factor, trunc_fn @@ -260,27 +260,34 @@ def render(self): ).tostr() st.change_url(new_url, self.request) - def render_date_view_inputs(self): - start_of_year_date = datetime.now().replace(month=1, day=1) - st.session_state.setdefault( - "start_date", - self.request.query_params.get( - "start_date", start_of_year_date.strftime("%Y-%m-%d") - ), - ) - start_date: datetime = ( - st.date_input("Start date", key="start_date") or start_of_year_date - ) - st.session_state.setdefault( - "end_date", - self.request.query_params.get( - "end_date", datetime.now().strftime("%Y-%m-%d") - ), - ) - end_date: datetime = st.date_input("End date", key="end_date") or datetime.now() - st.session_state.setdefault( - "view", self.request.query_params.get("view", "Weekly") - ) + def render_date_view_inputs(self, bi): + if st.checkbox("Show All"): + start_date = bi.created_at + end_date = datetime.now() + else: + start_of_year_date = datetime.now().replace(month=1, day=1) + st.session_state.setdefault( + "start_date", + self.request.query_params.get( + "start_date", start_of_year_date.strftime("%Y-%m-%d") + ), + ) + start_date: datetime = ( + st.date_input("Start date", key="start_date") or start_of_year_date + ) + st.session_state.setdefault( + "end_date", + self.request.query_params.get( + "end_date", datetime.now().strftime("%Y-%m-%d") + ), + ) + end_date: datetime = ( + st.date_input("End date", key="end_date") or datetime.now() + ) + st.session_state.setdefault( + "view", self.request.query_params.get("view", "Weekly") + ) + st.write("---") view = st.horizontal_radio( "### View", options=["Daily", "Weekly", "Monthly"], From 570963c1e99dfedbb6a90192be4c1d601450bb9a Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 13:53:49 -0800 Subject: [PATCH 26/58] tables adhere to data filters --- recipes/VideoBotsStats.py | 60 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 12eabf51f..f41add98f 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -210,7 +210,15 @@ def render(self): sort_by = st.session_state["sort_by"] df = self.get_tabular_data( - bi, run_url, conversations, messages, details, sort_by, rows=500 + bi, + run_url, + conversations, + messages, + details, + sort_by, + rows=500, + start_date=start_date, + end_date=end_date, ) if not df.empty: @@ -233,7 +241,14 @@ def render(self): st.html("
") if st.checkbox("Export"): df = self.get_tabular_data( - bi, run_url, conversations, messages, details, sort_by + bi, + run_url, + conversations, + messages, + details, + sort_by, + start_date=start_date, + end_date=end_date, ) csv = df.to_csv() b64 = base64.b64encode(csv.encode()).decode() @@ -642,12 +657,29 @@ def plot_graphs(self, view, df): st.plotly_chart(fig) def get_tabular_data( - self, bi, run_url, conversations, messages, details, sort_by, rows=10000 + self, + bi, + run_url, + conversations, + messages, + details, + sort_by, + rows=10000, + start_date=None, + end_date=None, ): df = pd.DataFrame() if details == "Conversations": + if start_date and end_date: + conversations = conversations.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = conversations.to_df_format(row_limit=rows) elif details == "Messages": + if start_date and end_date: + messages = messages.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = messages.order_by("-created_at", "conversation__id").to_df_format( row_limit=rows ) @@ -658,6 +690,10 @@ def get_tabular_data( message__conversation__bot_integration=bi, rating=Feedback.Rating.RATING_THUMBS_UP, ) # type: ignore + if start_date and end_date: + pos_feedbacks = pos_feedbacks.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = pos_feedbacks.to_df_format(row_limit=rows) df["Run URL"] = run_url df["Bot"] = bi.name @@ -665,7 +701,13 @@ def get_tabular_data( neg_feedbacks: FeedbackQuerySet = Feedback.objects.filter( message__conversation__bot_integration=bi, rating=Feedback.Rating.RATING_THUMBS_DOWN, + created_at__date__gte=start_date, + created_at__date__lte=end_date, ) # type: ignore + if start_date and end_date: + neg_feedbacks = neg_feedbacks.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = neg_feedbacks.to_df_format(row_limit=rows) df["Run URL"] = run_url df["Bot"] = bi.name @@ -673,7 +715,13 @@ def get_tabular_data( successful_messages: MessageQuerySet = Message.objects.filter( conversation__bot_integration=bi, analysis_result__contains={"Answered": True}, + created_at__date__gte=start_date, + created_at__date__lte=end_date, ) # type: ignore + if start_date and end_date: + successful_messages = successful_messages.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = successful_messages.to_df_analysis_format(row_limit=rows) df["Run URL"] = run_url df["Bot"] = bi.name @@ -681,7 +729,13 @@ def get_tabular_data( unsuccessful_messages: MessageQuerySet = Message.objects.filter( conversation__bot_integration=bi, analysis_result__contains={"Answered": False}, + created_at__date__gte=start_date, + created_at__date__lte=end_date, ) # type: ignore + if start_date and end_date: + unsuccessful_messages = unsuccessful_messages.filter( + created_at__date__gte=start_date, created_at__date__lte=end_date + ) df = unsuccessful_messages.to_df_analysis_format(row_limit=rows) df["Run URL"] = run_url df["Bot"] = bi.name From 747cead27e41e1f683c4b9c658d5e304208c0bf9 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 14:00:15 -0800 Subject: [PATCH 27/58] round response time --- bots/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bots/models.py b/bots/models.py index 22008019f..edb7e183a 100644 --- a/bots/models.py +++ b/bots/models.py @@ -927,7 +927,7 @@ def to_df_format( else None ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, - "Response Time": message.response_time.total_seconds(), + "Response Time": round(message.response_time.total_seconds(), 1), } rows.append(row) df = pd.DataFrame.from_records( From 819cf1f80fc7fe8fd5ddd8bdaa91051a4224b39b Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 22:02:03 -0800 Subject: [PATCH 28/58] split of graphs --- recipes/VideoBotsStats.py | 80 +++++++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 8 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index f41add98f..08787f1b3 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -499,6 +499,12 @@ def calculate_stats_binned_by_time( df["Unique_feedback_givers"] = df["Unique_feedback_givers"] * factor df["Pos_feedback"] = df["Pos_feedback"] * factor df["Neg_feedback"] = df["Neg_feedback"] * factor + df["Percentage_positive_feedback"] = ( + df["Pos_feedback"] / (df["Pos_feedback"] + df["Neg_feedback"]) + ) * 100 + df["Percentage_negative_feedback"] = ( + df["Neg_feedback"] / (df["Pos_feedback"] + df["Neg_feedback"]) + ) * 100 df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] df["Average_response_time"] = df["Average_response_time"] * factor @@ -602,14 +608,6 @@ def plot_graphs(self, view, df): text=list(df["Msgs_per_user"]), hovertemplate="Messages per User: %{y:.0f}", ), - go.Scatter( - name="Average Response Time", - mode="lines+markers", - x=list(df["date"]), - y=list(df["Average_response_time"]), - text=list(df["Average_response_time"]), - hovertemplate="Average Response Time: %{y:.0f}", - ), ], layout=dict( margin=dict(l=0, r=0, t=28, b=0), @@ -655,6 +653,72 @@ def plot_graphs(self, view, df): ], ) st.plotly_chart(fig) + st.markdown("
", unsafe_allow_html=True) + fig = go.Figure( + data=[ + go.Scatter( + name="Positive Feedback", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Percentage_positive_feedback"]), + text=list(df["Percentage_positive_feedback"]), + hovertemplate="Positive Feedback: %{y:.0f}\\%", + ), + go.Scatter( + name="Negative Feedback", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Percentage_negative_feedback"]), + text=list(df["Percentage_negative_feedback"]), + hovertemplate="Negative Feedback: %{y:.0f}\\%", + ), + ], + layout=dict( + margin=dict(l=0, r=0, t=28, b=0), + yaxis=dict( + title="Percentage", + range=[0, 100], + tickvals=[ + *range( + 0, + 101, + 10, + ) + ], + ), + title=dict( + text=f"{view} Feedback Distribution", + ), + height=300, + template="plotly_white", + ), + ) + st.plotly_chart(fig) + st.write("---") + fig = go.Figure( + data=[ + go.Scatter( + name="Average Response Time", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Average_response_time"]), + text=list(df["Average_response_time"]), + hovertemplate="Average Response Time: %{y:.0f}", + ), + ], + layout=dict( + margin=dict(l=0, r=0, t=28, b=0), + yaxis=dict( + title="Seconds", + ), + title=dict( + text=f"{view} Performance Metrics", + ), + height=300, + template="plotly_white", + ), + ) + st.plotly_chart(fig) def get_tabular_data( self, From 2a5d31969c86a07e0a6180c5c75829c268253e8d Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Thu, 1 Feb 2024 22:50:30 -0800 Subject: [PATCH 29/58] sean tweaks and reordering of graphs --- recipes/VideoBotsStats.py | 134 ++++++++++++++++++++++++++++---------- 1 file changed, 99 insertions(+), 35 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 08787f1b3..718b580f1 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -119,6 +119,7 @@ def render(self): if int(bid) not in allowed_bids: bid = allowed_bids[0] bi = BotIntegration.objects.get(id=bid) + has_analysis_run = bi.analysis_run is not None run_title, run_url = self.parse_run_info(bi) self.show_title_breadcrumb_share(run_title, run_url, bi) @@ -153,14 +154,22 @@ def render(self): st.session_state.setdefault("details", self.request.query_params.get("details")) details = st.horizontal_radio( "### Details", - options=[ - "Conversations", - "Messages", - "Feedback Positive", - "Feedback Negative", - "Answered Successfully", - "Answered Unsuccessfully", - ], + options=( + [ + "Conversations", + "Messages", + "Feedback Positive", + "Feedback Negative", + ] + + ( + [ + "Answered Successfully", + "Answered Unsuccessfully", + ] + if has_analysis_run + else [] + ) + ), key="details", ) @@ -430,7 +439,9 @@ def calculate_stats_binned_by_time( distinct=True, ) ) + .annotate(Average_runtime=Avg("saved_run__run_time")) .annotate(Average_response_time=Avg("response_time")) + .annotate(Average_analysis_time=Avg("analysis_run__run_time")) .annotate(Unique_feedback_givers=Count("feedbacks", distinct=True)) .values( "date", @@ -439,6 +450,8 @@ def calculate_stats_binned_by_time( "Senders", "Unique_feedback_givers", "Average_response_time", + "Average_runtime", + "Average_analysis_time", ) ) @@ -470,6 +483,20 @@ def calculate_stats_binned_by_time( .values("date", "Neg_feedback") ) + successfully_answered = ( + Message.objects.filter( + conversation__bot_integration=bi, + analysis_result__contains={"Answered": True}, + created_at__date__gte=start_date, + created_at__date__lte=end_date, + ) + .order_by() + .annotate(date=trunc_fn("created_at")) + .values("date") + .annotate(Successfully_Answered=Count("id")) + .values("date", "Successfully_Answered") + ) + df = pd.DataFrame( messages_received, columns=[ @@ -479,6 +506,8 @@ def calculate_stats_binned_by_time( "Senders", "Unique_feedback_givers", "Average_response_time", + "Average_runtime", + "Average_analysis_time", ], ) df = df.merge( @@ -493,6 +522,14 @@ def calculate_stats_binned_by_time( left_on="date", right_on="date", ) + df = df.merge( + pd.DataFrame( + successfully_answered, columns=["date", "Successfully_Answered"] + ), + how="outer", + left_on="date", + right_on="date", + ) df["Messages_Sent"] = df["Messages_Sent"] * factor df["Convos"] = df["Convos"] * factor df["Senders"] = df["Senders"] * factor @@ -500,10 +537,13 @@ def calculate_stats_binned_by_time( df["Pos_feedback"] = df["Pos_feedback"] * factor df["Neg_feedback"] = df["Neg_feedback"] * factor df["Percentage_positive_feedback"] = ( - df["Pos_feedback"] / (df["Pos_feedback"] + df["Neg_feedback"]) + df["Pos_feedback"] / df["Messages_Sent"] ) * 100 df["Percentage_negative_feedback"] = ( - df["Neg_feedback"] / (df["Pos_feedback"] + df["Neg_feedback"]) + df["Neg_feedback"] / df["Messages_Sent"] + ) * 100 + df["Percentage_successfully_answered"] = ( + df["Successfully_Answered"] / df["Messages_Sent"] ) * 100 df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] @@ -653,41 +693,41 @@ def plot_graphs(self, view, df): ], ) st.plotly_chart(fig) - st.markdown("
", unsafe_allow_html=True) + st.write("---") fig = go.Figure( data=[ go.Scatter( - name="Positive Feedback", + name="Average Response Time", mode="lines+markers", x=list(df["date"]), - y=list(df["Percentage_positive_feedback"]), - text=list(df["Percentage_positive_feedback"]), - hovertemplate="Positive Feedback: %{y:.0f}\\%", + y=list(df["Average_response_time"]), + text=list(df["Average_response_time"]), + hovertemplate="Average Response Time: %{y:.0f}", ), go.Scatter( - name="Negative Feedback", + name="Average Run Time", mode="lines+markers", x=list(df["date"]), - y=list(df["Percentage_negative_feedback"]), - text=list(df["Percentage_negative_feedback"]), - hovertemplate="Negative Feedback: %{y:.0f}\\%", + y=list(df["Average_runtime"]), + text=list(df["Average_runtime"]), + hovertemplate="Average Runtime: %{y:.0f}", + ), + go.Scatter( + name="Average Analysis Time", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Average_analysis_time"]), + text=list(df["Average_analysis_time"]), + hovertemplate="Average Analysis Time: %{y:.0f}", ), ], layout=dict( margin=dict(l=0, r=0, t=28, b=0), yaxis=dict( - title="Percentage", - range=[0, 100], - tickvals=[ - *range( - 0, - 101, - 10, - ) - ], + title="Seconds", ), title=dict( - text=f"{view} Feedback Distribution", + text=f"{view} Performance Metrics", ), height=300, template="plotly_white", @@ -698,21 +738,45 @@ def plot_graphs(self, view, df): fig = go.Figure( data=[ go.Scatter( - name="Average Response Time", + name="Positive Feedback", mode="lines+markers", x=list(df["date"]), - y=list(df["Average_response_time"]), - text=list(df["Average_response_time"]), - hovertemplate="Average Response Time: %{y:.0f}", + y=list(df["Percentage_positive_feedback"]), + text=list(df["Percentage_positive_feedback"]), + hovertemplate="Positive Feedback: %{y:.0f}%", + ), + go.Scatter( + name="Negative Feedback", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Percentage_negative_feedback"]), + text=list(df["Percentage_negative_feedback"]), + hovertemplate="Negative Feedback: %{y:.0f}%", + ), + go.Scatter( + name="Successfully Answered", + mode="lines+markers", + x=list(df["date"]), + y=list(df["Percentage_successfully_answered"]), + text=list(df["Percentage_successfully_answered"]), + hovertemplate="Successfully Answered: %{y:.0f}%", ), ], layout=dict( margin=dict(l=0, r=0, t=28, b=0), yaxis=dict( - title="Seconds", + title="Percentage", + range=[0, 100], + tickvals=[ + *range( + 0, + 101, + 10, + ) + ], ), title=dict( - text=f"{view} Performance Metrics", + text=f"{view} Feedback Distribution", ), height=300, template="plotly_white", From d2da2d10d84c6ca7e07e964286333cb7847f39dc Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 20:33:57 +0530 Subject: [PATCH 30/58] use timezone.now() --- daras_ai_v2/bots.py | 13 +++++-------- recipes/VideoBots.py | 6 +++--- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index edaf0ed32..9f2ebd555 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -1,16 +1,15 @@ import mimetypes import traceback import typing +from datetime import datetime from urllib.parse import parse_qs -import pytz -from datetime import datetime from django.db import transaction +from django.utils import timezone from fastapi import HTTPException, Request from furl import furl from sentry_sdk import capture_exception -from daras_ai_v2 import settings from app_users.models import AppUser from bots.models import ( Platform, @@ -191,7 +190,7 @@ def _on_msg(bot: BotInterface): speech_run = None input_images = None input_documents = None - recieved_time: datetime = datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) + recieved_time: datetime = timezone.now() if not bot.page_cls: bot.send_msg(text=PAGE_NOT_CONNECTED_ERROR) return @@ -463,8 +462,7 @@ def _save_msgs( if speech_run else None ), - response_time=datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) - - received_time, + response_time=timezone.now() - received_time, ) attachments = [] for f_url in (input_images or []) + (input_documents or []): @@ -481,8 +479,7 @@ def _save_msgs( saved_run=SavedRun.objects.get_or_create( workflow=Workflow.VIDEO_BOTS, **furl(url).query.params )[0], - response_time=datetime.now(tz=pytz.timezone(settings.TIME_ZONE)) - - received_time, + response_time=timezone.now() - received_time, ) # save the messages & attachments with transaction.atomic(): diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 5d23502ff..75615e6a7 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1124,9 +1124,9 @@ def slack_specific_settings(self, bi: BotIntegration): st.session_state[f"_bi_name_{bi.id}"] = ( pr and pr.title ) or self.get_recipe_title() - st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( - BotIntegration._meta.get_field("slack_read_receipt_msg").default - ) + st.session_state[ + f"_bi_slack_read_receipt_msg_{bi.id}" + ] = BotIntegration._meta.get_field("slack_read_receipt_msg").default bi.slack_read_receipt_msg = st.text_input( """ From b063439037904c706e4101db916f1b68e9fbc366 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 20:42:30 +0530 Subject: [PATCH 31/58] make response_time null by default --- bots/migrations/0056_message_response_time.py | 19 --------------- .../0057_message_response_time_and_more.py | 23 +++++++++++++++++++ bots/models.py | 8 +++++-- 3 files changed, 29 insertions(+), 21 deletions(-) delete mode 100644 bots/migrations/0056_message_response_time.py create mode 100644 bots/migrations/0057_message_response_time_and_more.py diff --git a/bots/migrations/0056_message_response_time.py b/bots/migrations/0056_message_response_time.py deleted file mode 100644 index 06730e931..000000000 --- a/bots/migrations/0056_message_response_time.py +++ /dev/null @@ -1,19 +0,0 @@ -# Generated by Django 4.2.7 on 2024-02-01 20:15 - -import datetime -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('bots', '0055_workflowmetadata'), - ] - - operations = [ - migrations.AddField( - model_name='message', - name='response_time', - field=models.DurationField(default=datetime.timedelta(days=-1, seconds=86399), help_text='The time it took for the bot to respond to the corresponding user message'), - ), - ] diff --git a/bots/migrations/0057_message_response_time_and_more.py b/bots/migrations/0057_message_response_time_and_more.py new file mode 100644 index 000000000..5bdecd39c --- /dev/null +++ b/bots/migrations/0057_message_response_time_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.7 on 2024-02-05 15:11 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0056_botintegration_streaming_enabled'), + ] + + operations = [ + migrations.AddField( + model_name='message', + name='response_time', + field=models.DurationField(default=None, help_text='The time it took for the bot to respond to the corresponding user message', null=True), + ), + migrations.AlterField( + model_name='botintegration', + name='streaming_enabled', + field=models.BooleanField(default=False, help_text='If set, the bot will stream messages to the frontend (Slack only)'), + ), + ] diff --git a/bots/models.py b/bots/models.py index 4aa2c3d4a..8e2a41dc8 100644 --- a/bots/models.py +++ b/bots/models.py @@ -932,7 +932,10 @@ def to_df_format( else None ), # only show first feedback as per Sean's request "Analysis JSON": message.analysis_result, - "Response Time": round(message.response_time.total_seconds(), 1), + "Response Time": ( + message.response_time + and round(message.response_time.total_seconds(), 1) + ), } rows.append(row) df = pd.DataFrame.from_records( @@ -1054,7 +1057,8 @@ class Message(models.Model): ) response_time = models.DurationField( - default=datetime.timedelta(seconds=-1), + default=None, + null=True, help_text="The time it took for the bot to respond to the corresponding user message", ) From f5466022b9eb943cf5d89bbb72f503b9a22b8df3 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 20:56:38 +0530 Subject: [PATCH 32/58] fix run_time display in copilot stats add resposne_time to admin --- bots/admin.py | 2 ++ recipes/VideoBotsStats.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index 4f69a1081..0b6b475cc 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -492,6 +492,7 @@ class MessageAdmin(admin.ModelAdmin): "prev_msg_content", "prev_msg_display_content", "prev_msg_saved_run", + "response_time", ] ordering = ["created_at"] actions = [export_to_csv, export_to_excel] @@ -551,6 +552,7 @@ def get_fieldsets(self, request, msg: Message = None): "Analysis", { "fields": [ + "response_time", "analysis_result", "analysis_run", "question_answered", diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 718b580f1..db88f4dde 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -338,15 +338,12 @@ def render_date_view_inputs(self, bi): def parse_run_info(self, bi): saved_run = bi.get_active_saved_run() - run_title = ( - bi.published_run.title - if bi.published_run - else ( - saved_run.page_title - if saved_run and saved_run.page_title - else "This Copilot Run" if saved_run else "No Run Connected" - ) - ) + if bi.published_run: + run_title = bi.published_run.title + elif saved_run: + run_title = "This Copilot Run" + else: + run_title = "No Run Connected" run_url = furl(saved_run.get_app_url()).tostr() if saved_run else "" return run_title, run_url @@ -424,7 +421,7 @@ def calculate_stats_binned_by_time( created_at__date__gte=start_date, created_at__date__lte=end_date, conversation__bot_integration=bi, - role=CHATML_ROLE_USER, + role=CHATML_ROLE_ASSISTANT, ) .order_by() .annotate(date=trunc_fn("created_at")) From 2be2378a5e1142c4b8b8f503850d76154d954882 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:02:33 +0530 Subject: [PATCH 33/58] use timezone.now() instead of datetime.now() --- recipes/VideoBotsStats.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index db88f4dde..c18683b96 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -1,3 +1,5 @@ +from django.utils import timezone + from daras_ai_v2.base import BasePage, MenuTabs import gooey_ui as st from furl import furl @@ -287,9 +289,9 @@ def render(self): def render_date_view_inputs(self, bi): if st.checkbox("Show All"): start_date = bi.created_at - end_date = datetime.now() + end_date = timezone.now() else: - start_of_year_date = datetime.now().replace(month=1, day=1) + start_of_year_date = timezone.now().replace(month=1, day=1) st.session_state.setdefault( "start_date", self.request.query_params.get( @@ -302,11 +304,11 @@ def render_date_view_inputs(self, bi): st.session_state.setdefault( "end_date", self.request.query_params.get( - "end_date", datetime.now().strftime("%Y-%m-%d") + "end_date", timezone.now().strftime("%Y-%m-%d") ), ) end_date: datetime = ( - st.date_input("End date", key="end_date") or datetime.now() + st.date_input("End date", key="end_date") or timezone.now() ) st.session_state.setdefault( "view", self.request.query_params.get("view", "Weekly") @@ -359,7 +361,7 @@ def calculate_overall_stats(self, bid, bi, run_title, run_url): num_active_users_last_7_days = ( user_messages.filter( conversation__in=users, - created_at__gte=datetime.now() - timedelta(days=7), + created_at__gte=timezone.now() - timedelta(days=7), ) .distinct( *ID_COLUMNS, @@ -369,7 +371,7 @@ def calculate_overall_stats(self, bid, bi, run_title, run_url): num_active_users_last_30_days = ( user_messages.filter( conversation__in=users, - created_at__gte=datetime.now() - timedelta(days=30), + created_at__gte=timezone.now() - timedelta(days=30), ) .distinct( *ID_COLUMNS, @@ -544,7 +546,9 @@ def calculate_stats_binned_by_time( ) * 100 df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] - df["Average_response_time"] = df["Average_response_time"] * factor + df["Average_response_time"] = ( + df["Average_response_time"].dt.total_seconds() * factor + ) df.fillna(0, inplace=True) df = df.round(0).astype("int32", errors="ignore") return df From 3ac7299c9ed18ac9948ac88768d9d8f74daae0ec Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:06:32 +0530 Subject: [PATCH 34/58] silence black --- recipes/VideoBots.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 75615e6a7..5d23502ff 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1124,9 +1124,9 @@ def slack_specific_settings(self, bi: BotIntegration): st.session_state[f"_bi_name_{bi.id}"] = ( pr and pr.title ) or self.get_recipe_title() - st.session_state[ - f"_bi_slack_read_receipt_msg_{bi.id}" - ] = BotIntegration._meta.get_field("slack_read_receipt_msg").default + st.session_state[f"_bi_slack_read_receipt_msg_{bi.id}"] = ( + BotIntegration._meta.get_field("slack_read_receipt_msg").default + ) bi.slack_read_receipt_msg = st.text_input( """ From 4aa53e70bfae3f0b1ee650f4cf945566135ebee9 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:27:43 +0530 Subject: [PATCH 35/58] fix AttributeError: Can only use .dt accessor with datetimelike values --- recipes/VideoBotsStats.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index c18683b96..00c7f07dd 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -546,9 +546,11 @@ def calculate_stats_binned_by_time( ) * 100 df["Msgs_per_convo"] = df["Messages_Sent"] / df["Convos"] df["Msgs_per_user"] = df["Messages_Sent"] / df["Senders"] - df["Average_response_time"] = ( - df["Average_response_time"].dt.total_seconds() * factor - ) + try: + df["Average_response_time"] = df["Average_response_time"].dt.total_seconds() + except AttributeError: + pass + df["Average_response_time"] = df["Average_response_time"] * factor df.fillna(0, inplace=True) df = df.round(0).astype("int32", errors="ignore") return df From 84a77a4f1efb945a4aa7e30caec9bf3d73b45b35 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:13:16 +0530 Subject: [PATCH 36/58] update black --- poetry.lock | 44 ++++++++++++++++++++++++-------------------- pyproject.toml | 2 +- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/poetry.lock b/poetry.lock index a5c39450e..0f225fd4e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -323,29 +323,33 @@ files = [ [[package]] name = "black" -version = "23.10.1" +version = "23.12.1" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.10.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:ec3f8e6234c4e46ff9e16d9ae96f4ef69fa328bb4ad08198c8cee45bb1f08c69"}, - {file = "black-23.10.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:1b917a2aa020ca600483a7b340c165970b26e9029067f019e3755b56e8dd5916"}, - {file = "black-23.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c74de4c77b849e6359c6f01987e94873c707098322b91490d24296f66d067dc"}, - {file = "black-23.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:7b4d10b0f016616a0d93d24a448100adf1699712fb7a4efd0e2c32bbb219b173"}, - {file = "black-23.10.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b15b75fc53a2fbcac8a87d3e20f69874d161beef13954747e053bca7a1ce53a0"}, - {file = "black-23.10.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:e293e4c2f4a992b980032bbd62df07c1bcff82d6964d6c9496f2cd726e246ace"}, - {file = "black-23.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d56124b7a61d092cb52cce34182a5280e160e6aff3137172a68c2c2c4b76bcb"}, - {file = "black-23.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f157a8945a7b2d424da3335f7ace89c14a3b0625e6593d21139c2d8214d55ce"}, - {file = "black-23.10.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:cfcce6f0a384d0da692119f2d72d79ed07c7159879d0bb1bb32d2e443382bf3a"}, - {file = "black-23.10.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:33d40f5b06be80c1bbce17b173cda17994fbad096ce60eb22054da021bf933d1"}, - {file = "black-23.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:840015166dbdfbc47992871325799fd2dc0dcf9395e401ada6d88fe11498abad"}, - {file = "black-23.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:037e9b4664cafda5f025a1728c50a9e9aedb99a759c89f760bd83730e76ba884"}, - {file = "black-23.10.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:7cb5936e686e782fddb1c73f8aa6f459e1ad38a6a7b0e54b403f1f05a1507ee9"}, - {file = "black-23.10.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:7670242e90dc129c539e9ca17665e39a146a761e681805c54fbd86015c7c84f7"}, - {file = "black-23.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed45ac9a613fb52dad3b61c8dea2ec9510bf3108d4db88422bacc7d1ba1243d"}, - {file = "black-23.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6d23d7822140e3fef190734216cefb262521789367fbdc0b3f22af6744058982"}, - {file = "black-23.10.1-py3-none-any.whl", hash = "sha256:d431e6739f727bb2e0495df64a6c7a5310758e87505f5f8cde9ff6c0f2d7e4fe"}, - {file = "black-23.10.1.tar.gz", hash = "sha256:1f8ce316753428ff68749c65a5f7844631aa18c8679dfd3ca9dc1a289979c258"}, + {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, + {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, + {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, + {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, + {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, + {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, + {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, + {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, + {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, + {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, + {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, + {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, + {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, + {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, + {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, + {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, + {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, + {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, + {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, + {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, + {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, + {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, ] [package.dependencies] @@ -359,7 +363,7 @@ typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} [package.extras] colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] diff --git a/pyproject.toml b/pyproject.toml index aac74b09c..44dbff5da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ django-phonenumber-field = { extras = ["phonenumberslite"], version = "^7.0.2" } gunicorn = "^20.1.0" psycopg2-binary = "^2.9.6" whitenoise = "^6.4.0" -black = "^23.3.0" +black = "^24.1.1" django-extensions = "^3.2.1" pytest-django = "^4.5.2" celery = "^5.3.1" From 3b321f5e1e974538121cb148be9ace1bede53190 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 5 Feb 2024 21:26:43 +0530 Subject: [PATCH 37/58] fix test urls --- conftest.py | 5 +++-- tests/test_apis.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/conftest.py b/conftest.py index e8534d5f2..dba333885 100644 --- a/conftest.py +++ b/conftest.py @@ -47,8 +47,9 @@ def _mock_gui_runner( def threadpool_subtest(subtests, max_workers: int = 8): ts = [] - def submit(fn, *args, **kwargs): - msg = "--".join(map(str, [*args, *kwargs.values()])) + def submit(fn, *args, msg=None, **kwargs): + if not msg: + msg = "--".join(map(str, [*args, *kwargs.values()])) @wraps(fn) def runner(*args, **kwargs): diff --git a/tests/test_apis.py b/tests/test_apis.py index 6412e9eda..96c27c7ef 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -73,7 +73,7 @@ def test_apis_examples(mock_gui_runner, force_authentication, threadpool_subtest hidden=False, example_id__isnull=False, ): - threadpool_subtest(_test_apis_examples, sr) + threadpool_subtest(_test_apis_examples, sr, msg=sr.get_app_url()) def _test_apis_examples(sr: SavedRun): From 1c0c990bb450934b767bc972e66df9f12f41293c Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 6 Feb 2024 19:35:11 +0530 Subject: [PATCH 38/58] Copilot: Turn off collecting detailed feedback off thumbs up or down --- daras_ai_v2/bots.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 9f2ebd555..30acf07a3 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -509,13 +509,25 @@ def _handle_interactive_msg(bot: BotInterface): return if button_id == ButtonIds.feedback_thumbs_up: rating = Feedback.Rating.RATING_THUMBS_UP - bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_UP - response_text = FEEDBACK_THUMBS_UP_MSG + # bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_UP + # response_text = FEEDBACK_THUMBS_UP_MSG else: rating = Feedback.Rating.RATING_THUMBS_DOWN - bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_DOWN - response_text = FEEDBACK_THUMBS_DOWN_MSG + # bot.convo.state = ConvoState.ASK_FOR_FEEDBACK_THUMBS_DOWN + # response_text = FEEDBACK_THUMBS_DOWN_MSG + response_text = FEEDBACK_CONFIRMED_MSG.format( + bot_name=str(bot.convo.bot_integration.name) + ) bot.convo.save() + # save the feedback + Feedback.objects.create(message=context_msg, rating=rating) + # send a confirmation msg + post click buttons + bot.send_msg( + text=response_text, + # buttons=_feedback_post_click_buttons(), + should_translate=True, + ) + # handle skip case ButtonIds.action_skip: bot.send_msg(text=TAPPED_SKIP_MSG, should_translate=True) @@ -523,6 +535,7 @@ def _handle_interactive_msg(bot: BotInterface): bot.convo.state = ConvoState.INITIAL bot.convo.save() return + # not sure what button was pressed, ignore case _: bot_name = str(bot.convo.bot_integration.name) @@ -534,14 +547,6 @@ def _handle_interactive_msg(bot: BotInterface): bot.convo.state = ConvoState.INITIAL bot.convo.save() return - # save the feedback - Feedback.objects.create(message=context_msg, rating=rating) - # send a confirmation msg + post click buttons - bot.send_msg( - text=response_text, - buttons=_feedback_post_click_buttons(), - should_translate=True, - ) def _handle_audio_msg(billing_account_user, bot: BotInterface): From 5d82e5db8ca5c702d16f1118705e80b0625e4e66 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 6 Feb 2024 22:46:08 +0530 Subject: [PATCH 39/58] lockfile --- poetry.lock | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0f225fd4e..587185ab8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -323,33 +323,33 @@ files = [ [[package]] name = "black" -version = "23.12.1" +version = "24.1.1" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, - {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, - {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, - {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, - {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, - {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, - {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, - {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, - {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, - {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, - {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, - {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, - {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, - {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, - {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, - {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, - {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, - {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, - {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, - {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, - {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, - {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, + {file = "black-24.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2588021038bd5ada078de606f2a804cadd0a3cc6a79cb3e9bb3a8bf581325a4c"}, + {file = "black-24.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a95915c98d6e32ca43809d46d932e2abc5f1f7d582ffbe65a5b4d1588af7445"}, + {file = "black-24.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fa6a0e965779c8f2afb286f9ef798df770ba2b6cee063c650b96adec22c056a"}, + {file = "black-24.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:5242ecd9e990aeb995b6d03dc3b2d112d4a78f2083e5a8e86d566340ae80fec4"}, + {file = "black-24.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fc1ec9aa6f4d98d022101e015261c056ddebe3da6a8ccfc2c792cbe0349d48b7"}, + {file = "black-24.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0269dfdea12442022e88043d2910429bed717b2d04523867a85dacce535916b8"}, + {file = "black-24.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3d64db762eae4a5ce04b6e3dd745dcca0fb9560eb931a5be97472e38652a161"}, + {file = "black-24.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:5d7b06ea8816cbd4becfe5f70accae953c53c0e53aa98730ceccb0395520ee5d"}, + {file = "black-24.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e2c8dfa14677f90d976f68e0c923947ae68fa3961d61ee30976c388adc0b02c8"}, + {file = "black-24.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a21725862d0e855ae05da1dd25e3825ed712eaaccef6b03017fe0853a01aa45e"}, + {file = "black-24.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07204d078e25327aad9ed2c64790d681238686bce254c910de640c7cc4fc3aa6"}, + {file = "black-24.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:a83fe522d9698d8f9a101b860b1ee154c1d25f8a82ceb807d319f085b2627c5b"}, + {file = "black-24.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08b34e85170d368c37ca7bf81cf67ac863c9d1963b2c1780c39102187ec8dd62"}, + {file = "black-24.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7258c27115c1e3b5de9ac6c4f9957e3ee2c02c0b39222a24dc7aa03ba0e986f5"}, + {file = "black-24.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40657e1b78212d582a0edecafef133cf1dd02e6677f539b669db4746150d38f6"}, + {file = "black-24.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e298d588744efda02379521a19639ebcd314fba7a49be22136204d7ed1782717"}, + {file = "black-24.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:34afe9da5056aa123b8bfda1664bfe6fb4e9c6f311d8e4a6eb089da9a9173bf9"}, + {file = "black-24.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:854c06fb86fd854140f37fb24dbf10621f5dab9e3b0c29a690ba595e3d543024"}, + {file = "black-24.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3897ae5a21ca132efa219c029cce5e6bfc9c3d34ed7e892113d199c0b1b444a2"}, + {file = "black-24.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:ecba2a15dfb2d97105be74bbfe5128bc5e9fa8477d8c46766505c1dda5883aac"}, + {file = "black-24.1.1-py3-none-any.whl", hash = "sha256:5cdc2e2195212208fbcae579b931407c1fa9997584f0a415421748aeafff1168"}, + {file = "black-24.1.1.tar.gz", hash = "sha256:48b5760dcbfe5cf97fd4fba23946681f3a81514c6ab8a45b50da67ac8fbc6c7b"}, ] [package.dependencies] @@ -6514,4 +6514,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "a0d934b2a9b3b5c54d3b14ad5e524ff4b163eaa5eaeb794fe38bc3a8d8d60e04" +content-hash = "798d4cfff5a447e83bd2aa675525c89d6258c927839260a6f20aa71a633a7e81" From 570da24520935bc18457a740ac715740ceeaa7d1 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 6 Feb 2024 22:24:43 +0530 Subject: [PATCH 40/58] support openai streaming squash migrations better lookup for model when saving cost rename costs -> usage_costs, provider pricing -> model pricing add init script for llm pricing --- ...alter_savedrun_unique_together_and_more.py | 21 ++ bots/models.py | 3 +- celeryapp/tasks.py | 1 - costs/admin.py | 33 -- costs/cost_utils.py | 36 -- costs/migrations/0001_initial.py | 59 --- costs/migrations/0002_providerpricing.py | 79 ---- ...t_model_remove_usagecost_param_and_more.py | 64 ---- .../0004_remove_usagecost_saved_run.py | 16 - .../0005_remove_usagecost_provider_pricing.py | 16 - ...st_provider_pricing_usagecost_saved_run.py | 38 -- .../0007_alter_providerpricing_provider.py | 25 -- ...08_alter_providerpricing_param_and_more.py | 33 -- .../0009_alter_providerpricing_product.py | 19 - ...0010_remove_usagecost_calculation_notes.py | 16 - .../0011_alter_providerpricing_product.py | 35 -- .../0012_alter_providerpricing_product.py | 37 -- .../0013_alter_providerpricing_product.py | 37 -- .../0014_alter_providerpricing_product.py | 40 --- .../0015_alter_providerpricing_product.py | 43 --- costs/models.py | 113 ------ daras_ai_v2/language_model.py | 159 +++++--- daras_ai_v2/settings.py | 2 +- scripts/init_llm_pricing.py | 338 ++++++++++++++++++ {costs => usage_costs}/__init__.py | 0 usage_costs/admin.py | 75 ++++ {costs => usage_costs}/apps.py | 4 +- usage_costs/cost_utils.py | 34 ++ usage_costs/migrations/0001_initial.py | 56 +++ {costs => usage_costs}/migrations/__init__.py | 0 usage_costs/models.py | 114 ++++++ {costs => usage_costs}/tests.py | 0 {costs => usage_costs}/views.py | 0 33 files changed, 752 insertions(+), 794 deletions(-) create mode 100644 bots/migrations/0058_alter_savedrun_unique_together_and_more.py delete mode 100644 costs/admin.py delete mode 100644 costs/cost_utils.py delete mode 100644 costs/migrations/0001_initial.py delete mode 100644 costs/migrations/0002_providerpricing.py delete mode 100644 costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py delete mode 100644 costs/migrations/0004_remove_usagecost_saved_run.py delete mode 100644 costs/migrations/0005_remove_usagecost_provider_pricing.py delete mode 100644 costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py delete mode 100644 costs/migrations/0007_alter_providerpricing_provider.py delete mode 100644 costs/migrations/0008_alter_providerpricing_param_and_more.py delete mode 100644 costs/migrations/0009_alter_providerpricing_product.py delete mode 100644 costs/migrations/0010_remove_usagecost_calculation_notes.py delete mode 100644 costs/migrations/0011_alter_providerpricing_product.py delete mode 100644 costs/migrations/0012_alter_providerpricing_product.py delete mode 100644 costs/migrations/0013_alter_providerpricing_product.py delete mode 100644 costs/migrations/0014_alter_providerpricing_product.py delete mode 100644 costs/migrations/0015_alter_providerpricing_product.py delete mode 100644 costs/models.py create mode 100644 scripts/init_llm_pricing.py rename {costs => usage_costs}/__init__.py (100%) create mode 100644 usage_costs/admin.py rename {costs => usage_costs}/apps.py (67%) create mode 100644 usage_costs/cost_utils.py create mode 100644 usage_costs/migrations/0001_initial.py rename {costs => usage_costs}/migrations/__init__.py (100%) create mode 100644 usage_costs/models.py rename {costs => usage_costs}/tests.py (100%) rename {costs => usage_costs}/views.py (100%) diff --git a/bots/migrations/0058_alter_savedrun_unique_together_and_more.py b/bots/migrations/0058_alter_savedrun_unique_together_and_more.py new file mode 100644 index 000000000..a4e137127 --- /dev/null +++ b/bots/migrations/0058_alter_savedrun_unique_together_and_more.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.7 on 2024-02-06 18:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0057_message_response_time_and_more'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='savedrun', + unique_together={('run_id', 'uid'), ('workflow', 'example_id')}, + ), + migrations.AddIndex( + model_name='savedrun', + index=models.Index(fields=['run_id', 'uid'], name='bots_savedr_run_id_7b0b34_idx'), + ), + ] diff --git a/bots/models.py b/bots/models.py index 8e2a41dc8..205dfb5ff 100644 --- a/bots/models.py +++ b/bots/models.py @@ -260,7 +260,7 @@ class Meta: ordering = ["-updated_at"] unique_together = [ ["workflow", "example_id"], - ["workflow", "run_id", "uid"], + ["run_id", "uid"], ] constraints = [ models.CheckConstraint( @@ -273,6 +273,7 @@ class Meta: models.Index(fields=["-created_at"]), models.Index(fields=["-updated_at"]), models.Index(fields=["workflow"]), + models.Index(fields=["run_id", "uid"]), models.Index(fields=["workflow", "run_id", "uid"]), models.Index(fields=["workflow", "example_id", "run_id", "uid"]), models.Index(fields=["workflow", "example_id", "hidden"]), diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 7bd144091..1868c7a18 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -8,7 +8,6 @@ import gooey_ui as st from app_users.models import AppUser from bots.models import SavedRun -from costs.models import UsageCost from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings diff --git a/costs/admin.py b/costs/admin.py deleted file mode 100644 index f3d99d472..000000000 --- a/costs/admin.py +++ /dev/null @@ -1,33 +0,0 @@ -from django.contrib import admin -from costs import models - -# Register your models here. - - -@admin.register(models.UsageCost) -class CostsAdmin(admin.ModelAdmin): - list_display = [ - "saved_run", - "provider_pricing", - "quantity", - "notes", - "dollar_amt", - "created_at", - ] - - -@admin.register(models.ProviderPricing) -class ProviderAdmin(admin.ModelAdmin): - list_display = [ - "type", - "provider", - "product", - "param", - "cost", - "unit", - "notes", - "created_at", - "last_updated", - "updated_by", - "pricing_url", - ] diff --git a/costs/cost_utils.py b/costs/cost_utils.py deleted file mode 100644 index 25f6ec28b..000000000 --- a/costs/cost_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -from costs.models import UsageCost, ProviderPricing -from bots.models import SavedRun - - -def get_provider_pricing( - type: str, - provider: str, - product: str, - param: str, -) -> ProviderPricing: - print("get_provider_pricing", type, provider, product, param) - return ProviderPricing.objects.get( - type=type, - provider=provider, - product=product, - param=param, - ) - - -def record_cost( - run_id: str | None, - uid: str | None, - provider_pricing: ProviderPricing, - quantity: int, -) -> UsageCost: - saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) - cost = UsageCost( - saved_run=saved_run, - provider_pricing=provider_pricing, - quantity=quantity, - notes="", - dollar_amt=provider_pricing.cost * quantity, - created_at=saved_run.created_at, - ) - cost.save() - return cost diff --git a/costs/migrations/0001_initial.py b/costs/migrations/0001_initial.py deleted file mode 100644 index d3f6d1948..000000000 --- a/costs/migrations/0001_initial.py +++ /dev/null @@ -1,59 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-04 03:59 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - ("bots", "0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more"), - ] - - operations = [ - migrations.CreateModel( - name="UsageCost", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ( - "provider", - models.IntegerField( - choices=[ - (1, "OpenAI"), - (2, "GPT-4"), - (3, "dalle-e"), - (4, "whisper"), - (5, "GPT-3.5"), - ] - ), - ), - ("model", models.TextField(blank=True, verbose_name="model")), - ("param", models.TextField(blank=True, verbose_name="param")), - ("notes", models.TextField(blank=True, default="")), - ("calculation_notes", models.TextField(blank=True, default="")), - ("dollar_amt", models.DecimalField(decimal_places=8, max_digits=13)), - ("created_at", models.DateTimeField(auto_now_add=True)), - ( - "saved_run", - models.ForeignKey( - blank=True, - default=None, - help_text="The run that was last saved by the user.", - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="usage_costs", - to="bots.savedrun", - ), - ), - ], - ), - ] diff --git a/costs/migrations/0002_providerpricing.py b/costs/migrations/0002_providerpricing.py deleted file mode 100644 index 6b0bb17a1..000000000 --- a/costs/migrations/0002_providerpricing.py +++ /dev/null @@ -1,79 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-04 04:00 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="ProviderPricing", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("type", models.TextField(choices=[("LLM", "Llm")])), - ( - "provider", - models.IntegerField( - choices=[ - (1, "OpenAI"), - (2, "GPT-4"), - (3, "dalle-e"), - (4, "whisper"), - (5, "GPT-3.5"), - ] - ), - ), - ( - "product", - models.TextField( - choices=[ - ("GPT-4 Vision (openai)", "gpt_4_vision"), - ("GPT-4 Turbo (openai)", "gpt_4_turbo"), - ("GPT-4 (openai)", "gpt_4"), - ("GPT-4 32K (openai)", "gpt_4_32k"), - ("ChatGPT (openai)", "gpt_3_5_turbo"), - ("ChatGPT 16k (openai)", "gpt_3_5_turbo_16k"), - ("Llama 2 (Meta AI)", "llama2_70b_chat"), - ("PaLM 2 Text (Google)", "palm2_chat"), - ("PaLM 2 Chat (Google)", "palm2_text"), - ("GPT-3.5 Davinci-3 (openai)", "text_davinci_003"), - ("GPT-3.5 Davinci-2 (openai)", "text_davinci_002"), - ("Curie (openai)", "text_curie_001"), - ("Babbage (openai)", "text_babbage_001"), - ("Ada (openai)", "text_ada_001"), - ("Codex [Deprecated] (openai)", "code_davinci_002"), - ] - ), - ), - ( - "param", - models.TextField( - choices=[ - ("input", "Input"), - ("output", "Output"), - ("input image", "Input Image"), - ("output image", "Output Image"), - ] - ), - ), - ("cost", models.DecimalField(decimal_places=8, max_digits=13)), - ("unit", models.TextField(default="")), - ("notes", models.TextField(default="")), - ("created_at", models.DateTimeField(auto_now_add=True)), - ("last_updated", models.DateTimeField(auto_now=True)), - ("updated_by", models.TextField(default="")), - ("pricing_url", models.TextField(default="")), - ], - ), - ] diff --git a/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py b/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py deleted file mode 100644 index 4aa89cadd..000000000 --- a/costs/migrations/0003_remove_usagecost_model_remove_usagecost_param_and_more.py +++ /dev/null @@ -1,64 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-05 06:02 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0002_providerpricing"), - ] - - operations = [ - migrations.RemoveField( - model_name="usagecost", - name="model", - ), - migrations.RemoveField( - model_name="usagecost", - name="param", - ), - migrations.RemoveField( - model_name="usagecost", - name="provider", - ), - migrations.AddField( - model_name="usagecost", - name="provider_pricing", - field=models.ForeignKey( - default=None, - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="usage_costs", - to="costs.providerpricing", - ), - ), - migrations.AddField( - model_name="usagecost", - name="quantity", - field=models.IntegerField(default=1), - ), - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("GPT-4 Vision (openai)", "gpt_4_vision"), - ("GPT-4 Turbo (openai)", "gpt_4_turbo"), - ("GPT-4 (openai)", "gpt_4"), - ("GPT-4 32K (openai)", "gpt_4_32k"), - ("ChatGPT (openai)", "gpt_3_5_turbo"), - ("ChatGPT 16k (openai)", "gpt_3_5_turbo_16k"), - ("Llama 2 (Meta AI)", "llama2_70b_chat"), - ("PaLM 2 Chat (Google)", "palm2_chat"), - ("PaLM 2 Text (Google)", "palm2_text"), - ("GPT-3.5 Davinci-3 (openai)", "text_davinci_003"), - ("GPT-3.5 Davinci-2 (openai)", "text_davinci_002"), - ("Curie (openai)", "text_curie_001"), - ("Babbage (openai)", "text_babbage_001"), - ("Ada (openai)", "text_ada_001"), - ("Codex [Deprecated] (openai)", "code_davinci_002"), - ] - ), - ), - ] diff --git a/costs/migrations/0004_remove_usagecost_saved_run.py b/costs/migrations/0004_remove_usagecost_saved_run.py deleted file mode 100644 index 2e68a519e..000000000 --- a/costs/migrations/0004_remove_usagecost_saved_run.py +++ /dev/null @@ -1,16 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-05 07:33 - -from django.db import migrations - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0003_remove_usagecost_model_remove_usagecost_param_and_more"), - ] - - operations = [ - migrations.RemoveField( - model_name="usagecost", - name="saved_run", - ), - ] diff --git a/costs/migrations/0005_remove_usagecost_provider_pricing.py b/costs/migrations/0005_remove_usagecost_provider_pricing.py deleted file mode 100644 index 2ff682ab0..000000000 --- a/costs/migrations/0005_remove_usagecost_provider_pricing.py +++ /dev/null @@ -1,16 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-05 08:06 - -from django.db import migrations - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0004_remove_usagecost_saved_run"), - ] - - operations = [ - migrations.RemoveField( - model_name="usagecost", - name="provider_pricing", - ), - ] diff --git a/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py b/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py deleted file mode 100644 index ba207c963..000000000 --- a/costs/migrations/0006_usagecost_provider_pricing_usagecost_saved_run.py +++ /dev/null @@ -1,38 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-05 08:06 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - dependencies = [ - ("bots", "0054_alter_savedrun_example_id_alter_savedrun_page_notes_and_more"), - ("costs", "0005_remove_usagecost_provider_pricing"), - ] - - operations = [ - migrations.AddField( - model_name="usagecost", - name="provider_pricing", - field=models.ForeignKey( - default=None, - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="usage_costs", - to="costs.providerpricing", - ), - ), - migrations.AddField( - model_name="usagecost", - name="saved_run", - field=models.ForeignKey( - blank=True, - default=None, - help_text="The run that was last saved by the user.", - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="usage_costs", - to="bots.savedrun", - ), - ), - ] diff --git a/costs/migrations/0007_alter_providerpricing_provider.py b/costs/migrations/0007_alter_providerpricing_provider.py deleted file mode 100644 index e6220e391..000000000 --- a/costs/migrations/0007_alter_providerpricing_provider.py +++ /dev/null @@ -1,25 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-06 00:51 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0006_usagecost_provider_pricing_usagecost_saved_run"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="provider", - field=models.TextField( - choices=[ - ("OpenAI", "Openai"), - ("GPT-4", "Gpt 4"), - ("dalle-e", "Dalle E"), - ("whisper", "Whisper"), - ("GPT-3.5", "Gpt 3 5"), - ] - ), - ), - ] diff --git a/costs/migrations/0008_alter_providerpricing_param_and_more.py b/costs/migrations/0008_alter_providerpricing_param_and_more.py deleted file mode 100644 index bcc9a0030..000000000 --- a/costs/migrations/0008_alter_providerpricing_param_and_more.py +++ /dev/null @@ -1,33 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-06 01:10 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0007_alter_providerpricing_provider"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="param", - field=models.TextField(choices=[("Input", "input"), ("Output", "output")]), - ), - migrations.AlterField( - model_name="providerpricing", - name="provider", - field=models.TextField( - choices=[ - ("vertex_ai", "Vertex AI"), - ("openai", "OpenAI"), - ("together", "Together"), - ] - ), - ), - migrations.AlterField( - model_name="providerpricing", - name="type", - field=models.TextField(choices=[("LLM", "LLM")]), - ), - ] diff --git a/costs/migrations/0009_alter_providerpricing_product.py b/costs/migrations/0009_alter_providerpricing_product.py deleted file mode 100644 index 5e5a51019..000000000 --- a/costs/migrations/0009_alter_providerpricing_product.py +++ /dev/null @@ -1,19 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-06 01:28 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0008_alter_providerpricing_param_and_more"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[("gpt-4-vision-preview", "gpt-4-vision-preview")] - ), - ), - ] diff --git a/costs/migrations/0010_remove_usagecost_calculation_notes.py b/costs/migrations/0010_remove_usagecost_calculation_notes.py deleted file mode 100644 index 3eba7ace1..000000000 --- a/costs/migrations/0010_remove_usagecost_calculation_notes.py +++ /dev/null @@ -1,16 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-06 02:02 - -from django.db import migrations - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0009_alter_providerpricing_product"), - ] - - operations = [ - migrations.RemoveField( - model_name="usagecost", - name="calculation_notes", - ), - ] diff --git a/costs/migrations/0011_alter_providerpricing_product.py b/costs/migrations/0011_alter_providerpricing_product.py deleted file mode 100644 index b49d3f308..000000000 --- a/costs/migrations/0011_alter_providerpricing_product.py +++ /dev/null @@ -1,35 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 08:50 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0010_remove_usagecost_calculation_notes"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ("gpt-4-1106-preview", "gpt-4-1106-preview"), - ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/migrations/0012_alter_providerpricing_product.py b/costs/migrations/0012_alter_providerpricing_product.py deleted file mode 100644 index 3bc92c01c..000000000 --- a/costs/migrations/0012_alter_providerpricing_product.py +++ /dev/null @@ -1,37 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 16:45 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0011_alter_providerpricing_product"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ("openai-gpt-4-turbo-prod-ca-1", "openai-gpt-4-turbo-prod-ca-1"), - ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), - ("gpt-3.5-turbo", "gpt-3.5-turbo"), - ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/migrations/0013_alter_providerpricing_product.py b/costs/migrations/0013_alter_providerpricing_product.py deleted file mode 100644 index 9449fbf91..000000000 --- a/costs/migrations/0013_alter_providerpricing_product.py +++ /dev/null @@ -1,37 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 16:50 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0012_alter_providerpricing_product"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ("openai-gpt-4-turbo-prod-ca-1", "gpt-4-1106-preview"), - ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), - ("gpt-3.5-turbo", "gpt-3.5-turbo"), - ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/migrations/0014_alter_providerpricing_product.py b/costs/migrations/0014_alter_providerpricing_product.py deleted file mode 100644 index 22b5439ed..000000000 --- a/costs/migrations/0014_alter_providerpricing_product.py +++ /dev/null @@ -1,40 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 16:58 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0013_alter_providerpricing_product"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ( - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - ), - ("openai-gpt-4-32k-prod-ca-1", "openai-gpt-4-32k-prod-ca-1"), - ("gpt-3.5-turbo", "gpt-3.5-turbo"), - ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/migrations/0015_alter_providerpricing_product.py b/costs/migrations/0015_alter_providerpricing_product.py deleted file mode 100644 index 2658de04f..000000000 --- a/costs/migrations/0015_alter_providerpricing_product.py +++ /dev/null @@ -1,43 +0,0 @@ -# Generated by Django 4.2.5 on 2024-01-10 17:02 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("costs", "0014_alter_providerpricing_product"), - ] - - operations = [ - migrations.AlterField( - model_name="providerpricing", - name="product", - field=models.TextField( - choices=[ - ("gpt-4-vision-preview", "gpt-4-vision-preview"), - ( - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - ), - ( - "('openai-gpt-4-prod-ca-1', 'gpt-4')", - "('openai-gpt-4-prod-ca-1', 'gpt-4')", - ), - ("gpt-3.5-turbo", "gpt-3.5-turbo"), - ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k"), - ("text-davinci-003", "text-davinci-003"), - ("text-davinci-002", "text-davinci-002"), - ("code-davinci-002", "code-davinci-002"), - ("text-curie-001", "text-curie-001"), - ("text-babbage-001", "text-babbage-001"), - ("text-ada-001", "text-ada-001"), - ("text-bison", "text-bison"), - ("chat-bison", "chat-bison"), - ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ), - ] - ), - ), - ] diff --git a/costs/models.py b/costs/models.py deleted file mode 100644 index 18700537f..000000000 --- a/costs/models.py +++ /dev/null @@ -1,113 +0,0 @@ -from django.db import models -from daras_ai_v2.language_model import LLMApis - - -class Product(models.TextChoices): - gpt_4_vision = ( - "gpt-4-vision-preview", - "gpt-4-vision-preview", - ) - gpt_4_turbo = ( - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - "('openai-gpt-4-turbo-prod-ca-1', 'gpt-4-1106-preview')", - ) - - gpt_4_32k = ( - "('openai-gpt-4-prod-ca-1', 'gpt-4')", - "('openai-gpt-4-prod-ca-1', 'gpt-4')", - ) - gpt_3_5_turbo = ( - "gpt-3.5-turbo", - "gpt-3.5-turbo", - ) - gpt_3_5_turbo_16k = ( - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-16k", - ) - text_davinci_003 = ( - "text-davinci-003", - "text-davinci-003", - ) - text_davinci_002 = ( - "text-davinci-002", - "text-davinci-002", - ) - code_davinci_002 = ( - "code-davinci-002", - "code-davinci-002", - ) - text_curie_001 = ( - "text-curie-001", - "text-curie-001", - ) - text_babbage_001 = ( - "text-babbage-001", - "text-babbage-001", - ) - text_ada_001 = ( - "text-ada-001", - "text-ada-001", - ) - palm2_text = ( - "text-bison", - "text-bison", - ) - palm2_chat = ( - "chat-bison", - "chat-bison", - ) - llama2_70b_chat = ( - "togethercomputer/llama-2-70b-chat", - "togethercomputer/llama-2-70b-chat", - ) - - -class UsageCost(models.Model): - saved_run = models.ForeignKey( - "bots.SavedRun", - on_delete=models.CASCADE, - related_name="usage_costs", - null=True, - default=None, - blank=True, - help_text="The run that was last saved by the user.", - ) - - provider_pricing = models.ForeignKey( - "costs.ProviderPricing", - on_delete=models.CASCADE, - related_name="usage_costs", - null=True, - default=None, - ) - quantity = models.IntegerField(default=1) - notes = models.TextField(default="", blank=True) - dollar_amt = models.DecimalField( - max_digits=13, - decimal_places=8, - ) - created_at = models.DateTimeField(auto_now_add=True) - - -class ProviderPricing(models.Model): - class Group(models.TextChoices): # change to different name than type - LLM = "LLM", "LLM" - - class Param(models.TextChoices): - input = "Input", "input" - output = "Output", "output" - - type = models.TextField(choices=Group.choices) - provider = models.TextField(choices=LLMApis.choices()) - product = models.TextField(choices=Product.choices) - param = models.TextField(choices=Param.choices) - cost = models.DecimalField(max_digits=13, decimal_places=8) - unit = models.TextField(default="") - notes = models.TextField(default="") - created_at = models.DateTimeField(auto_now_add=True) - last_updated = models.DateTimeField(auto_now=True) - updated_by = models.TextField(default="") - pricing_url = models.TextField(default="") - - def __str__(self): - return self.type + " " + self.provider + " " + self.product + " " + self.param diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 58a26f9b6..fa8b48d86 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -4,7 +4,7 @@ import re import typing from enum import Enum -from functools import partial +from functools import partial, wraps import numpy as np import requests @@ -112,11 +112,11 @@ def is_chat_model(self) -> bool: LargeLanguageModels.gpt_4_32k: "openai-gpt-4-32k-prod-ca-1", LargeLanguageModels.gpt_3_5_turbo: ( "openai-gpt-35-turbo-prod-ca-1", - "gpt-3.5-turbo", + "gpt-3.5-turbo-0613", ), LargeLanguageModels.gpt_3_5_turbo_16k: ( "openai-gpt-35-turbo-16k-prod-ca-1", - "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-16k-0613", ), LargeLanguageModels.text_davinci_003: "text-davinci-003", LargeLanguageModels.text_davinci_002: "text-davinci-002", @@ -378,34 +378,6 @@ def run_language_model( stream=stream and not (tools or response_format_type), ) - ## incorportate down the call chain - # provider_pricing_in = get_provider_pricing( - # type="LLM", - # provider=api.name, - # product=model_name, - # param="Input", - # ) - # - # provider_pricing_out = get_provider_pricing( - # type="LLM", - # provider=api.name, - # product=model_name, - # param="Output", - # ) - # example_id, run_id, uid = extract_query_params(gooey_get_query_params()) - # record_cost( - # run_id=run_id, - # uid=uid, - # provider_pricing=provider_pricing_in, - # quantity=input_token, - # ) - # record_cost( - # run_id=run_id, - # uid=uid, - # provider_pricing=provider_pricing_out, - # quantity=output_token, - # ) - if stream: return _stream_llm_outputs(entries, response_format_type) else: @@ -588,10 +560,9 @@ def _run_openai_chat( presence_penalty = 0 if isinstance(model, str): model = [model] - r = try_all( + r, used_model = try_all( *[ - partial( - _get_openai_client(model_str).chat.completions.create, + _get_chat_completions_create( model=model_str, messages=messages, max_tokens=max_tokens, @@ -612,15 +583,39 @@ def _run_openai_chat( ], ) if stream: - return _stream_openai_chunked(r) + return _stream_openai_chunked(r, used_model, messages) else: - r.usage.completion_tokens - r.usage.prompt_tokens + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=used_model, + sku=ModelSku.llm_prompt, + quantity=r.usage.prompt_tokens, + ) + record_cost_auto( + model=used_model, + sku=ModelSku.llm_completion, + quantity=r.usage.completion_tokens, + ) return [choice.message.dict() for choice in r.choices] +def _get_chat_completions_create(model: str, **kwargs): + client = _get_openai_client(model) + + @wraps(client.chat.completions.create) + def wrapper(): + return client.chat.completions.create(model=model, **kwargs), model + + return wrapper + + def _stream_openai_chunked( r: Stream[ChatCompletionChunk], + used_model: str, + messages: list[ConversationEntry], + *, start_chunk_size: int = 50, stop_chunk_size: int = 400, step_chunk_size: int = 150, @@ -679,6 +674,20 @@ def _stream_openai_chunked( entry["content"] += entry["chunk"] yield ret + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=used_model, + sku=ModelSku.llm_prompt, + quantity=sum(default_length_function(entry["content"]) for entry in messages), + ) + record_cost_auto( + model=used_model, + sku=ModelSku.llm_completion, + quantity=sum(default_length_function(entry["content"]) for entry in ret), + ) + @retry_if(openai_should_retry) def _run_openai_text( @@ -702,8 +711,21 @@ def _run_openai_text( frequency_penalty=0.1 if avoid_repetition else 0, presence_penalty=0.25 if avoid_repetition else 0, ) - r.usage.completion_tokens, - r.usage.prompt_tokens, + + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=model, + sku=ModelSku.llm_prompt, + quantity=r.usage.prompt_tokens, + ) + record_cost_auto( + model=model, + sku=ModelSku.llm_completion, + quantity=r.usage.completion_tokens, + ) + return [choice.text for choice in r.choices] @@ -761,15 +783,13 @@ def _run_together_chat( range(num_outputs), ) ret = [] - total_out_tokens = 0 - total_in_tokens = 0 + prompt_tokens = 0 + completion_tokens = 0 for r in results: raise_for_status(r) data = r.json() output = data["output"] error = output.get("error") - total_out_tokens += output.get("usage", {}).get("completion_tokens", 0) - total_in_tokens += output.get("usage", {}).get("prompt_tokens", 0) if error: raise ValueError(error) ret.append( @@ -778,6 +798,21 @@ def _run_together_chat( "content": output["choices"][0]["text"], } ) + prompt_tokens += output.get("usage", {}).get("prompt_tokens", 0) + completion_tokens += output.get("usage", {}).get("completion_tokens", 0) + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=model, + sku=ModelSku.llm_prompt, + quantity=prompt_tokens, + ) + record_cost_auto( + model=model, + sku=ModelSku.llm_completion, + quantity=completion_tokens, + ) return ret @@ -828,16 +863,28 @@ def _run_palm_chat( }, ) raise_for_status(r) - - r.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"] - r.json()["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"] + out = r.json() + + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + + record_cost_auto( + model=model_id, + sku=ModelSku.llm_prompt, + quantity=out["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"], + ) + record_cost_auto( + model=model_id, + sku=ModelSku.llm_completion, + quantity=out["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + ) return [ { "role": msg["author"], "content": msg["content"], } - for pred in r.json()["predictions"] + for pred in out["predictions"] for msg in pred["candidates"] ] @@ -876,11 +923,23 @@ def _run_palm_text( }, ) raise_for_status(res) + out = res.json() + + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku - res.json()["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"] - res.json()["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"] + record_cost_auto( + model=model_id, + sku=ModelSku.llm_prompt, + quantity=out["metadata"]["tokenMetadata"]["inputTokenCount"]["totalTokens"], + ) + record_cost_auto( + model=model_id, + sku=ModelSku.llm_completion, + quantity=out["metadata"]["tokenMetadata"]["outputTokenCount"]["totalTokens"], + ) - return [prediction["content"] for prediction in res.json()["predictions"]] + return [prediction["content"] for prediction in out["predictions"]] def format_chatml_message(entry: ConversationEntry) -> str: diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 9df4d7673..deda81500 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -55,7 +55,7 @@ "files", "url_shortener", "glossary_resources", - "costs", + "usage_costs", ] MIDDLEWARE = [ diff --git a/scripts/init_llm_pricing.py b/scripts/init_llm_pricing.py new file mode 100644 index 000000000..a1f55e964 --- /dev/null +++ b/scripts/init_llm_pricing.py @@ -0,0 +1,338 @@ +from daras_ai_v2.language_model import LargeLanguageModels +from usage_costs.models import ModelSku, ModelCategory, ModelProvider, ModelPricing + + +def run(): + category = ModelCategory.LLM + + # GPT-4-Turbo + + for model in ["gpt-4-0125-preview", "gpt-4-1106-preview"]: + ModelPricing.objects.create( + model_id=model, + model_name=LargeLanguageModels.gpt_4_turbo.name, + sku=ModelSku.llm_prompt, + unit_cost=0.01, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id=model, + model_name=LargeLanguageModels.gpt_4_turbo.name, + sku=ModelSku.llm_completion, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-4-turbo-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_turbo.name, + sku=ModelSku.llm_prompt, + unit_cost=0.01, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-4-turbo-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_turbo.name, + sku=ModelSku.llm_completion, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-4-Turbo-Vision + + ModelPricing.objects.create( + model_id="gpt-4-vision-preview", + model_name=LargeLanguageModels.gpt_4_vision.name, + sku=ModelSku.llm_prompt, + unit_cost=0.01, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-4-vision-preview", + model_name=LargeLanguageModels.gpt_4_vision.name, + sku=ModelSku.llm_completion, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-4-turbo-vision-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_vision.name, + sku=ModelSku.llm_prompt, + unit_cost=0.01, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-4-turbo-vision-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_vision.name, + sku=ModelSku.llm_completion, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-4 + + ModelPricing.objects.create( + model_id="gpt-4", + model_name=LargeLanguageModels.gpt_4.name, + sku=ModelSku.llm_prompt, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-4", + model_name=LargeLanguageModels.gpt_4.name, + sku=ModelSku.llm_completion, + unit_cost=0.06, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-4-prod-ca-1", + model_name=LargeLanguageModels.gpt_4.name, + sku=ModelSku.llm_prompt, + unit_cost=0.03, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-4-prod-ca-1", + model_name=LargeLanguageModels.gpt_4.name, + sku=ModelSku.llm_completion, + unit_cost=0.06, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-4-32k + + ModelPricing.objects.create( + model_id="gpt-4-32k", + model_name=LargeLanguageModels.gpt_4_32k.name, + sku=ModelSku.llm_prompt, + unit_cost=0.06, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-4-32k", + model_name=LargeLanguageModels.gpt_4_32k.name, + sku=ModelSku.llm_completion, + unit_cost=0.12, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-4-32k-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_32k.name, + sku=ModelSku.llm_prompt, + unit_cost=0.06, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-4-32k-prod-ca-1", + model_name=LargeLanguageModels.gpt_4_32k.name, + sku=ModelSku.llm_completion, + unit_cost=0.12, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-3.5-Turbo + + ModelPricing.objects.create( + model_id="gpt-3.5-turbo-0613", + model_name=LargeLanguageModels.gpt_3_5_turbo.name, + sku=ModelSku.llm_prompt, + unit_cost=0.0015, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-3.5-turbo-0613", + model_name=LargeLanguageModels.gpt_3_5_turbo.name, + sku=ModelSku.llm_completion, + unit_cost=0.002, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-35-turbo-prod-ca-1", + model_name=LargeLanguageModels.gpt_3_5_turbo.name, + sku=ModelSku.llm_prompt, + unit_cost=0.0015, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-35-turbo-prod-ca-1", + model_name=LargeLanguageModels.gpt_3_5_turbo.name, + sku=ModelSku.llm_completion, + unit_cost=0.002, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # GPT-3.5-Turbo-16k + + ModelPricing.objects.create( + model_id="gpt-3.5-turbo-16k-0613", + model_name=LargeLanguageModels.gpt_3_5_turbo_16k.name, + sku=ModelSku.llm_prompt, + unit_cost=0.003, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + ModelPricing.objects.create( + model_id="gpt-3.5-turbo-16k-0613", + model_name=LargeLanguageModels.gpt_3_5_turbo_16k.name, + sku=ModelSku.llm_completion, + unit_cost=0.004, + unit_quantity=1000, + category=category, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + + ModelPricing.objects.create( + model_id="openai-gpt-35-turbo-16k-prod-ca-1", + model_name=LargeLanguageModels.gpt_3_5_turbo_16k.name, + sku=ModelSku.llm_prompt, + unit_cost=0.003, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + ModelPricing.objects.create( + model_id="openai-gpt-35-turbo-16k-prod-ca-1", + model_name=LargeLanguageModels.gpt_3_5_turbo_16k.name, + sku=ModelSku.llm_completion, + unit_cost=0.004, + unit_quantity=1000, + category=category, + provider=ModelProvider.azure_openai, + pricing_url="https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/", + ) + + # Palm2 + + ModelPricing.objects.create( + model_id="text-bison", + model_name=LargeLanguageModels.palm2_text.name, + sku=ModelSku.llm_prompt, + unit_cost=0.00025, + unit_quantity=1000, + category=category, + provider=ModelProvider.google, + pricing_url="https://cloud.google.com/vertex-ai/docs/generative-ai/pricing#text_generation", + ) + ModelPricing.objects.create( + model_id="text-bison", + model_name=LargeLanguageModels.palm2_text.name, + sku=ModelSku.llm_completion, + unit_cost=0.0005, + unit_quantity=1000, + category=category, + provider=ModelProvider.google, + pricing_url="https://cloud.google.com/vertex-ai/docs/generative-ai/pricing#text_generation", + ) + + ModelPricing.objects.create( + model_id="chat-bison", + model_name=LargeLanguageModels.palm2_chat.name, + sku=ModelSku.llm_prompt, + unit_cost=0.00025, + unit_quantity=1000, + category=category, + provider=ModelProvider.google, + pricing_url="https://cloud.google.com/vertex-ai/docs/generative-ai/pricing#text_generation", + ) + ModelPricing.objects.create( + model_id="chat-bison", + model_name=LargeLanguageModels.palm2_chat.name, + sku=ModelSku.llm_completion, + unit_cost=0.0005, + unit_quantity=1000, + category=category, + provider=ModelProvider.google, + pricing_url="https://cloud.google.com/vertex-ai/docs/generative-ai/pricing#text_generation", + ) + + # Llama2 + + ModelPricing.objects.create( + model_id="togethercomputer/llama-2-70b-chat", + model_name=LargeLanguageModels.llama2_70b_chat.name, + sku=ModelSku.llm_prompt, + unit_cost=0.9, + unit_quantity=10**6, + category=category, + provider=ModelProvider.together_ai, + pricing_url="https://www.together.ai/pricing", + ) + ModelPricing.objects.create( + model_id="togethercomputer/llama-2-70b-chat", + model_name=LargeLanguageModels.llama2_70b_chat.name, + sku=ModelSku.llm_completion, + unit_cost=0.9, + unit_quantity=10**6, + category=category, + provider=ModelProvider.together_ai, + pricing_url="https://www.together.ai/pricing", + ) diff --git a/costs/__init__.py b/usage_costs/__init__.py similarity index 100% rename from costs/__init__.py rename to usage_costs/__init__.py diff --git a/usage_costs/admin.py b/usage_costs/admin.py new file mode 100644 index 000000000..611624ab1 --- /dev/null +++ b/usage_costs/admin.py @@ -0,0 +1,75 @@ +from decimal import Decimal + +from django.contrib import admin + +from bots.admin_links import open_in_new_tab, change_obj_url +from usage_costs import models + + +class CostQtyMixin: + @admin.display(description="Cost / Qty", ordering="unit_cost") + def cost_qty(self, obj): + return f"${obj.unit_cost.normalize()} / {obj.unit_quantity}" + + +@admin.register(models.UsageCost) +class UsageCostAdmin(admin.ModelAdmin, CostQtyMixin): + list_display = [ + "__str__", + "cost_qty", + "quantity", + "display_dollar_amount", + "view_pricing", + "view_saved_run", + "notes", + "created_at", + ] + autocomplete_fields = ["saved_run", "pricing"] + list_filter = [ + "saved_run__workflow", + "pricing__category", + "pricing__provider", + "pricing__model_name", + "pricing__sku", + "created_at", + ] + search_fields = ["saved_run", "pricing"] + readonly_fields = ["created_at"] + ordering = ["-created_at"] + + @admin.display(description="Amount", ordering="dollar_amount") + def display_dollar_amount(self, obj): + return f"${obj.dollar_amount.normalize()}" + + @admin.display(description="Saved Run", ordering="saved_run") + def view_saved_run(self, obj): + return change_obj_url( + obj.saved_run, + label=f"{obj.saved_run.get_workflow_display()} - {obj.saved_run} ({obj.saved_run.run_id})", + ) + + @admin.display(description="Pricing", ordering="pricing") + def view_pricing(self, obj): + return change_obj_url(obj.pricing) + + +@admin.register(models.ModelPricing) +class ModelPricingAdmin(admin.ModelAdmin, CostQtyMixin): + list_display = [ + "__str__", + "cost_qty", + "model_id", + "category", + "provider", + "model_name", + "sku", + "view_pricing_url", + "notes", + ] + list_filter = ["category", "provider", "model_name", "sku"] + search_fields = ["category", "provider", "model_name", "sku", "model_id"] + readonly_fields = ["created_at", "updated_at"] + + @admin.display(description="Pricing URL", ordering="pricing_url") + def view_pricing_url(self, obj): + return open_in_new_tab(obj.pricing_url, label=obj.pricing_url) diff --git a/costs/apps.py b/usage_costs/apps.py similarity index 67% rename from costs/apps.py rename to usage_costs/apps.py index 55bed0ca1..b807e363b 100644 --- a/costs/apps.py +++ b/usage_costs/apps.py @@ -3,5 +3,5 @@ class CostsConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" - name = "costs" - verbose_name = "Costs" + name = "usage_costs" + verbose_name = "Usage Costs" diff --git a/usage_costs/cost_utils.py b/usage_costs/cost_utils.py new file mode 100644 index 000000000..ea622e06a --- /dev/null +++ b/usage_costs/cost_utils.py @@ -0,0 +1,34 @@ +from loguru import logger + +from daras_ai_v2.query_params import gooey_get_query_params +from daras_ai_v2.query_params_util import extract_query_params +from usage_costs.models import ( + UsageCost, + ModelSku, + ModelPricing, +) + + +def record_cost_auto(model: str, sku: ModelSku, quantity: int): + from bots.models import SavedRun + + _, run_id, uid = extract_query_params(gooey_get_query_params()) + if not run_id or not uid: + return + + try: + pricing = ModelPricing.objects.get(model_id=model, sku=sku) + except ModelPricing.DoesNotExist as e: + logger.warning(f"Cant find pricing for {model=} {sku=}: {e=}") + return + + saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) + + UsageCost.objects.create( + saved_run=saved_run, + pricing=pricing, + quantity=quantity, + unit_cost=pricing.unit_cost, + unit_quantity=pricing.unit_quantity, + dollar_amount=pricing.unit_cost * quantity / pricing.unit_quantity, + ) diff --git a/usage_costs/migrations/0001_initial.py b/usage_costs/migrations/0001_initial.py new file mode 100644 index 000000000..bb02723f7 --- /dev/null +++ b/usage_costs/migrations/0001_initial.py @@ -0,0 +1,56 @@ +# Generated by Django 4.2.7 on 2024-02-06 18:22 + +import bots.custom_fields +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('bots', '0057_message_response_time_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='ModelPricing', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('model_id', models.TextField(help_text='The model ID. Model ID + SKU should be unique together.')), + ('sku', models.IntegerField(choices=[(1, 'LLM Prompt'), (2, 'LLM Completion')], help_text="The model's SKU. Model ID + SKU should be unique together.")), + ('unit_cost', models.DecimalField(decimal_places=10, help_text='The cost per unit.', max_digits=15)), + ('unit_quantity', models.PositiveIntegerField(help_text='The quantity of the unit. (e.g. 1000 tokens)')), + ('category', models.IntegerField(choices=[(1, 'LLM')], help_text='The category of the model. Only used for Display purposes.')), + ('provider', models.IntegerField(choices=[(1, 'OpenAI'), (2, 'Google'), (3, 'TogetherAI'), (4, 'Azure OpenAI')], help_text='The provider of the model. Only used for Display purposes.')), + ('model_name', models.CharField(choices=[('gpt_4_vision', 'GPT-4 Vision (openai)'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai)'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('llama2_70b_chat', 'Llama 2 (Meta AI)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 (openai)'), ('text_curie_001', 'Curie (openai)'), ('text_babbage_001', 'Babbage (openai)'), ('text_ada_001', 'Ada (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)')], help_text='The name of the model. Only used for Display purposes.', max_length=255)), + ('notes', models.TextField(blank=True, default='', help_text='Any notes about the pricing. (e.g. how the pricing was calculated)')), + ('pricing_url', bots.custom_fields.CustomURLField(blank=True, default='', help_text='The URL of the pricing.', max_length=2048)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ], + ), + migrations.CreateModel( + name='UsageCost', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('quantity', models.PositiveIntegerField()), + ('unit_cost', models.DecimalField(decimal_places=10, help_text='The cost per unit, recorded at the time of the usage.', max_digits=15)), + ('unit_quantity', models.PositiveIntegerField(help_text='The quantity of the unit (e.g. 1000 tokens), recorded at the time of the usage.')), + ('dollar_amount', models.DecimalField(decimal_places=10, help_text='The dollar amount, calculated as unit_cost x quantity / unit_quantity.', max_digits=15)), + ('notes', models.TextField(blank=True, default='')), + ('created_at', models.DateTimeField(auto_now_add=True, db_index=True)), + ('pricing', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='usage_costs', to='usage_costs.modelpricing')), + ('saved_run', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='usage_costs', to='bots.savedrun')), + ], + ), + migrations.AddIndex( + model_name='modelpricing', + index=models.Index(fields=['model_id', 'sku'], name='usage_costs_model_i_d7b80a_idx'), + ), + migrations.AlterUniqueTogether( + name='modelpricing', + unique_together={('model_id', 'sku')}, + ), + ] diff --git a/costs/migrations/__init__.py b/usage_costs/migrations/__init__.py similarity index 100% rename from costs/migrations/__init__.py rename to usage_costs/migrations/__init__.py diff --git a/usage_costs/models.py b/usage_costs/models.py new file mode 100644 index 000000000..67b5c69e0 --- /dev/null +++ b/usage_costs/models.py @@ -0,0 +1,114 @@ +from django.db import models + +from bots.custom_fields import CustomURLField + +max_digits = 15 +decimal_places = 10 + + +class UsageCost(models.Model): + saved_run = models.ForeignKey( + "bots.SavedRun", on_delete=models.CASCADE, related_name="usage_costs" + ) + pricing = models.ForeignKey( + "usage_costs.ModelPricing", on_delete=models.CASCADE, related_name="usage_costs" + ) + + quantity = models.PositiveIntegerField() + unit_cost = models.DecimalField( + max_digits=max_digits, + decimal_places=decimal_places, + help_text="The cost per unit, recorded at the time of the usage.", + ) + unit_quantity = models.PositiveIntegerField( + help_text="The quantity of the unit (e.g. 1000 tokens), recorded at the time of the usage." + ) + + dollar_amount = models.DecimalField( + max_digits=max_digits, + decimal_places=decimal_places, + help_text="The dollar amount, calculated as unit_cost x quantity / unit_quantity.", + ) + + notes = models.TextField(default="", blank=True) + + created_at = models.DateTimeField(auto_now_add=True, db_index=True) + + def __str__(self): + return f"{self.saved_run} - {self.pricing} - {self.quantity}" + + +class ModelCategory(models.IntegerChoices): + LLM = 1, "LLM" + + +class ModelProvider(models.IntegerChoices): + openai = 1, "OpenAI" + google = 2, "Google" + together_ai = 3, "TogetherAI" + azure_openai = 4, "Azure OpenAI" + + +def get_model_choices(): + from daras_ai_v2.language_model import LargeLanguageModels + + return [(api.name, api.value) for api in LargeLanguageModels] + + +class ModelSku(models.IntegerChoices): + llm_prompt = 1, "LLM Prompt" + llm_completion = 2, "LLM Completion" + + +class ModelPricing(models.Model): + model_id = models.TextField( + help_text="The model ID. Model ID + SKU should be unique together." + ) + sku = models.IntegerField( + choices=ModelSku.choices, + help_text="The model's SKU. Model ID + SKU should be unique together.", + ) + + unit_cost = models.DecimalField( + max_digits=max_digits, + decimal_places=decimal_places, + help_text="The cost per unit.", + ) + unit_quantity = models.PositiveIntegerField( + help_text="The quantity of the unit. (e.g. 1000 tokens)" + ) + + category = models.IntegerField( + choices=ModelCategory.choices, + help_text="The category of the model. Only used for Display purposes.", + ) + provider = models.IntegerField( + choices=ModelProvider.choices, + help_text="The provider of the model. Only used for Display purposes.", + ) + model_name = models.CharField( + max_length=255, + choices=get_model_choices(), + help_text="The name of the model. Only used for Display purposes.", + ) + + notes = models.TextField( + default="", + blank=True, + help_text="Any notes about the pricing. (e.g. how the pricing was calculated)", + ) + pricing_url = CustomURLField( + default="", blank=True, help_text="The URL of the pricing." + ) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + unique_together = ["model_id", "sku"] + indexes = [ + models.Index(fields=["model_id", "sku"]), + ] + + def __str__(self): + return f"{self.get_provider_display()} / {self.get_model_name_display()} / {self.get_sku_display()}" diff --git a/costs/tests.py b/usage_costs/tests.py similarity index 100% rename from costs/tests.py rename to usage_costs/tests.py diff --git a/costs/views.py b/usage_costs/views.py similarity index 100% rename from costs/views.py rename to usage_costs/views.py From 08d46ae8bd91f7efa45b5ea4d1ff6edfe0585f9a Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 00:06:33 +0530 Subject: [PATCH 41/58] fix gpt4-v crash --- daras_ai_v2/language_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index fa8b48d86..3d48b1068 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -680,7 +680,9 @@ def _stream_openai_chunked( record_cost_auto( model=used_model, sku=ModelSku.llm_prompt, - quantity=sum(default_length_function(entry["content"]) for entry in messages), + quantity=sum( + default_length_function(get_entry_text(entry)) for entry in messages + ), ) record_cost_auto( model=used_model, From 61b88c11253c6d46e3585477c0bb09c034fd2bb4 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 00:17:44 +0530 Subject: [PATCH 42/58] show usage cost in saved run admin --- bots/admin.py | 14 +++++++++++--- bots/admin_links.py | 3 +++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index 0b6b475cc..38424dc79 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -5,7 +5,7 @@ from django import forms from django.conf import settings from django.contrib import admin -from django.db.models import Max, Count, F +from django.db.models import Max, Count, F, Sum from django.template import loader from django.utils import dateformat from django.utils.safestring import mark_safe @@ -28,8 +28,6 @@ WorkflowMetadata, ) from bots.tasks import create_personal_channels_for_all_members -from daras_ai.image_input import truncate_text_words -from daras_ai_v2.base import BasePage from gooeysite.custom_actions import export_to_excel, export_to_csv from gooeysite.custom_filters import ( related_json_field_summary, @@ -278,6 +276,7 @@ class SavedRunAdmin(admin.ModelAdmin): "parent", "view_bots", "price", + "view_usage_cost", "transaction", "created_at", "updated_at", @@ -313,6 +312,15 @@ def view_parent_published_run(self, saved_run: SavedRun): pr = saved_run.parent_published_run() return pr and change_obj_url(pr) + @admin.display(description="Usage Costs") + def view_usage_cost(self, saved_run: SavedRun): + total_cost = saved_run.usage_costs.aggregate(total_cost=Sum("dollar_amount"))[ + "total_cost" + ] + return list_related_html_url( + saved_run.usage_costs, extra_label=f"${total_cost.normalize()}" + ) + @admin.register(PublishedRunVersion) class PublishedRunVersionAdmin(admin.ModelAdmin): diff --git a/bots/admin_links.py b/bots/admin_links.py index 94086e348..c06601953 100644 --- a/bots/admin_links.py +++ b/bots/admin_links.py @@ -35,6 +35,7 @@ def list_related_html_url( query_param: str = None, instance_id: int = None, show_add: bool = True, + extra_label: str = None, ) -> typing.Optional[str]: num = manager.all().count() @@ -60,6 +61,8 @@ def list_related_html_url( ).url label = f"{num} {meta.verbose_name if num == 1 else meta.verbose_name_plural}" + if extra_label: + label = f"{label} ({extra_label})" if show_add: add_related_url = furl( From 9899741ddbb425fee7424dbcc95e0e45aa086a86 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 00:22:17 +0530 Subject: [PATCH 43/58] increase run id size --- daras_ai_v2/crypto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daras_ai_v2/crypto.py b/daras_ai_v2/crypto.py index 30bef7f4f..ec0047bbd 100644 --- a/daras_ai_v2/crypto.py +++ b/daras_ai_v2/crypto.py @@ -65,7 +65,7 @@ def safe_preview(password: str) -> str: def get_random_doc_id() -> str: return get_random_string( - length=8, allowed_chars=string.ascii_lowercase + string.digits + length=12, allowed_chars=string.ascii_lowercase + string.digits ) From cc8c24ba00b0f8c125b9891dc10dea6d0f940ff1 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 21:54:08 +0530 Subject: [PATCH 44/58] record costs for all gpu tasks --- daras_ai_v2/gpu_server.py | 12 +++++- scripts/init_self_hosted_pricing.py | 2 + usage_costs/admin.py | 8 +++- ...02_alter_modelpricing_category_and_more.py | 38 +++++++++++++++++++ usage_costs/models.py | 12 +++++- 5 files changed, 67 insertions(+), 5 deletions(-) create mode 100644 scripts/init_self_hosted_pricing.py create mode 100644 usage_costs/migrations/0002_alter_modelpricing_category_and_more.py diff --git a/daras_ai_v2/gpu_server.py b/daras_ai_v2/gpu_server.py index fa638da89..3050cd98b 100644 --- a/daras_ai_v2/gpu_server.py +++ b/daras_ai_v2/gpu_server.py @@ -2,6 +2,7 @@ import datetime import os import typing +from time import time import requests from furl import furl @@ -138,6 +139,7 @@ def get_celery(): _app = Celery() _app.conf.broker_url = settings.GPU_CELERY_BROKER_URL _app.conf.result_backend = settings.GPU_CELERY_RESULT_BACKEND + _app.conf.result_extended = True return _app @@ -149,8 +151,16 @@ def call_celery_task( inputs: dict, queue_prefix: str = "gooey-gpu", ): + from usage_costs.cost_utils import record_cost_auto + from usage_costs.models import ModelSku + queue = os.path.join(queue_prefix, pipeline["model_id"].strip()).strip("/") result = get_celery().send_task( task_name, kwargs=dict(pipeline=pipeline, inputs=inputs), queue=queue ) - return result.get(disable_sync_subtasks=False) + s = time() + ret = result.get(disable_sync_subtasks=False) + record_cost_auto( + model=queue, sku=ModelSku.gpu_ms, quantity=int((time() - s) * 1000) + ) + return ret diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py new file mode 100644 index 000000000..d04d42367 --- /dev/null +++ b/scripts/init_self_hosted_pricing.py @@ -0,0 +1,2 @@ +def run(): + pass diff --git a/usage_costs/admin.py b/usage_costs/admin.py index 611624ab1..fbe367243 100644 --- a/usage_costs/admin.py +++ b/usage_costs/admin.py @@ -1,5 +1,3 @@ -from decimal import Decimal - from django.contrib import admin from bots.admin_links import open_in_new_tab, change_obj_url @@ -21,6 +19,7 @@ class UsageCostAdmin(admin.ModelAdmin, CostQtyMixin): "display_dollar_amount", "view_pricing", "view_saved_run", + "view_parent_published_run", "notes", "created_at", ] @@ -41,6 +40,11 @@ class UsageCostAdmin(admin.ModelAdmin, CostQtyMixin): def display_dollar_amount(self, obj): return f"${obj.dollar_amount.normalize()}" + @admin.display(description="Published Run") + def view_parent_published_run(self, obj): + pr = obj.saved_run.parent_published_run() + return pr and change_obj_url(pr) + @admin.display(description="Saved Run", ordering="saved_run") def view_saved_run(self, obj): return change_obj_url( diff --git a/usage_costs/migrations/0002_alter_modelpricing_category_and_more.py b/usage_costs/migrations/0002_alter_modelpricing_category_and_more.py new file mode 100644 index 000000000..6097175f6 --- /dev/null +++ b/usage_costs/migrations/0002_alter_modelpricing_category_and_more.py @@ -0,0 +1,38 @@ +# Generated by Django 4.2.7 on 2024-02-07 15:44 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('usage_costs', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='modelpricing', + name='category', + field=models.IntegerField(choices=[(1, 'LLM'), (2, 'Self-Hosted')], help_text='The category of the model. Only used for Display purposes.'), + ), + migrations.AlterField( + model_name='modelpricing', + name='model_name', + field=models.CharField(choices=[('gpt_4_vision', 'GPT-4 Vision (openai)'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai)'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('llama2_70b_chat', 'Llama 2 (Meta AI)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 (openai)'), ('text_curie_001', 'Curie (openai)'), ('text_babbage_001', 'Babbage (openai)'), ('text_ada_001', 'Ada (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)')], help_text='The name of the model. Only used for Display purposes.', max_length=255), + ), + migrations.AlterField( + model_name='modelpricing', + name='provider', + field=models.IntegerField(choices=[(1, 'OpenAI'), (2, 'Google'), (3, 'TogetherAI'), (4, 'Azure OpenAI'), (5, 'Azure Kubernetes Service')], help_text='The provider of the model. Only used for Display purposes.'), + ), + migrations.AlterField( + model_name='modelpricing', + name='sku', + field=models.IntegerField(choices=[(1, 'LLM Prompt'), (2, 'LLM Completion'), (3, 'GPU Milliseconds')], help_text="The model's SKU. Model ID + SKU should be unique together."), + ), + migrations.AlterField( + model_name='modelpricing', + name='unit_quantity', + field=models.PositiveIntegerField(default=1, help_text='The quantity of the unit. (e.g. 1000 tokens)'), + ), + ] diff --git a/usage_costs/models.py b/usage_costs/models.py index 67b5c69e0..3cdee247f 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -40,6 +40,7 @@ def __str__(self): class ModelCategory(models.IntegerChoices): LLM = 1, "LLM" + SELF_HOSTED = 2, "Self-Hosted" class ModelProvider(models.IntegerChoices): @@ -48,17 +49,24 @@ class ModelProvider(models.IntegerChoices): together_ai = 3, "TogetherAI" azure_openai = 4, "Azure OpenAI" + aks = 5, "Azure Kubernetes Service" + def get_model_choices(): from daras_ai_v2.language_model import LargeLanguageModels + from recipes.DeforumSD import AnimationModels - return [(api.name, api.value) for api in LargeLanguageModels] + return [(api.name, api.value) for api in LargeLanguageModels] + [ + (model.name, model.label) for model in AnimationModels + ] class ModelSku(models.IntegerChoices): llm_prompt = 1, "LLM Prompt" llm_completion = 2, "LLM Completion" + gpu_ms = 3, "GPU Milliseconds" + class ModelPricing(models.Model): model_id = models.TextField( @@ -75,7 +83,7 @@ class ModelPricing(models.Model): help_text="The cost per unit.", ) unit_quantity = models.PositiveIntegerField( - help_text="The quantity of the unit. (e.g. 1000 tokens)" + help_text="The quantity of the unit. (e.g. 1000 tokens)", default=1 ) category = models.IntegerField( From b89905ba0912abec84233f1be09d9f77fa0bb280 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 22:07:29 +0530 Subject: [PATCH 45/58] init self hosted pricing --- daras_ai_v2/gpu_server.py | 6 +++++- scripts/init_self_hosted_pricing.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/daras_ai_v2/gpu_server.py b/daras_ai_v2/gpu_server.py index 3050cd98b..752879f8d 100644 --- a/daras_ai_v2/gpu_server.py +++ b/daras_ai_v2/gpu_server.py @@ -154,7 +154,7 @@ def call_celery_task( from usage_costs.cost_utils import record_cost_auto from usage_costs.models import ModelSku - queue = os.path.join(queue_prefix, pipeline["model_id"].strip()).strip("/") + queue = build_queue_name(queue_prefix, pipeline["model_id"]) result = get_celery().send_task( task_name, kwargs=dict(pipeline=pipeline, inputs=inputs), queue=queue ) @@ -164,3 +164,7 @@ def call_celery_task( model=queue, sku=ModelSku.gpu_ms, quantity=int((time() - s) * 1000) ) return ret + + +def build_queue_name(queue_prefix: str, model_id: str) -> str: + return os.path.join(queue_prefix, model_id.strip()).strip("/") diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index d04d42367..05947eb47 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -1,2 +1,29 @@ +from decimal import Decimal + +from daras_ai_v2.gpu_server import build_queue_name +from recipes.DeforumSD import AnimationModels +from usage_costs.models import ModelPricing +from usage_costs.models import ModelSku, ModelCategory, ModelProvider + +category = ModelCategory.SELF_HOSTED + + def run(): - pass + for model in AnimationModels: + add_model(model.value, model.name) + + +def add_model(model_id, model_name): + ModelPricing.objects.get_or_create( + model_id=build_queue_name("gooey-gpu", model_id), + sku=ModelSku.gpu_ms, + defaults=dict( + model_name=model_name, + unit_cost=Decimal("3.673"), + unit_quantity=3600000, + category=category, + provider=ModelProvider.aks, + notes="NC24ads A100 v4 - 1 X A100 - Pay as you go - $3.6730/hour", + pricing_url="https://azure.microsoft.com/en-in/pricing/details/virtual-machines/linux/#pricing", + ), + ) From f4f4c1fcb2833b966e87d7eb5413309060df058f Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 22:19:09 +0530 Subject: [PATCH 46/58] img2img/text2img/inpaint --- scripts/init_self_hosted_pricing.py | 20 +++++++++++++++++++ .../0003_alter_modelpricing_model_name.py | 18 +++++++++++++++++ usage_costs/models.py | 14 +++++++++---- 3 files changed, 48 insertions(+), 4 deletions(-) create mode 100644 usage_costs/migrations/0003_alter_modelpricing_model_name.py diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index 05947eb47..7bc16376a 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -1,6 +1,14 @@ from decimal import Decimal from daras_ai_v2.gpu_server import build_queue_name +from daras_ai_v2.stable_diffusion import ( + Text2ImgModels, + Img2ImgModels, + InpaintingModels, + text2img_model_ids, + img2img_model_ids, + inpaint_model_ids, +) from recipes.DeforumSD import AnimationModels from usage_costs.models import ModelPricing from usage_costs.models import ModelSku, ModelCategory, ModelProvider @@ -11,6 +19,18 @@ def run(): for model in AnimationModels: add_model(model.value, model.name) + for model_enum, model_ids in [ + (Text2ImgModels, text2img_model_ids), + (Img2ImgModels, img2img_model_ids), + (InpaintingModels, inpaint_model_ids), + ]: + for m in model_enum: + if "dall_e" in m.name: + continue + try: + add_model(model_ids[m], m.name) + except KeyError: + pass def add_model(model_id, model_name): diff --git a/usage_costs/migrations/0003_alter_modelpricing_model_name.py b/usage_costs/migrations/0003_alter_modelpricing_model_name.py new file mode 100644 index 000000000..a4cbe05fa --- /dev/null +++ b/usage_costs/migrations/0003_alter_modelpricing_model_name.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-07 16:48 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('usage_costs', '0002_alter_modelpricing_category_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='modelpricing', + name='model_name', + field=models.CharField(choices=[('gpt_4_vision', 'GPT-4 Vision (openai)'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai)'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('llama2_70b_chat', 'Llama 2 (Meta AI)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 (openai)'), ('text_curie_001', 'Curie (openai)'), ('text_babbage_001', 'Babbage (openai)'), ('text_ada_001', 'Ada (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'DALLΒ·E 2 (OpenAI)'), ('dall_e_3', 'DALLΒ·E 3 (OpenAI)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐒'), ('openjourney', 'Open Journey (PromptHero) 🐒'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐒'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐒'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('deepfloyd_if', 'DeepFloyd IF [Deprecated] (stability.ai)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('instruct_pix2pix', '✨ InstructPix2Pix (Tim Brooks)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐒'), ('openjourney', 'Open Journey (PromptHero) 🐒'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐒'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐒'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('runway_ml', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)')], help_text='The name of the model. Only used for Display purposes.', max_length=255), + ), + ] diff --git a/usage_costs/models.py b/usage_costs/models.py index 3cdee247f..5a3bf5171 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -1,6 +1,7 @@ from django.db import models from bots.custom_fields import CustomURLField +from daras_ai_v2.stable_diffusion import InpaintingModels max_digits = 15 decimal_places = 10 @@ -55,10 +56,15 @@ class ModelProvider(models.IntegerChoices): def get_model_choices(): from daras_ai_v2.language_model import LargeLanguageModels from recipes.DeforumSD import AnimationModels - - return [(api.name, api.value) for api in LargeLanguageModels] + [ - (model.name, model.label) for model in AnimationModels - ] + from daras_ai_v2.stable_diffusion import Text2ImgModels, Img2ImgModels + + return ( + [(api.name, api.value) for api in LargeLanguageModels] + + [(model.name, model.label) for model in AnimationModels] + + [(model.name, model.value) for model in Text2ImgModels] + + [(model.name, model.value) for model in Img2ImgModels] + + [(model.name, model.value) for model in InpaintingModels] + ) class ModelSku(models.IntegerChoices): From b0d9125d4caa576e3626ee2a8f969b3b703bb7c0 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 7 Feb 2024 22:26:08 +0530 Subject: [PATCH 47/58] wav2lip pricing --- scripts/init_self_hosted_pricing.py | 1 + usage_costs/models.py | 1 + 2 files changed, 2 insertions(+) diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index 7bc16376a..6ea45c7d3 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -31,6 +31,7 @@ def run(): add_model(model_ids[m], m.name) except KeyError: pass + add_model("gooey-gpu/wav2lip_gan.pth", "wav2lip") def add_model(model_id, model_name): diff --git a/usage_costs/models.py b/usage_costs/models.py index 5a3bf5171..e3fce075f 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -64,6 +64,7 @@ def get_model_choices(): + [(model.name, model.value) for model in Text2ImgModels] + [(model.name, model.value) for model in Img2ImgModels] + [(model.name, model.value) for model in InpaintingModels] + + [("wav2lip", "LipSync (wav2lip)")] ) From 93d786da35d83a6b25ea97bf7fce3dd09b99a8ec Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 8 Feb 2024 18:39:52 +0530 Subject: [PATCH 48/58] expand gdrive folders in parallel avoid mutating the list that is being iterated --- daras_ai_v2/doc_search_settings_widgets.py | 26 +++++++++----- daras_ai_v2/gdrive_downloader.py | 41 +++++++++++----------- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index 39cc152db..f175cd911 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -1,10 +1,14 @@ import os import typing +from furl import furl +from sentry_sdk import capture_exception + import gooey_ui as st from daras_ai_v2 import settings from daras_ai_v2.asr import AsrModels, google_translate_language_selector from daras_ai_v2.enum_selector_widget import enum_selector +from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder from daras_ai_v2.search_ref import CitationStyles _user_media_url_prefix = os.path.join( @@ -76,20 +80,24 @@ def document_uploader( accept_multiple_files=accept_multiple_files, ) documents = st.session_state.get(key, []) - for document in documents: - if not document.startswith("https://drive.google.com/drive/folders"): - continue - from daras_ai_v2.gdrive_downloader import gdrive_list_urls_of_files_in_folder - from furl import furl - - folder_content_urls = gdrive_list_urls_of_files_in_folder(furl(document)) - documents.remove(document) - documents.extend(folder_content_urls) + try: + documents = list(_expand_gdrive_folders(documents)) + except Exception as e: + capture_exception(e) + st.error(f"Error expanding gdrive folders: {e}") st.session_state[key] = documents st.session_state[custom_key] = "\n".join(documents) return documents +def _expand_gdrive_folders(documents: list[str]) -> list[str]: + for url in documents: + if url.startswith("https://drive.google.com/drive/folders"): + yield from gdrive_list_urls_of_files_in_folder(furl(url)) + else: + yield url + + def doc_search_settings( asr_allowed: bool = False, keyword_instructions_allowed: bool = False, diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index 345099aba..5ae9fa176 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -5,6 +5,8 @@ from googleapiclient import discovery from googleapiclient.http import MediaIoBaseDownload +from daras_ai_v2.functional import flatmap_parallel + def is_gdrive_url(f: furl) -> bool: return f.host in ["drive.google.com", "docs.google.com"] @@ -25,34 +27,33 @@ def url_to_gdrive_file_id(f: furl) -> str: return file_id -def gdrive_list_urls_of_files_in_folder(f: furl, max_depth=10) -> list[str]: +def gdrive_list_urls_of_files_in_folder(f: furl, max_depth: int = 4) -> list[str]: if max_depth <= 0: return [] + assert f.host == "drive.google.com", f"Bad google drive folder url: {f}" # get drive folder id from url (e.g. https://drive.google.com/drive/folders/1Xijcsj7oBvDn1OWx4UmNAT8POVKG4W73?usp=drive_link) folder_id = f.path.segments[-1] service = discovery.build("drive", "v3") - if f.host == "drive.google.com": - request = service.files().list( - supportsAllDrives=True, - includeItemsFromAllDrives=True, - q=f"'{folder_id}' in parents", - fields="files(mimeType,webViewLink)", - ) - else: - raise ValueError(f"Can't list files from non google folder url: {str(f)!r}") + request = service.files().list( + supportsAllDrives=True, + includeItemsFromAllDrives=True, + q=f"'{folder_id}' in parents", + fields="files(mimeType,webViewLink)", + ) response = request.execute() files = response.get("files", []) - urls = [] - for file in files: - mime_type = file.get("mimeType") - url = file.get("webViewLink") - if mime_type == "application/vnd.google-apps.folder": - urls += gdrive_list_urls_of_files_in_folder( - furl(url), max_depth=max_depth - 1 + urls = flatmap_parallel( + lambda file: ( + gdrive_list_urls_of_files_in_folder(furl(url), max_depth=max_depth - 1) + if ( + (url := file.get("webViewLink")) + and file.get("mimeType") == "application/vnd.google-apps.folder" ) - elif url: - urls.append(url) - return urls + else [url] + ), + files, + ) + return filter(None, urls) def gdrive_download(f: furl, mime_type: str) -> tuple[bytes, str]: From f9723feef76382ee2713170e91b690112e41a864 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 9 Feb 2024 20:47:12 +0530 Subject: [PATCH 49/58] fix openai usage record for non streaming --- daras_ai_v2/language_model.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 3d48b1068..fd6142758 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -585,20 +585,9 @@ def _run_openai_chat( if stream: return _stream_openai_chunked(r, used_model, messages) else: - from usage_costs.cost_utils import record_cost_auto - from usage_costs.models import ModelSku - - record_cost_auto( - model=used_model, - sku=ModelSku.llm_prompt, - quantity=r.usage.prompt_tokens, - ) - record_cost_auto( - model=used_model, - sku=ModelSku.llm_completion, - quantity=r.usage.completion_tokens, - ) - return [choice.message.dict() for choice in r.choices] + ret = [choice.message.dict() for choice in r.choices] + record_openai_llm_usage(used_model, messages, ret) + return ret def _get_chat_completions_create(model: str, **kwargs): @@ -674,6 +663,12 @@ def _stream_openai_chunked( entry["content"] += entry["chunk"] yield ret + record_openai_llm_usage(used_model, messages, ret) + + +def record_openai_llm_usage( + used_model: str, messages: list[ConversationEntry], choices: list[ConversationEntry] +): from usage_costs.cost_utils import record_cost_auto from usage_costs.models import ModelSku @@ -687,7 +682,7 @@ def _stream_openai_chunked( record_cost_auto( model=used_model, sku=ModelSku.llm_completion, - quantity=sum(default_length_function(entry["content"]) for entry in ret), + quantity=sum(default_length_function(entry["content"]) for entry in choices), ) From d6481579f6e114dfaae25364efd3647e7eb067a6 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 8 Feb 2024 22:17:33 +0530 Subject: [PATCH 50/58] show custom urls in uppy rename Custom URLs -> Submit Links in Bulk --- daras_ai_v2/bot_integration_widgets.py | 3 +++ daras_ai_v2/doc_search_settings_widgets.py | 8 +++----- gooey_ui/components/__init__.py | 4 +--- recipes/QRCodeGenerator.py | 17 +++-------------- routers/root.py | 6 ++++++ 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 0790f0b43..73c6ccad0 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -19,6 +19,9 @@ def general_integration_settings(bi: BotIntegration): st.session_state[f"_bi_user_language_{bi.id}"] = BotIntegration._meta.get_field( "user_language" ).default + st.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( + BotIntegration._meta.get_field("streaming_enabled").default + ) st.session_state[f"_bi_show_feedback_buttons_{bi.id}"] = ( BotIntegration._meta.get_field("show_feedback_buttons").default ) diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index f175cd911..a64572b99 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -40,11 +40,8 @@ def document_uploader( documents = st.session_state.get(key) or [] if isinstance(documents, str): documents = [documents] - has_custom_urls = not all(map(is_user_uploaded_url, documents)) custom_key = "__custom_" + key - if st.checkbox( - "Enter Custom URLs", key=f"__custom_checkbox_{key}", value=has_custom_urls - ): + if st.session_state.get(f"__custom_checkbox_{key}"): if not custom_key in st.session_state: st.session_state[custom_key] = "\n".join(documents) if accept_multiple_files: @@ -67,7 +64,7 @@ def document_uploader( **kwargs, ) if accept_multiple_files: - st.session_state[key] = text_value.strip().splitlines() + st.session_state[key] = filter(None, text_value.strip().splitlines()) else: st.session_state[key] = text_value else: @@ -79,6 +76,7 @@ def document_uploader( accept=accept, accept_multiple_files=accept_multiple_files, ) + st.checkbox("Submit Links in Bulk", key=f"__custom_checkbox_{key}") documents = st.session_state.get(key, []) try: documents = list(_expand_gdrive_folders(documents)) diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index 408b950fb..0814d65e9 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -293,9 +293,7 @@ def text_area( # assert not value, "only one of value or key can be provided" # else: if not key: - key = md5_values( - "textarea", label, height, help, value, placeholder, label_visibility - ) + key = md5_values("textarea", label, height, help, placeholder, label_visibility) value = str(state.session_state.setdefault(key, value) or "") if label_visibility != "visible": label = None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 34b615225..741149afd 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -566,20 +566,9 @@ def vcard_form(*, key: str) -> VCARD: st.error("No contact info found for that email") else: vcard = imported_vcard - # clear inputs - st.js( - # language=js - """ - const form = document.getElementById("gooey-form"); - if (!form) return; - Object.entries(fields).forEach(([k, v]) => { - const field = form["__vcard_data__" + k]; - if (!field) return; - field.value = v; - }); - """, - fields=vcard.dict(), - ) + # update inputs + for k, v in vcard.dict().items(): + st.session_state[f"__vcard_data__{k}"] = v vcard.format_name = st.text_input( "Name*", diff --git a/routers/root.py b/routers/root.py index 6f854ba20..e5f27edc6 100644 --- a/routers/root.py +++ b/routers/root.py @@ -31,6 +31,7 @@ RedirectException, get_example_request_body, ) +from daras_ai_v2.bots import request_json from daras_ai_v2.copy_to_clipboard_button_widget import copy_to_clipboard_scripts from daras_ai_v2.db import FIREBASE_SESSION_COOKIE from daras_ai_v2.manage_api_keys_widget import manage_api_keys @@ -165,6 +166,11 @@ async def logout(request: Request): return RedirectResponse(request.query_params.get("next", DEFAULT_LOGOUT_REDIRECT)) +@app.post("/__/file-upload/url/meta") +async def file_upload(request: Request, body_json: dict = Depends(request_json)): + return dict(name=(body_json["url"]), type="url/undefined", size=None) + + @app.post("/__/file-upload/") def file_upload(request: Request, form_data: FormData = Depends(request_form_files)): from wand.image import Image From 25903559018aea9523e6e5998ede5aa048a6cacc Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:33:12 -0800 Subject: [PATCH 51/58] Added download buttons --- daras_ai_v2/base.py | 47 ++++++++++++++++--------- gooey_ui/components/__init__.py | 37 ++++++++++++++++++++ recipes/CompareText2Img.py | 1 + recipes/CompareUpscaler.py | 1 + recipes/DeforumSD.py | 3 +- recipes/FaceInpainting.py | 1 + recipes/GoogleImageGen.py | 1 + recipes/ImageSegmentation.py | 1 + recipes/Img2Img.py | 9 ++--- recipes/Lipsync.py | 62 +++++++++++++++++++-------------- recipes/LipsyncTTS.py | 37 ++++---------------- recipes/ObjectInpainting.py | 1 + recipes/QRCodeGenerator.py | 5 +-- recipes/Text2Audio.py | 5 +-- recipes/TextToSpeech.py | 1 + recipes/VideoBots.py | 4 +-- 16 files changed, 131 insertions(+), 85 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index f05a0c71c..8dab30806 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1264,7 +1264,9 @@ def _render_report_button(self): if not (self.request.user and run_id and uid): return - reported = st.button("❗Report") + reported = st.button( + ' Report', type="tertiary" + ) if not reported: return @@ -1351,7 +1353,6 @@ def _render_output_col(self, submitted: bool): def _render_completed_output(self): run_time = st.session_state.get(StateKeys.run_time, 0) - st.success(f"Success! Run Time: `{run_time:.2f}` seconds.") def _render_failed_output(self): err_msg = st.session_state.get(StateKeys.error_msg) @@ -1513,22 +1514,36 @@ def clear_outputs(self): st.session_state.pop(field_name, None) def _render_after_output(self): - col1, col2, col3 = st.columns([1, 1, 1], responsive=False) - col2.node.props[ - "className" - ] += " d-flex justify-content-center align-items-center" - col3.node.props["className"] += " d-flex justify-content-end align-items-center" + caption = "" + caption += f'\\\nGenerated in {st.session_state.get(StateKeys.run_time, 0):.2f}s' if "seed" in self.RequestModel.schema_json(): seed = st.session_state.get("seed") - with col1: - st.caption(f"*Seed\\\n`{seed}`*") - with col2: - randomize = st.button("♻️ Regenerate") - if randomize: - st.session_state[StateKeys.pressed_randomize] = True - st.experimental_rerun() - with col3: - self._render_report_button() + caption += f' with seed {seed} ' + try: + format_created_at = st.session_state.get( + "created_at", datetime.datetime.today() + ).strftime("%d %b %Y %-I:%M%p") + except: + format_created_at = format_created_at = st.session_state.get( + "created_at", datetime.datetime.today() + ) + caption += f' at {format_created_at}' + st.caption(caption, unsafe_allow_html=True) + + def render_buttons(self, url: str): + st.download_button( + label=' Download', + url=url, + type="secondary", + ) + if "seed" in self.RequestModel.schema_json(): + randomize = st.button( + ' Regenerate', type="tertiary" + ) + if randomize: + st.session_state[StateKeys.pressed_randomize] = True + st.experimental_rerun() + self._render_report_button() def state_to_doc(self, state: dict): ret = { diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index 0814d65e9..7e435b725 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -453,6 +453,43 @@ def button( form_submit_button = button +def download_button( + label: str, + url: str, + key: str = None, + help: str = None, + *, + type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", + disabled: bool = False, + **props, +) -> bool: + """ + Example: + st.button("Primary", key="test0", type="primary") + st.button("Secondary", key="test1") + st.button("Tertiary", key="test3", type="tertiary") + st.button("Link Button", key="test3", type="link") + """ + if not key: + key = md5_values("button", label, help, type, props) + className = f"btn-{type} " + props.pop("className", "") + state.RenderTreeNode( + name="download-button", + props=dict( + type="submit", + value="yes", + url=url, + name=key, + label=dedent(label), + help=help, + disabled=disabled, + className=className, + **props, + ), + ).mount() + return bool(state.session_state.pop(key, False)) + + def expander(label: str, *, expanded: bool = False, **props): node = state.RenderTreeNode( name="expander", diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 5f44ca34b..79fd76273 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -247,6 +247,7 @@ def _render_outputs(self, state): output_images: dict = state.get("output_images", {}).get(key, []) for img in output_images: st.image(img, caption=Text2ImgModels[key].value) + self.render_buttons(img) def preview_description(self, state: dict) -> str: return "Create multiple AI photos from one prompt using Stable Diffusion (1.5 -> 2.1, Open/Midjourney), DallE, and other models. Find out which AI Image generator works best for your text prompt on comparing OpenAI, Stability.AI etc." diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index e4ed8865f..3206f532d 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -108,6 +108,7 @@ def _render_outputs(self, state): if not img: continue st.image(img, caption=UpscalerModels[key].value) + self.render_buttons(img) def get_raw_price(self, state: dict) -> int: selected_models = state.get("selected_models", []) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index 57b5bb9f7..d971e708c 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -423,8 +423,9 @@ def render_description(self): def render_output(self): output_video = st.session_state.get("output_video") if output_video: - st.write("Output Video") + st.write("#### Output Video") st.video(output_video, autoplay=True) + self.render_buttons(output_video) def estimate_run_duration(self): # in seconds diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index ee377b53d..3c14a8147 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -201,6 +201,7 @@ def render_output(self): url, caption="```" + text_prompt.replace("\n", "") + "```", ) + self.render_buttons(url) else: st.div() diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index c3dab52f0..f49657346 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -207,6 +207,7 @@ def render_output(self): if out_imgs: for img in out_imgs: st.image(img, caption="Generated Image") + self.render_buttons(img) else: st.div() diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 751a0936b..0f5eacbee 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -343,6 +343,7 @@ def render_example(self, state: dict): input_image = state.get("input_image") if input_image: st.image(input_image, caption="Input Photo") + self.render_buttons(input_image) else: st.div() diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 68e696c08..7334ebb31 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -91,7 +91,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.file_uploader( """ - ### Input Image + #### Input Image """, key="input_image", upload_meta=dict(resize=f"{SD_IMG_MAX_SIZE[0] * SD_IMG_MAX_SIZE[1]}@>"), @@ -100,7 +100,7 @@ def render_form_v2(self): if st.session_state.get("selected_model") != InpaintingModels.dall_e.name: st.text_area( """ - ### Prompt + #### Prompt Describe your edits """, key="text_prompt", @@ -129,9 +129,10 @@ def render_usage_guide(self): def render_output(self): text_prompt = st.session_state.get("text_prompt", "") output_images = st.session_state.get("output_images", []) - + st.write("#### Output Image") for img in output_images: - st.image(img, caption="```" + text_prompt.replace("\n", "") + "```") + st.image(img) + self.render_buttons(img) def render_example(self, state: dict): col1, col2 = st.columns(2) diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 5fe08a09b..8ea341aad 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -82,33 +82,13 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) def render_example(self, state: dict): - col1, col2 = st.columns(2) - - with col1: - input_face = state.get("input_face") - if not input_face: - st.div() - elif input_face.endswith(".mp4") or input_face.endswith(".mov"): - st.write("Input Face (Video)") - st.video(input_face) - else: - st.write("Input Face (Image)") - st.image(input_face) - - input_audio = state.get("input_audio") - if input_audio: - st.write("Input Audio") - st.audio(input_audio) - else: - st.div() - - with col2: - output_video = state.get("output_video") - if output_video: - st.write("Output Video") - st.video(output_video, autoplay=True) - else: - st.div() + output_video = state.get("output_video") + if output_video: + st.write("#### Output Video") + st.video(output_video, autoplay=True) + self.render_buttons(output_video) + else: + st.div() def render_output(self): self.render_example(st.session_state) @@ -145,3 +125,31 @@ def get_raw_price(self, state: dict) -> float: total_mb = total_bytes / 1024 / 1024 return total_mb * CREDITS_PER_MB + + def download_blob(self, bucket_name, source_blob_name, destination_file_name): + """Downloads a blob from the bucket.""" + # The ID of your GCS bucket + # bucket_name = "your-bucket-name" + + # The ID of your GCS object + # source_blob_name = "storage-object-name" + + # The path to which the file should be downloaded + # destination_file_name = "local/path/to/file" + + storage_client = storage.Client() + + bucket = storage_client.bucket(bucket_name) + + # Construct a client side representation of a blob. + # Note `Bucket.blob` differs from `Bucket.get_blob` as it doesn't retrieve + # any content from Google Cloud Storage. As we don't need additional data, + # using `Bucket.blob` is preferred here. + blob = bucket.blob(source_blob_name) + blob.download_to_filename(destination_file_name) + + print( + "Downloaded storage object {} from bucket {} to local file {}.".format( + source_blob_name, bucket_name, destination_file_name + ) + ) diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 60886c017..d463d7483 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -127,37 +127,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield from LipsyncPage.run(self, state) def render_example(self, state: dict): - col1, col2 = st.columns(2) - - with col1: - input_face = state.get("input_face") - if not input_face: - pass - elif input_face.endswith(".mp4") or input_face.endswith(".mov"): - st.video(input_face, caption="Input Face (Video)") - else: - st.image(input_face, caption="Input Face (Image)") - - input_text = state.get("text_prompt") - if input_text: - st.write("**Input Text**") - st.write(input_text) - else: - st.div() - - # input_audio = state.get("input_audio") - # if input_audio: - # st.write("Synthesized Voice") - # st.audio(input_audio) - # else: - # st.empty() - - with col2: - output_video = state.get("output_video") - if output_video: - st.video(output_video, caption="Output Video", autoplay=True) - else: - st.div() + output_video = state.get("output_video") + if output_video: + st.video(output_video, caption="#### Output Video", autoplay=True) + self.render_buttons(output_video) + else: + st.div() def render_output(self): self.render_example(st.session_state) diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index d84c59069..4d7f7342b 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -202,6 +202,7 @@ def render_output(self): if output_images: for url in output_images: st.image(url, caption=f"{text_prompt}") + self.render_buttons(url) else: st.div() diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 741149afd..1fe283825 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -144,7 +144,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.text_area( """ - ##### πŸ‘©β€πŸ’» Prompt + #### πŸ‘©β€πŸ’» Prompt Describe the subject/scene of the QR Code. Choose clear prompts and distinguishable visuals to ensure optimal readability. """, @@ -208,7 +208,7 @@ def render_form_v2(self): st.file_uploader( """ - ##### 🏞️ Reference Image *[optional]* + #### 🏞️ Reference Image *[optional]* This image will be used as inspiration to blend with the QR Code. """, key="image_prompt", @@ -458,6 +458,7 @@ def _render_outputs(self, state: dict, max_count: int | None = None): st.caption(f"{shortened_url} β†’ {qr_code_data} (Views: {clicks})") else: st.caption(f"{shortened_url} β†’ {qr_code_data}") + self.render_buttons(img) def run(self, state: dict) -> typing.Iterator[str | None]: request: QRCodeGeneratorPage.RequestModel = self.RequestModel.parse_obj(state) diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 18270cc8c..0e0991c35 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -128,7 +128,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) def render_output(self): - _render_output(st.session_state) + _render_output(self, st.session_state) def render_example(self, state: dict): col1, col2 = st.columns(2) @@ -141,9 +141,10 @@ def preview_description(self, state: dict) -> str: return "Generate AI Music with text instruction prompts. AudiLDM is capable of generating realistic audio samples by process any text input. Learn more [here](https://huggingface.co/cvssp/audioldm-m-full)." -def _render_output(state): +def _render_output(self, state): selected_models = state.get("selected_models", []) for key in selected_models: output: dict = state.get("output_audios", {}).get(key, []) for audio in output: st.audio(audio, caption=Text2AudioModels[key].value) + self.render_buttons(audio) diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index 721b2cb7a..de13bfbc0 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -135,6 +135,7 @@ def render_output(self): audio_url = st.session_state.get("audio_url") if audio_url: st.audio(audio_url) + self.render_buttons(audio_url) else: st.div() diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 5d23502ff..2d0bc26e0 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -312,7 +312,7 @@ def render_description(self): def render_form_v2(self): st.text_area( """ - ##### πŸ“ Prompt + #### πŸ“ Prompt High-level system instructions. """, key="bot_script", @@ -321,7 +321,7 @@ def render_form_v2(self): document_uploader( """ -##### πŸ“„ Documents (*optional*) +#### πŸ“„ Documents (*optional*) Upload documents or enter URLs to give your copilot a knowledge base. With each incoming user message, we'll search your documents via a vector DB query. """ ) From cd991b5664014b6d238e01419cf69d1f969a36ad Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:48:19 -0800 Subject: [PATCH 52/58] Fixed headings --- daras_ai_v2/img_model_settings_widgets.py | 2 +- daras_ai_v2/prompt_vars.py | 2 +- recipes/CompareText2Img.py | 2 +- recipes/CompareUpscaler.py | 2 +- recipes/DocExtract.py | 4 ++-- recipes/DocSearch.py | 6 +++--- recipes/DocSummary.py | 4 ++-- recipes/EmailFaceInpainting.py | 4 ++-- recipes/FaceInpainting.py | 10 ++++------ recipes/GoogleGPT.py | 2 +- recipes/GoogleImageGen.py | 6 +++--- recipes/ImageSegmentation.py | 2 +- recipes/ObjectInpainting.py | 4 ++-- recipes/SEOSummary.py | 4 ++-- recipes/SocialLookupEmail.py | 6 +++--- recipes/Text2Audio.py | 2 +- recipes/TextToSpeech.py | 2 +- recipes/asr.py | 4 ++-- 18 files changed, 33 insertions(+), 35 deletions(-) diff --git a/daras_ai_v2/img_model_settings_widgets.py b/daras_ai_v2/img_model_settings_widgets.py index e86c18873..8249833a0 100644 --- a/daras_ai_v2/img_model_settings_widgets.py +++ b/daras_ai_v2/img_model_settings_widgets.py @@ -83,7 +83,7 @@ def model_selector( selected_model = enum_selector( models_enum, label=""" - ### πŸ€– Choose your preferred AI Model + #### πŸ€– Choose your preferred AI Model """, key="selected_model", use_selectbox=True, diff --git a/daras_ai_v2/prompt_vars.py b/daras_ai_v2/prompt_vars.py index 80f13d906..fabb3928c 100644 --- a/daras_ai_v2/prompt_vars.py +++ b/daras_ai_v2/prompt_vars.py @@ -28,7 +28,7 @@ def prompt_vars_widget(*keys: str, variables_key: str = "variables"): if not (template_vars or err): return - st.write("##### βŒ₯ Variables") + st.write("#### βŒ₯ Variables") old_state = st.session_state.get(variables_key, {}) new_state = {} for name in sorted(template_vars): diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 79fd76273..47347bb65 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -96,7 +96,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.text_area( """ - ### πŸ‘©β€πŸ’» Prompt + #### πŸ‘©β€πŸ’» Prompt Describe the scene that you'd like to generate. """, key="text_prompt", diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index 3206f532d..4aad4ffb2 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -34,7 +34,7 @@ class ResponseModel(BaseModel): def render_form_v2(self): st.file_uploader( """ - ### Input Image + #### Input Image """, key="input_image", upload_meta=dict(resize=f"{SD_IMG_MAX_SIZE[0] * SD_IMG_MAX_SIZE[1]}@>"), diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 2672fffdb..ade752624 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -94,11 +94,11 @@ def preview_image(self, state: dict) -> str | None: def render_form_v2(self): document_uploader( - "##### πŸ€– Youtube URLS", + "#### πŸ€– Youtube URLS", accept=("audio/*", "application/pdf", "video/*"), ) st.text_input( - "##### πŸ“Š Google Sheets URL", + "#### πŸ“Š Google Sheets URL", key="sheet_url", ) diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 0b95464e1..6d7977c25 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -82,8 +82,8 @@ class ResponseModel(BaseModel): final_search_query: str | None def render_form_v2(self): - st.text_area("##### Search Query", key="search_query") - document_uploader("##### Documents") + st.text_area("#### Search Query", key="search_query") + document_uploader("#### Documents") prompt_vars_widget("task_instructions", "query_instructions") def validate_form_v2(self): @@ -111,7 +111,7 @@ def render_output(self): def render_example(self, state: dict): render_documents(state) - st.write("**Search Query**") + st.html("**Search Query**") st.write("```properties\n" + state.get("search_query", "") + "\n```") render_output_with_refs(state, 200) diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 498f2d405..476bd5da6 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -84,8 +84,8 @@ def preview_image(self, state: dict) -> str | None: return DEFAULT_DOC_SUMMARY_META_IMG def render_form_v2(self): - document_uploader("##### πŸ“Ž Documents") - st.text_area("##### πŸ‘©β€πŸ’» Instructions", key="task_instructions", height=150) + document_uploader("#### πŸ“Ž Documents") + st.text_area("#### πŸ‘©β€πŸ’» Instructions", key="task_instructions", height=150) def render_settings(self): st.text_area( diff --git a/recipes/EmailFaceInpainting.py b/recipes/EmailFaceInpainting.py index 74749bdb5..4b8b50d1e 100644 --- a/recipes/EmailFaceInpainting.py +++ b/recipes/EmailFaceInpainting.py @@ -115,7 +115,7 @@ def render_description(self): def render_form_v2(self): st.text_area( """ - ### Prompt + #### Prompt Describe the scene that you'd like to generate around the face. """, key="text_prompt", @@ -130,7 +130,7 @@ def render_form_v2(self): source = st.radio( """ - ### Photo Source + #### Photo Source From where we should get the photo?""", options=["Email Address", "Twitter Handle"], key="__photo_source", diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index 3c14a8147..f6a550d48 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -105,7 +105,7 @@ def render_description(self): def render_form_v2(self): st.text_area( """ - ### Prompt + #### Prompt Describe the character that you'd like to generate. """, key="text_prompt", @@ -114,7 +114,7 @@ def render_form_v2(self): st.file_uploader( """ - ### Face Photo + #### Face Photo Give us a photo of yourself, or anyone else """, key="input_image", @@ -196,11 +196,9 @@ def render_output(self): output_images = st.session_state.get("output_images") if output_images: + st.write("#### Output Image") for url in output_images: - st.image( - url, - caption="```" + text_prompt.replace("\n", "") + "```", - ) + st.image(url) self.render_buttons(url) else: st.div() diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 7a2d2591f..35d650acb 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -111,7 +111,7 @@ class ResponseModel(BaseModel): final_search_query: str | None def render_form_v2(self): - st.text_area("##### Google Search Query", key="search_query") + st.text_area("#### Google Search Query", key="search_query") st.text_input("Search on a specific site *(optional)*", key="site_filter") prompt_vars_widget("task_instructions", "query_instructions") diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index f49657346..fd5d2db80 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -180,7 +180,7 @@ def run(self, state: dict): def render_form_v2(self): st.text_input( """ - ### πŸ”Ž Google Image Search + #### πŸ”Ž Google Image Search Type a query you'd use in [Google image search](https://images.google.com/?gws_rd=ssl) """, key="search_query", @@ -188,7 +188,7 @@ def render_form_v2(self): model_selector(Img2ImgModels) st.text_area( """ - ### πŸ‘©β€πŸ’» Prompt + #### πŸ‘©β€πŸ’» Prompt Describe how you want to edit the photo in words """, key="text_prompt", @@ -206,7 +206,7 @@ def render_output(self): out_imgs = st.session_state.get("output_images") if out_imgs: for img in out_imgs: - st.image(img, caption="Generated Image") + st.image(img, caption="#### Generated Image") self.render_buttons(img) else: st.div() diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 0f5eacbee..39713cb2b 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -86,7 +86,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.file_uploader( """ - ### Input Photo + #### Input Photo Give us a photo of anything """, key="input_image", diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index 4d7f7342b..c2c3b29c4 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -95,7 +95,7 @@ def related_workflows(self) -> list: def render_form_v2(self): st.text_area( """ - ### Prompt + #### Prompt Describe the scene that you'd like to generate. """, key="text_prompt", @@ -104,7 +104,7 @@ def render_form_v2(self): st.file_uploader( """ - ### Object Photo + #### Object Photo Give us a photo of anything """, key="input_image", diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index 78328d822..79d8092a8 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -160,7 +160,7 @@ def render_description(self): ) def render_form_v2(self): - st.write("### Inputs") + st.write("#### Inputs") st.text_input("Google Search Query", key="search_query") st.text_input("Website Name", key="title") st.text_input("Website URL", key="company_url") @@ -192,7 +192,7 @@ def render_settings(self): def render_output(self): output_content = st.session_state.get("output_content") if output_content: - st.write("### Generated Content") + st.write("#### Generated Content") for idx, text in enumerate(output_content): if st.session_state.get("enable_html"): scrollable_html(text) diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index 1094cea53..b2518b1c8 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -120,7 +120,7 @@ def render_settings(self): def render_form_v2(self): st.text_input( """ - ### Email Address + #### Email Address Give us an email address and we'll try to get determine the profile data associated with it """, key="email_address", @@ -132,7 +132,7 @@ def render_form_v2(self): st.text_area( """ - ### Email Body + #### Email Body """, key="input_email_body", height=200, @@ -198,7 +198,7 @@ def _input_variables(self, state: dict): def render_output(self): st.text_area( """ - ### Email Body Output + #### Email Body Output """, disabled=True, value=st.session_state.get("output_email_body", ""), diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 0e0991c35..4d68b8eb5 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -64,7 +64,7 @@ def preview_image(self, state: dict) -> str | None: def render_form_v2(self): st.text_area( """ - ### πŸ‘©β€πŸ’» Prompt + #### πŸ‘©β€πŸ’» Prompt Describe the audio that you'd like to generate. """, key="text_prompt", diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index de13bfbc0..b90bb8a4b 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -100,7 +100,7 @@ def render_description(self): def render_form_v2(self): st.text_area( """ - ### Prompt + #### Prompt Enter text you want to convert to speech """, key="text_prompt", diff --git a/recipes/asr.py b/recipes/asr.py index a297bf9c3..7dcea4249 100644 --- a/recipes/asr.py +++ b/recipes/asr.py @@ -82,14 +82,14 @@ def related_workflows(self) -> list: def render_form_v2(self): document_uploader( - "##### Audio Files", + "#### Audio Files", accept=("audio/*", "video/*", "application/octet-stream"), ) col1, col2 = st.columns(2, responsive=False) with col1: selected_model = enum_selector( AsrModels, - label="##### ASR Model", + label="#### ASR Model", key="selected_model", use_selectbox=True, ) From e985052c1dc3b01aa64abd647dbcea3d19aed906 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:56:46 -0800 Subject: [PATCH 53/58] Fixed datetime format --- daras_ai_v2/base.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 8dab30806..75068f911 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1519,14 +1519,12 @@ def _render_after_output(self): if "seed" in self.RequestModel.schema_json(): seed = st.session_state.get("seed") caption += f' with seed {seed} ' - try: - format_created_at = st.session_state.get( - "created_at", datetime.datetime.today() - ).strftime("%d %b %Y %-I:%M%p") - except: - format_created_at = format_created_at = st.session_state.get( - "created_at", datetime.datetime.today() - ) + created_at = st.session_state.get( + StateKeys.created_at, datetime.datetime.today() + ) + if not isinstance(created_at, datetime.datetime): + created_at = datetime.datetime.fromisoformat(created_at) + format_created_at = created_at.strftime("%d %b %Y %-I:%M%p") caption += f' at {format_created_at}' st.caption(caption, unsafe_allow_html=True) From 6f75eca0c212cc5242bf787d8f9fdef2b69b4c57 Mon Sep 17 00:00:00 2001 From: clr-li <111320104+clr-li@users.noreply.github.com> Date: Fri, 26 Jan 2024 13:15:03 -0800 Subject: [PATCH 54/58] Regen button appears once --- daras_ai_v2/base.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 75068f911..28756f96a 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1514,6 +1514,13 @@ def clear_outputs(self): st.session_state.pop(field_name, None) def _render_after_output(self): + if "seed" in self.RequestModel.schema_json(): + randomize = st.button( + ' Regenerate', type="tertiary" + ) + if randomize: + st.session_state[StateKeys.pressed_randomize] = True + st.experimental_rerun() caption = "" caption += f'\\\nGenerated in {st.session_state.get(StateKeys.run_time, 0):.2f}s' if "seed" in self.RequestModel.schema_json(): @@ -1534,13 +1541,6 @@ def render_buttons(self, url: str): url=url, type="secondary", ) - if "seed" in self.RequestModel.schema_json(): - randomize = st.button( - ' Regenerate', type="tertiary" - ) - if randomize: - st.session_state[StateKeys.pressed_randomize] = True - st.experimental_rerun() self._render_report_button() def state_to_doc(self, state: dict): From 64b7dfcefe19b8da6f3bb3caa84b313786a43799 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 12 Feb 2024 11:57:49 +0530 Subject: [PATCH 55/58] fix report button re-usability fuckup put download button inside the audio/video/image components consistent datetime format fix half ass caption showing up when the recipe is running make download_button re-use the button code --- bots/models.py | 14 +++---- daras_ai_v2/base.py | 55 ++++++++++++------------- daras_ai_v2/settings.py | 2 + gooey_ui/components/__init__.py | 72 +++++++++++++++++---------------- recipes/CompareText2Img.py | 5 ++- recipes/CompareUpscaler.py | 3 +- recipes/DeforumSD.py | 3 +- recipes/FaceInpainting.py | 3 +- recipes/GoogleImageGen.py | 3 +- recipes/ImageSegmentation.py | 3 +- recipes/Img2Img.py | 6 +-- recipes/Lipsync.py | 3 +- recipes/LipsyncTTS.py | 8 +++- recipes/ObjectInpainting.py | 3 +- recipes/QRCodeGenerator.py | 3 +- recipes/Text2Audio.py | 9 +++-- recipes/TextToSpeech.py | 7 +--- 17 files changed, 101 insertions(+), 101 deletions(-) diff --git a/bots/models.py b/bots/models.py index 205dfb5ff..558512e8c 100644 --- a/bots/models.py +++ b/bots/models.py @@ -700,8 +700,8 @@ def to_df_format( .replace(tzinfo=None) ) row |= { - "Last Sent": last_time.strftime("%b %d, %Y %I:%M %p"), - "First Sent": first_time.strftime("%b %d, %Y %I:%M %p"), + "Last Sent": last_time.strftime(settings.SHORT_DATETIME_FORMAT), + "First Sent": first_time.strftime(settings.SHORT_DATETIME_FORMAT), "A7": not convo.d7(), "A30": not convo.d30(), "R1": last_time - first_time < datetime.timedelta(days=1), @@ -926,7 +926,7 @@ def to_df_format( "Message (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Feedback": ( message.feedbacks.first().get_display_text() if message.feedbacks.first() @@ -968,7 +968,7 @@ def to_df_analysis_format( "Answer (EN)": message.content, "Sent": message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Analysis JSON": message.analysis_result, } rows.append(row) @@ -1153,16 +1153,16 @@ def to_df_format( "Question Sent": feedback.message.get_previous_by_created_at() .created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Answer (EN)": feedback.message.content, "Answer Sent": feedback.message.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Rating": Feedback.Rating(feedback.rating).label, "Feedback (EN)": feedback.text_english, "Feedback Sent": feedback.created_at.astimezone(tz) .replace(tzinfo=None) - .strftime("%b %d, %Y %I:%M %p"), + .strftime(settings.SHORT_DATETIME_FORMAT), "Question Answered": feedback.message.question_answered, } rows.append(row) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 28756f96a..0a885bf0a 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1348,11 +1348,11 @@ def _render_output_col(self, submitted: bool): # render outputs self.render_output() - if run_state != "waiting": + if run_state != RecipeRunState.running: self._render_after_output() def _render_completed_output(self): - run_time = st.session_state.get(StateKeys.run_time, 0) + pass def _render_failed_output(self): err_msg = st.session_state.get(StateKeys.error_msg) @@ -1368,12 +1368,10 @@ def render_extra_waiting_output(self): if not estimated_run_time: return if created_at := st.session_state.get("created_at"): - if isinstance(created_at, datetime.datetime): - start_time = created_at - else: - start_time = datetime.datetime.fromisoformat(created_at) + if isinstance(created_at, str): + created_at = datetime.datetime.fromisoformat(created_at) with st.countdown_timer( - end_time=start_time + datetime.timedelta(seconds=estimated_run_time), + end_time=created_at + datetime.timedelta(seconds=estimated_run_time), delay_text="Sorry for the wait. Your run is taking longer than we expected.", ): if self.is_current_user_owner() and self.request.user.email: @@ -1514,6 +1512,8 @@ def clear_outputs(self): st.session_state.pop(field_name, None) def _render_after_output(self): + self._render_report_button() + if "seed" in self.RequestModel.schema_json(): randomize = st.button( ' Regenerate', type="tertiary" @@ -1521,27 +1521,8 @@ def _render_after_output(self): if randomize: st.session_state[StateKeys.pressed_randomize] = True st.experimental_rerun() - caption = "" - caption += f'\\\nGenerated in {st.session_state.get(StateKeys.run_time, 0):.2f}s' - if "seed" in self.RequestModel.schema_json(): - seed = st.session_state.get("seed") - caption += f' with seed {seed} ' - created_at = st.session_state.get( - StateKeys.created_at, datetime.datetime.today() - ) - if not isinstance(created_at, datetime.datetime): - created_at = datetime.datetime.fromisoformat(created_at) - format_created_at = created_at.strftime("%d %b %Y %-I:%M%p") - caption += f' at {format_created_at}' - st.caption(caption, unsafe_allow_html=True) - def render_buttons(self, url: str): - st.download_button( - label=' Download', - url=url, - type="secondary", - ) - self._render_report_button() + render_output_caption() def state_to_doc(self, state: dict): ret = { @@ -1918,6 +1899,26 @@ def is_current_user_owner(self) -> bool: ) +def render_output_caption(): + caption = "" + + run_time = st.session_state.get(StateKeys.run_time, 0) + if run_time: + caption += f'Generated in {run_time :.2f}s' + + if seed := st.session_state.get("seed"): + caption += f' with seed {seed} ' + + created_at = st.session_state.get(StateKeys.created_at, datetime.datetime.today()) + if created_at: + if isinstance(created_at, str): + created_at = datetime.datetime.fromisoformat(created_at) + format_created_at = created_at.strftime(settings.SHORT_DATETIME_FORMAT) + caption += f' at {format_created_at}' + + st.caption(caption, unsafe_allow_html=True) + + def get_example_request_body( request_model: typing.Type[BaseModel], state: dict, diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index deda81500..512fb3834 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -162,6 +162,8 @@ es_formats.DATETIME_FORMAT = DATETIME_FORMAT +SHORT_DATETIME_FORMAT = "%b %d, %Y %-I:%M %p" + # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.2/howto/static-files/ diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py index 7e435b725..6322cf629 100644 --- a/gooey_ui/components/__init__.py +++ b/gooey_ui/components/__init__.py @@ -217,6 +217,7 @@ def image( caption: str = None, alt: str = None, href: str = None, + show_download_button: bool = False, **props, ): if isinstance(src, np.ndarray): @@ -241,9 +242,18 @@ def image( **props, ), ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) -def video(src: str, caption: str = None, autoplay: bool = False): +def video( + src: str, + caption: str = None, + autoplay: bool = False, + show_download_button: bool = False, +): autoplay_props = {} if autoplay: autoplay_props = { @@ -266,15 +276,23 @@ def video(src: str, caption: str = None, autoplay: bool = False): name="video", props=dict(src=src, caption=dedent(caption), **autoplay_props), ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) -def audio(src: str, caption: str = None): +def audio(src: str, caption: str = None, show_download_button: bool = False): if not src: return state.RenderTreeNode( name="audio", props=dict(src=src, caption=dedent(caption)), ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) def text_area( @@ -415,8 +433,9 @@ def selectbox( return value -def button( +def download_button( label: str, + url: str, key: str = None, help: str = None, *, @@ -424,43 +443,26 @@ def button( disabled: bool = False, **props, ) -> bool: - """ - Example: - st.button("Primary", key="test0", type="primary") - st.button("Secondary", key="test1") - st.button("Tertiary", key="test3", type="tertiary") - st.button("Link Button", key="test3", type="link") - """ - if not key: - key = md5_values("button", label, help, type, props) - className = f"btn-{type} " + props.pop("className", "") - state.RenderTreeNode( - name="gui-button", - props=dict( - type="submit", - value="yes", - name=key, - label=dedent(label), - help=help, - disabled=disabled, - className=className, - **props, - ), - ).mount() - return bool(state.session_state.pop(key, False)) - - -form_submit_button = button + return button( + component="download-button", + url=url, + label=label, + key=key, + help=help, + type=type, + disabled=disabled, + **props, + ) -def download_button( +def button( label: str, - url: str, key: str = None, help: str = None, *, type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", disabled: bool = False, + component: typing.Literal["download-button", "gui-button"] = "gui-button", **props, ) -> bool: """ @@ -474,11 +476,10 @@ def download_button( key = md5_values("button", label, help, type, props) className = f"btn-{type} " + props.pop("className", "") state.RenderTreeNode( - name="download-button", + name=component, props=dict( type="submit", value="yes", - url=url, name=key, label=dedent(label), help=help, @@ -490,6 +491,9 @@ def download_button( return bool(state.session_state.pop(key, False)) +form_submit_button = button + + def expander(label: str, *, expanded: bool = False, **props): node = state.RenderTreeNode( name="expander", diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index 47347bb65..dc5ea1ae2 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -246,8 +246,9 @@ def _render_outputs(self, state): for key in selected_models: output_images: dict = state.get("output_images", {}).get(key, []) for img in output_images: - st.image(img, caption=Text2ImgModels[key].value) - self.render_buttons(img) + st.image( + img, caption=Text2ImgModels[key].value, show_download_button=True + ) def preview_description(self, state: dict) -> str: return "Create multiple AI photos from one prompt using Stable Diffusion (1.5 -> 2.1, Open/Midjourney), DallE, and other models. Find out which AI Image generator works best for your text prompt on comparing OpenAI, Stability.AI etc." diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index 4aad4ffb2..13f62de91 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -107,8 +107,7 @@ def _render_outputs(self, state): img: dict = state.get("output_images", {}).get(key) if not img: continue - st.image(img, caption=UpscalerModels[key].value) - self.render_buttons(img) + st.image(img, caption=UpscalerModels[key].value, show_download_button=True) def get_raw_price(self, state: dict) -> int: selected_models = state.get("selected_models", []) diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index d971e708c..35a71a12c 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -424,8 +424,7 @@ def render_output(self): output_video = st.session_state.get("output_video") if output_video: st.write("#### Output Video") - st.video(output_video, autoplay=True) - self.render_buttons(output_video) + st.video(output_video, autoplay=True, show_download_button=True) def estimate_run_duration(self): # in seconds diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index f6a550d48..6d787bba6 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -198,8 +198,7 @@ def render_output(self): if output_images: st.write("#### Output Image") for url in output_images: - st.image(url) - self.render_buttons(url) + st.image(url, show_download_button=True) else: st.div() diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index fd5d2db80..278128a37 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -206,8 +206,7 @@ def render_output(self): out_imgs = st.session_state.get("output_images") if out_imgs: for img in out_imgs: - st.image(img, caption="#### Generated Image") - self.render_buttons(img) + st.image(img, caption="#### Generated Image", show_download_button=True) else: st.div() diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index 39713cb2b..c247fdeda 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -342,8 +342,7 @@ def render_example(self, state: dict): with col1: input_image = state.get("input_image") if input_image: - st.image(input_image, caption="Input Photo") - self.render_buttons(input_image) + st.image(input_image, caption="Input Photo", show_download_button=True) else: st.div() diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 7334ebb31..e2713aaa2 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -127,12 +127,12 @@ def render_usage_guide(self): youtube_video("narcZNyuNAg") def render_output(self): - text_prompt = st.session_state.get("text_prompt", "") output_images = st.session_state.get("output_images", []) + if not output_images: + return st.write("#### Output Image") for img in output_images: - st.image(img) - self.render_buttons(img) + st.image(img, show_download_button=True) def render_example(self, state: dict): col1, col2 = st.columns(2) diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 8ea341aad..c2ac64369 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -85,8 +85,7 @@ def render_example(self, state: dict): output_video = state.get("output_video") if output_video: st.write("#### Output Video") - st.video(output_video, autoplay=True) - self.render_buttons(output_video) + st.video(output_video, autoplay=True, show_download_button=True) else: st.div() diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index d463d7483..1fc0f6c64 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -129,8 +129,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]: def render_example(self, state: dict): output_video = state.get("output_video") if output_video: - st.video(output_video, caption="#### Output Video", autoplay=True) - self.render_buttons(output_video) + st.video( + output_video, + caption="#### Output Video", + autoplay=True, + show_download_button=True, + ) else: st.div() diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index c2c3b29c4..a1e0c2449 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -201,8 +201,7 @@ def render_output(self): if output_images: for url in output_images: - st.image(url, caption=f"{text_prompt}") - self.render_buttons(url) + st.image(url, caption=f"{text_prompt}", show_download_button=True) else: st.div() diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 1fe283825..1512e4523 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -436,7 +436,7 @@ def _render_outputs(self, state: dict, max_count: int | None = None): if max_count: output_images = output_images[:max_count] for img in output_images: - st.image(img) + st.image(img, show_download_button=True) qr_code_data = ( state.get(QrSources.qr_code_data.name) or state.get(QrSources.qr_code_input_image.name) @@ -458,7 +458,6 @@ def _render_outputs(self, state: dict, max_count: int | None = None): st.caption(f"{shortened_url} β†’ {qr_code_data} (Views: {clicks})") else: st.caption(f"{shortened_url} β†’ {qr_code_data}") - self.render_buttons(img) def run(self, state: dict) -> typing.Iterator[str | None]: request: QRCodeGeneratorPage.RequestModel = self.RequestModel.parse_obj(state) diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index 4d68b8eb5..f7ddf5aab 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -128,7 +128,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ) def render_output(self): - _render_output(self, st.session_state) + _render_output(st.session_state) def render_example(self, state: dict): col1, col2 = st.columns(2) @@ -141,10 +141,11 @@ def preview_description(self, state: dict) -> str: return "Generate AI Music with text instruction prompts. AudiLDM is capable of generating realistic audio samples by process any text input. Learn more [here](https://huggingface.co/cvssp/audioldm-m-full)." -def _render_output(self, state): +def _render_output(state): selected_models = state.get("selected_models", []) for key in selected_models: output: dict = state.get("output_audios", {}).get(key, []) for audio in output: - st.audio(audio, caption=Text2AudioModels[key].value) - self.render_buttons(audio) + st.audio( + audio, caption=Text2AudioModels[key].value, show_download_button=True + ) diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index b90bb8a4b..7494c992d 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -131,13 +131,8 @@ def render_usage_guide(self): # loom_video("2d853b7442874b9cbbf3f27b98594add") def render_output(self): - text_prompt = st.session_state.get("text_prompt", "") audio_url = st.session_state.get("audio_url") - if audio_url: - st.audio(audio_url) - self.render_buttons(audio_url) - else: - st.div() + st.audio(audio_url, show_download_button=True) def _get_elevenlabs_price(self, state: dict): _, is_user_provided_key = self._get_elevenlabs_api_key(state) From dc7056900c6c4569a0e6a204f501f7fb92940c07 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 12 Feb 2024 13:04:00 +0530 Subject: [PATCH 56/58] fix expand gdrive folders and accept_multiple_files clash --- daras_ai_v2/doc_search_settings_widgets.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index a64572b99..fb69f22a9 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -78,11 +78,12 @@ def document_uploader( ) st.checkbox("Submit Links in Bulk", key=f"__custom_checkbox_{key}") documents = st.session_state.get(key, []) - try: - documents = list(_expand_gdrive_folders(documents)) - except Exception as e: - capture_exception(e) - st.error(f"Error expanding gdrive folders: {e}") + if accept_multiple_files: + try: + documents = list(_expand_gdrive_folders(documents)) + except Exception as e: + capture_exception(e) + st.error(f"Error expanding gdrive folders: {e}") st.session_state[key] = documents st.session_state[custom_key] = "\n".join(documents) return documents From 6cd5617e898a57528b15eaf1440d37a5e6ca11b0 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 12 Feb 2024 13:24:47 +0530 Subject: [PATCH 57/58] add is_api_call to savedrun show api call in breadcrumbs --- bots/migrations/0059_savedrun_is_api_call.py | 18 +++++++++++++++ bots/models.py | 2 ++ celeryapp/tasks.py | 1 + daras_ai_v2/base.py | 23 ++++++++----------- daras_ai_v2/breadcrumbs.py | 5 +++- routers/api.py | 2 +- .../0004_alter_modelpricing_model_name.py | 18 +++++++++++++++ 7 files changed, 54 insertions(+), 15 deletions(-) create mode 100644 bots/migrations/0059_savedrun_is_api_call.py create mode 100644 usage_costs/migrations/0004_alter_modelpricing_model_name.py diff --git a/bots/migrations/0059_savedrun_is_api_call.py b/bots/migrations/0059_savedrun_is_api_call.py new file mode 100644 index 000000000..dc93057fc --- /dev/null +++ b/bots/migrations/0059_savedrun_is_api_call.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-02-12 07:54 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0058_alter_savedrun_unique_together_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='savedrun', + name='is_api_call', + field=models.BooleanField(default=False), + ), + ] diff --git a/bots/models.py b/bots/models.py index 558512e8c..e2c66148a 100644 --- a/bots/models.py +++ b/bots/models.py @@ -254,6 +254,8 @@ class SavedRun(models.Model): page_title = models.TextField(default="", blank=True, help_text="(Deprecated)") page_notes = models.TextField(default="", blank=True, help_text="(Deprecated)") + is_api_call = models.BooleanField(default=False) + objects = SavedRunQuerySet.as_manager() class Meta: diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 1868c7a18..5e1690b31 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -32,6 +32,7 @@ def gui_runner( ): page = page_cls(request=SimpleNamespace(user=AppUser.objects.get(id=user_id))) sr = page.run_doc_sr(run_id, uid) + sr.is_api_call = is_api_call st.set_session_state(state) run_time = 0 diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 0a885bf0a..906b73a08 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -243,7 +243,7 @@ def _render_header(self): if tbreadcrumbs: with st.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): - render_breadcrumbs(tbreadcrumbs) + render_breadcrumbs(tbreadcrumbs, current_run.is_api_call) author = self.run_user or current_run.get_creator() if not is_root_example: @@ -996,9 +996,7 @@ def get_or_create_root_published_run(cls) -> PublishedRun: workflow=cls.workflow, published_run_id="", defaults={ - "saved_run": lambda: cls.run_doc_sr( - run_id="", uid="", create=True, parent=None, parent_version=None - ), + "saved_run": lambda: cls.run_doc_sr(run_id="", uid="", create=True), "created_by": None, "last_edited_by": None, "title": cls.title, @@ -1022,15 +1020,11 @@ def run_doc_sr( run_id: str, uid: str, create: bool = False, - parent: SavedRun | None = None, - parent_version: PublishedRunVersion | None = None, + defaults: dict = None, ) -> SavedRun: config = dict(workflow=cls.workflow, uid=uid, run_id=run_id) if create: - return SavedRun.objects.get_or_create( - **config, - defaults=dict(parent=parent, parent_version=parent_version), - )[0] + return SavedRun.objects.get_or_create(**config, defaults=defaults)[0] else: return SavedRun.objects.get(**config) @@ -1408,7 +1402,7 @@ def should_submit_after_login(self) -> bool: and not self.request.user.is_anonymous ) - def create_new_run(self): + def create_new_run(self, is_api_call: bool = False): st.session_state[StateKeys.run_status] = "Starting..." st.session_state.pop(StateKeys.error_msg, None) st.session_state.pop(StateKeys.run_time, None) @@ -1443,8 +1437,11 @@ def create_new_run(self): run_id, uid, create=True, - parent=parent, - parent_version=parent_version, + defaults=dict( + parent=parent, + parent_version=parent_version, + is_api_call=is_api_call, + ), ).set(self.state_to_doc(st.session_state)) return None, run_id, uid diff --git a/daras_ai_v2/breadcrumbs.py b/daras_ai_v2/breadcrumbs.py index 96dc919e5..62ea058dd 100644 --- a/daras_ai_v2/breadcrumbs.py +++ b/daras_ai_v2/breadcrumbs.py @@ -31,7 +31,7 @@ def has_breadcrumbs(self): return bool(self.root_title or self.published_title) -def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs): +def render_breadcrumbs(breadcrumbs: TitleBreadCrumbs, is_api_call: bool = False): st.html( """