diff --git a/app_users/admin.py b/app_users/admin.py index caa61d223..56f325fca 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -4,7 +4,8 @@ from app_users import models from bots.admin_links import open_in_new_tab, list_related_html_url -from bots.models import SavedRun +from bots.models import SavedRun, PublishedRun +from embeddings.models import EmbeddedFile from usage_costs.models import UsageCost @@ -31,7 +32,12 @@ class AppUserAdmin(admin.ModelAdmin): "is_paying", "disable_safety_checker", "disable_rate_limits", - ("user_runs", "view_transactions"), + ( + "view_saved_runs", + "view_published_runs", + "view_embedded_files", + "view_transactions", + ), "created_at", "upgraded_from_anonymous_at", ("open_in_firebase", "open_in_stripe"), @@ -86,7 +92,9 @@ class AppUserAdmin(admin.ModelAdmin): "total_usage_cost", "created_at", "upgraded_from_anonymous_at", - "user_runs", + "view_saved_runs", + "view_published_runs", + "view_embedded_files", "view_transactions", "open_in_firebase", "open_in_stripe", @@ -95,7 +103,7 @@ class AppUserAdmin(admin.ModelAdmin): autocomplete_fields = ["handle", "subscription"] @admin.display(description="User Runs") - def user_runs(self, user: models.AppUser): + def view_saved_runs(self, user: models.AppUser): return list_related_html_url( SavedRun.objects.filter(uid=user.uid), query_param="uid", @@ -103,6 +111,24 @@ def user_runs(self, user: models.AppUser): show_add=False, ) + @admin.display(description="Published Runs") + def view_published_runs(self, user: models.AppUser): + return list_related_html_url( + PublishedRun.objects.filter(created_by=user), + query_param="created_by", + instance_id=user.id, + show_add=False, + ) + + @admin.display(description="Embedded Files") + def view_embedded_files(self, user: models.AppUser): + return list_related_html_url( + EmbeddedFile.objects.filter(created_by=user), + query_param="created_by", + instance_id=user.id, + show_add=False, + ) + @admin.display(description="Total Payments") def total_payments(self, user: models.AppUser): return "$" + str( diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 55b723124..095268d49 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -15,11 +15,14 @@ import numpy as np import requests from django.db import transaction +from django.db.models import F +from django.utils import timezone from furl import furl from loguru import logger from pydantic import BaseModel, Field import gooey_ui as gui +from app_users.models import AppUser from daras_ai.image_input import ( upload_file_from_bytes, safe_filename, @@ -112,8 +115,7 @@ def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str: def get_top_k_references( - request: DocSearchRequest, - is_user_url: bool = True, + request: DocSearchRequest, is_user_url: bool = True, current_user: AppUser = None ) -> typing.Generator[str, None, list[SearchReference]]: """ Get the top k documents that ref the search query @@ -121,6 +123,7 @@ def get_top_k_references( Args: request: the document search request is_user_url: whether the url is user-uploaded + current_user: the current user Returns: the top k documents @@ -148,8 +151,8 @@ def get_top_k_references( EmbeddedFile._meta.get_field("embedding_model").default ), ) - embedding_refs: list[EmbeddedFile] = map_parallel( - lambda f_url, file_meta: get_or_create_embeddings( + embedded_files: list[EmbeddedFile] = map_parallel( + lambda f_url, file_meta: get_or_create_embedded_file( f_url=f_url, file_meta=file_meta, max_context_words=request.max_context_words, @@ -158,20 +161,26 @@ def get_top_k_references( selected_asr_model=selected_asr_model, embedding_model=embedding_model, is_user_url=is_user_url, + current_user=current_user, ), file_urls, file_metas, max_workers=4, ) - if not embedding_refs: + if not embedded_files: yield "No embeddings found - skipping search" return [] - vespa_file_ids = [ref.vespa_file_id for ref in embedding_refs] + yield "Searching knowledge base..." + + vespa_file_ids = [ref.vespa_file_id for ref in embedded_files] + EmbeddedFile.objects.filter(id__in=[ref.id for ref in embedded_files]).update( + query_count=F("query_count") + 1, + last_query_at=timezone.now(), + ) # chunk_count = sum(len(ref.document_ids) for ref in embedding_refs) # logger.debug(f"Knowledge base has {len(file_ids)} documents ({chunk_count} chunks)") - yield "Searching knowledge base..." s = time() search_result = query_vespa( request.search_query, @@ -361,16 +370,17 @@ def yt_dlp_extract_info(url: str) -> dict: return data -def get_or_create_embeddings( +def get_or_create_embedded_file( + *, f_url: str, file_meta: FileMetadata, - *, max_context_words: int, scroll_jump: int, google_translate_target: str | None, selected_asr_model: str | None, embedding_model: EmbeddingModels, is_user_url: bool, + current_user: AppUser, ) -> EmbeddedFile: """ Return Vespa document ids and document tags @@ -408,7 +418,11 @@ def get_or_create_embeddings( file_meta.save() embedded_file = EmbeddedFile.objects.get_or_create( **lookup, - defaults=dict(metadata=file_meta, vespa_file_id=file_id), + defaults=dict( + metadata=file_meta, + vespa_file_id=file_id, + created_by=current_user, + ), )[0] for ref in refs: ref.embedded_file = embedded_file diff --git a/embeddings/admin.py b/embeddings/admin.py index 1bc9d2a91..40322a3ab 100644 --- a/embeddings/admin.py +++ b/embeddings/admin.py @@ -15,7 +15,9 @@ class EmbeddedFileAdmin(admin.ModelAdmin): "google_translate_target", "selected_asr_model", "vespa_file_id", + "query_count", "created_at", + "last_query_at", ] search_fields = [ "url", @@ -31,6 +33,7 @@ class EmbeddedFileAdmin(admin.ModelAdmin): "selected_asr_model", "created_at", "updated_at", + "last_query_at", ] readonly_fields = [ "metadata", @@ -38,7 +41,10 @@ class EmbeddedFileAdmin(admin.ModelAdmin): "created_at", "updated_at", "view_embeds", + "query_count", + "last_query_at", ] + autocomplete_fields = ["created_by"] ordering = ["-created_at"] @admin.display(description="View Embeds") diff --git a/embeddings/migrations/0002_embeddedfile_created_by_embeddedfile_last_query_at_and_more.py b/embeddings/migrations/0002_embeddedfile_created_by_embeddedfile_last_query_at_and_more.py new file mode 100644 index 000000000..90666b31e --- /dev/null +++ b/embeddings/migrations/0002_embeddedfile_created_by_embeddedfile_last_query_at_and_more.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.7 on 2024-07-25 10:35 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0019_alter_appusertransaction_reason'), + ('embeddings', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='embeddedfile', + name='created_by', + field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='embedded_files', to='app_users.appuser'), + ), + migrations.AddField( + model_name='embeddedfile', + name='last_query_at', + field=models.DateTimeField(blank=True, db_index=True, default=None, null=True), + ), + migrations.AddField( + model_name='embeddedfile', + name='query_count', + field=models.PositiveIntegerField(db_index=True, default=0), + ), + ] diff --git a/embeddings/models.py b/embeddings/models.py index 69a6cf703..f30b1a0bf 100644 --- a/embeddings/models.py +++ b/embeddings/models.py @@ -7,6 +7,15 @@ class EmbeddedFile(models.Model): + created_by = models.ForeignKey( + "app_users.AppUser", + on_delete=models.SET_NULL, + null=True, + blank=True, + default=None, + related_name="embedded_files", + ) + url = CustomURLField(help_text="The URL of the original resource (e.g. a document)") metadata = models.ForeignKey( "files.FileMetadata", @@ -32,6 +41,11 @@ class EmbeddedFile(models.Model): created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) + query_count = models.PositiveIntegerField(default=0, db_index=True) + last_query_at = models.DateTimeField( + null=True, blank=True, default=None, db_index=True + ) + def __str__(self): return f"{self.url} ({self.metadata})" diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index b6c1af301..a17ec1af1 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -175,6 +175,7 @@ def run_v2( "search_query": response.final_search_query, }, ), + current_user=self.request.user, ) # empty search result, abort! diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index c1214bc26..7f451ea13 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -254,6 +254,7 @@ def run_v2( }, ), is_user_url=False, + current_user=self.request.user, ) # add pretty titles to references for ref in response.references: diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index b3218d9ed..4028f033b 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -882,6 +882,7 @@ def run_v2( "keyword_query": response.final_keyword_query, }, ), + current_user=self.request.user, ) if request.use_url_shortener: for reference in response.references: diff --git a/usage_costs/migrations/0016_alter_modelpricing_model_name.py b/usage_costs/migrations/0016_alter_modelpricing_model_name.py new file mode 100644 index 000000000..c59ca921f --- /dev/null +++ b/usage_costs/migrations/0016_alter_modelpricing_model_name.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-07-25 10:35 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('usage_costs', '0015_alter_modelpricing_model_name_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='modelpricing', + name='model_name', + field=models.CharField(choices=[('gpt_4_o', 'GPT-4o (openai)'), ('gpt_4_o_mini', 'GPT-4o-mini (openai)'), ('gpt_4_turbo_vision', 'GPT-4 Turbo with Vision (openai)'), ('gpt_4_vision', 'GPT-4 Vision (openai) 🔻'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai) 🔻'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('gpt_3_5_turbo_instruct', 'GPT-3.5 Instruct (openai) 🔻'), ('llama3_70b', 'Llama 3 70b (Meta AI)'), ('llama_3_groq_70b_tool_use', 'Llama 3 Groq 70b Tool Use'), ('llama3_8b', 'Llama 3 8b (Meta AI)'), ('llama_3_groq_8b_tool_use', 'Llama 3 Groq 8b Tool Use'), ('llama2_70b_chat', 'Llama 2 70b Chat [Deprecated] (Meta AI)'), ('mixtral_8x7b_instruct_0_1', 'Mixtral 8x7b Instruct v0.1 (Mistral)'), ('gemma_2_9b_it', 'Gemma 2 9B (Google)'), ('gemma_7b_it', 'Gemma 7B (Google)'), ('gemini_1_5_pro', 'Gemini 1.5 Pro (Google)'), ('gemini_1_pro_vision', 'Gemini 1.0 Pro Vision (Google)'), ('gemini_1_pro', 'Gemini 1.0 Pro (Google)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('claude_3_5_sonnet', 'Claude 3.5 Sonnet (Anthropic)'), ('claude_3_opus', 'Claude 3 Opus [L] (Anthropic)'), ('claude_3_sonnet', 'Claude 3 Sonnet [M] (Anthropic)'), ('claude_3_haiku', 'Claude 3 Haiku [S] (Anthropic)'), ('sea_lion_7b_instruct', 'SEA-LION-7B-Instruct (aisingapore)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 [Deprecated] (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 [Deprecated] (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('text_curie_001', 'Curie [Deprecated] (openai)'), ('text_babbage_001', 'Babbage [Deprecated] (openai)'), ('text_ada_001', 'Ada [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'DALL·E 2 (OpenAI)'), ('dall_e_3', 'DALL·E 3 (OpenAI)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero)'), ('openjourney', 'Open Journey (PromptHero)'), ('analog_diffusion', 'Analog Diffusion (wavymulder)'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('deepfloyd_if', 'DeepFloyd IF [Deprecated] (stability.ai)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('instruct_pix2pix', '✨ InstructPix2Pix (Tim Brooks)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐢'), ('openjourney', 'Open Journey (PromptHero) 🐢'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐢'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐢'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('runway_ml', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('wav2lip', 'LipSync (wav2lip)')], help_text='The name of the model. Only used for Display purposes.', max_length=255), + ), + ]