From 6a1b0d5ef62f1ee48fafe70e4ab8a4fe09c32192 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 25 Jul 2024 15:02:12 +0530 Subject: [PATCH] Enhance vector search functionality to track user interaction Add `current_user` parameter to `get_top_k_references` for user-specific embedding tracking. Update `EmbeddedFile` model to include `created_by`, `query_count`, and `last_query_at`. Modify related functions to increment query counts and log last query time accordingly. Update various pages to pass current user context during document searches. Add embedded file references to user admin. --- app_users/admin.py | 34 ++++++++++++++++--- daras_ai_v2/vector_search.py | 34 +++++++++++++------ embeddings/admin.py | 6 ++++ ..._by_embeddedfile_last_query_at_and_more.py | 30 ++++++++++++++++ embeddings/models.py | 14 ++++++++ recipes/DocSearch.py | 1 + recipes/GoogleGPT.py | 1 + recipes/VideoBots.py | 1 + .../0016_alter_modelpricing_model_name.py | 18 ++++++++++ 9 files changed, 125 insertions(+), 14 deletions(-) create mode 100644 embeddings/migrations/0002_embeddedfile_created_by_embeddedfile_last_query_at_and_more.py create mode 100644 usage_costs/migrations/0016_alter_modelpricing_model_name.py diff --git a/app_users/admin.py b/app_users/admin.py index caa61d223..56f325fca 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -4,7 +4,8 @@ from app_users import models from bots.admin_links import open_in_new_tab, list_related_html_url -from bots.models import SavedRun +from bots.models import SavedRun, PublishedRun +from embeddings.models import EmbeddedFile from usage_costs.models import UsageCost @@ -31,7 +32,12 @@ class AppUserAdmin(admin.ModelAdmin): "is_paying", "disable_safety_checker", "disable_rate_limits", - ("user_runs", "view_transactions"), + ( + "view_saved_runs", + "view_published_runs", + "view_embedded_files", + "view_transactions", + ), "created_at", "upgraded_from_anonymous_at", ("open_in_firebase", "open_in_stripe"), @@ -86,7 +92,9 @@ class AppUserAdmin(admin.ModelAdmin): "total_usage_cost", "created_at", "upgraded_from_anonymous_at", - "user_runs", + "view_saved_runs", + "view_published_runs", + "view_embedded_files", "view_transactions", "open_in_firebase", "open_in_stripe", @@ -95,7 +103,7 @@ class AppUserAdmin(admin.ModelAdmin): autocomplete_fields = ["handle", "subscription"] @admin.display(description="User Runs") - def user_runs(self, user: models.AppUser): + def view_saved_runs(self, user: models.AppUser): return list_related_html_url( SavedRun.objects.filter(uid=user.uid), query_param="uid", @@ -103,6 +111,24 @@ def user_runs(self, user: models.AppUser): show_add=False, ) + @admin.display(description="Published Runs") + def view_published_runs(self, user: models.AppUser): + return list_related_html_url( + PublishedRun.objects.filter(created_by=user), + query_param="created_by", + instance_id=user.id, + show_add=False, + ) + + @admin.display(description="Embedded Files") + def view_embedded_files(self, user: models.AppUser): + return list_related_html_url( + EmbeddedFile.objects.filter(created_by=user), + query_param="created_by", + instance_id=user.id, + show_add=False, + ) + @admin.display(description="Total Payments") def total_payments(self, user: models.AppUser): return "$" + str( diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 55b723124..095268d49 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -15,11 +15,14 @@ import numpy as np import requests from django.db import transaction +from django.db.models import F +from django.utils import timezone from furl import furl from loguru import logger from pydantic import BaseModel, Field import gooey_ui as gui +from app_users.models import AppUser from daras_ai.image_input import ( upload_file_from_bytes, safe_filename, @@ -112,8 +115,7 @@ def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str: def get_top_k_references( - request: DocSearchRequest, - is_user_url: bool = True, + request: DocSearchRequest, is_user_url: bool = True, current_user: AppUser = None ) -> typing.Generator[str, None, list[SearchReference]]: """ Get the top k documents that ref the search query @@ -121,6 +123,7 @@ def get_top_k_references( Args: request: the document search request is_user_url: whether the url is user-uploaded + current_user: the current user Returns: the top k documents @@ -148,8 +151,8 @@ def get_top_k_references( EmbeddedFile._meta.get_field("embedding_model").default ), ) - embedding_refs: list[EmbeddedFile] = map_parallel( - lambda f_url, file_meta: get_or_create_embeddings( + embedded_files: list[EmbeddedFile] = map_parallel( + lambda f_url, file_meta: get_or_create_embedded_file( f_url=f_url, file_meta=file_meta, max_context_words=request.max_context_words, @@ -158,20 +161,26 @@ def get_top_k_references( selected_asr_model=selected_asr_model, embedding_model=embedding_model, is_user_url=is_user_url, + current_user=current_user, ), file_urls, file_metas, max_workers=4, ) - if not embedding_refs: + if not embedded_files: yield "No embeddings found - skipping search" return [] - vespa_file_ids = [ref.vespa_file_id for ref in embedding_refs] + yield "Searching knowledge base..." + + vespa_file_ids = [ref.vespa_file_id for ref in embedded_files] + EmbeddedFile.objects.filter(id__in=[ref.id for ref in embedded_files]).update( + query_count=F("query_count") + 1, + last_query_at=timezone.now(), + ) # chunk_count = sum(len(ref.document_ids) for ref in embedding_refs) # logger.debug(f"Knowledge base has {len(file_ids)} documents ({chunk_count} chunks)") - yield "Searching knowledge base..." s = time() search_result = query_vespa( request.search_query, @@ -361,16 +370,17 @@ def yt_dlp_extract_info(url: str) -> dict: return data -def get_or_create_embeddings( +def get_or_create_embedded_file( + *, f_url: str, file_meta: FileMetadata, - *, max_context_words: int, scroll_jump: int, google_translate_target: str | None, selected_asr_model: str | None, embedding_model: EmbeddingModels, is_user_url: bool, + current_user: AppUser, ) -> EmbeddedFile: """ Return Vespa document ids and document tags @@ -408,7 +418,11 @@ def get_or_create_embeddings( file_meta.save() embedded_file = EmbeddedFile.objects.get_or_create( **lookup, - defaults=dict(metadata=file_meta, vespa_file_id=file_id), + defaults=dict( + metadata=file_meta, + vespa_file_id=file_id, + created_by=current_user, + ), )[0] for ref in refs: ref.embedded_file = embedded_file diff --git a/embeddings/admin.py b/embeddings/admin.py index 1bc9d2a91..40322a3ab 100644 --- a/embeddings/admin.py +++ b/embeddings/admin.py @@ -15,7 +15,9 @@ class EmbeddedFileAdmin(admin.ModelAdmin): "google_translate_target", "selected_asr_model", "vespa_file_id", + "query_count", "created_at", + "last_query_at", ] search_fields = [ "url", @@ -31,6 +33,7 @@ class EmbeddedFileAdmin(admin.ModelAdmin): "selected_asr_model", "created_at", "updated_at", + "last_query_at", ] readonly_fields = [ "metadata", @@ -38,7 +41,10 @@ class EmbeddedFileAdmin(admin.ModelAdmin): "created_at", "updated_at", "view_embeds", + "query_count", + "last_query_at", ] + autocomplete_fields = ["created_by"] ordering = ["-created_at"] @admin.display(description="View Embeds") diff --git a/embeddings/migrations/0002_embeddedfile_created_by_embeddedfile_last_query_at_and_more.py b/embeddings/migrations/0002_embeddedfile_created_by_embeddedfile_last_query_at_and_more.py new file mode 100644 index 000000000..90666b31e --- /dev/null +++ b/embeddings/migrations/0002_embeddedfile_created_by_embeddedfile_last_query_at_and_more.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.7 on 2024-07-25 10:35 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0019_alter_appusertransaction_reason'), + ('embeddings', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='embeddedfile', + name='created_by', + field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='embedded_files', to='app_users.appuser'), + ), + migrations.AddField( + model_name='embeddedfile', + name='last_query_at', + field=models.DateTimeField(blank=True, db_index=True, default=None, null=True), + ), + migrations.AddField( + model_name='embeddedfile', + name='query_count', + field=models.PositiveIntegerField(db_index=True, default=0), + ), + ] diff --git a/embeddings/models.py b/embeddings/models.py index 69a6cf703..f30b1a0bf 100644 --- a/embeddings/models.py +++ b/embeddings/models.py @@ -7,6 +7,15 @@ class EmbeddedFile(models.Model): + created_by = models.ForeignKey( + "app_users.AppUser", + on_delete=models.SET_NULL, + null=True, + blank=True, + default=None, + related_name="embedded_files", + ) + url = CustomURLField(help_text="The URL of the original resource (e.g. a document)") metadata = models.ForeignKey( "files.FileMetadata", @@ -32,6 +41,11 @@ class EmbeddedFile(models.Model): created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) + query_count = models.PositiveIntegerField(default=0, db_index=True) + last_query_at = models.DateTimeField( + null=True, blank=True, default=None, db_index=True + ) + def __str__(self): return f"{self.url} ({self.metadata})" diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index b6c1af301..a17ec1af1 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -175,6 +175,7 @@ def run_v2( "search_query": response.final_search_query, }, ), + current_user=self.request.user, ) # empty search result, abort! diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index c1214bc26..7f451ea13 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -254,6 +254,7 @@ def run_v2( }, ), is_user_url=False, + current_user=self.request.user, ) # add pretty titles to references for ref in response.references: diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index b3218d9ed..4028f033b 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -882,6 +882,7 @@ def run_v2( "keyword_query": response.final_keyword_query, }, ), + current_user=self.request.user, ) if request.use_url_shortener: for reference in response.references: diff --git a/usage_costs/migrations/0016_alter_modelpricing_model_name.py b/usage_costs/migrations/0016_alter_modelpricing_model_name.py new file mode 100644 index 000000000..c59ca921f --- /dev/null +++ b/usage_costs/migrations/0016_alter_modelpricing_model_name.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-07-25 10:35 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('usage_costs', '0015_alter_modelpricing_model_name_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='modelpricing', + name='model_name', + field=models.CharField(choices=[('gpt_4_o', 'GPT-4o (openai)'), ('gpt_4_o_mini', 'GPT-4o-mini (openai)'), ('gpt_4_turbo_vision', 'GPT-4 Turbo with Vision (openai)'), ('gpt_4_vision', 'GPT-4 Vision (openai) 🔻'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai) 🔻'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('gpt_3_5_turbo_instruct', 'GPT-3.5 Instruct (openai) 🔻'), ('llama3_70b', 'Llama 3 70b (Meta AI)'), ('llama_3_groq_70b_tool_use', 'Llama 3 Groq 70b Tool Use'), ('llama3_8b', 'Llama 3 8b (Meta AI)'), ('llama_3_groq_8b_tool_use', 'Llama 3 Groq 8b Tool Use'), ('llama2_70b_chat', 'Llama 2 70b Chat [Deprecated] (Meta AI)'), ('mixtral_8x7b_instruct_0_1', 'Mixtral 8x7b Instruct v0.1 (Mistral)'), ('gemma_2_9b_it', 'Gemma 2 9B (Google)'), ('gemma_7b_it', 'Gemma 7B (Google)'), ('gemini_1_5_pro', 'Gemini 1.5 Pro (Google)'), ('gemini_1_pro_vision', 'Gemini 1.0 Pro Vision (Google)'), ('gemini_1_pro', 'Gemini 1.0 Pro (Google)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('claude_3_5_sonnet', 'Claude 3.5 Sonnet (Anthropic)'), ('claude_3_opus', 'Claude 3 Opus [L] (Anthropic)'), ('claude_3_sonnet', 'Claude 3 Sonnet [M] (Anthropic)'), ('claude_3_haiku', 'Claude 3 Haiku [S] (Anthropic)'), ('sea_lion_7b_instruct', 'SEA-LION-7B-Instruct (aisingapore)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 [Deprecated] (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 [Deprecated] (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('text_curie_001', 'Curie [Deprecated] (openai)'), ('text_babbage_001', 'Babbage [Deprecated] (openai)'), ('text_ada_001', 'Ada [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'DALL·E 2 (OpenAI)'), ('dall_e_3', 'DALL·E 3 (OpenAI)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero)'), ('openjourney', 'Open Journey (PromptHero)'), ('analog_diffusion', 'Analog Diffusion (wavymulder)'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('deepfloyd_if', 'DeepFloyd IF [Deprecated] (stability.ai)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('instruct_pix2pix', '✨ InstructPix2Pix (Tim Brooks)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐢'), ('openjourney', 'Open Journey (PromptHero) 🐢'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐢'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐢'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('runway_ml', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('wav2lip', 'LipSync (wav2lip)')], help_text='The name of the model. Only used for Display purposes.', max_length=255), + ), + ]