Skip to content

Commit

Permalink
Enhance vector search functionality to track user interaction
Browse files Browse the repository at this point in the history
Add `current_user` parameter to `get_top_k_references` for user-specific embedding tracking. Update `EmbeddedFile` model to include `created_by`, `query_count`, and `last_query_at`. Modify related functions to increment query counts and log last query time accordingly. Update various pages to pass current user context during document searches. Add embedded file references to user admin.
  • Loading branch information
devxpy committed Jul 25, 2024
1 parent 0c48fd8 commit 6a1b0d5
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 14 deletions.
34 changes: 30 additions & 4 deletions app_users/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from app_users import models
from bots.admin_links import open_in_new_tab, list_related_html_url
from bots.models import SavedRun
from bots.models import SavedRun, PublishedRun
from embeddings.models import EmbeddedFile
from usage_costs.models import UsageCost


Expand All @@ -31,7 +32,12 @@ class AppUserAdmin(admin.ModelAdmin):
"is_paying",
"disable_safety_checker",
"disable_rate_limits",
("user_runs", "view_transactions"),
(
"view_saved_runs",
"view_published_runs",
"view_embedded_files",
"view_transactions",
),
"created_at",
"upgraded_from_anonymous_at",
("open_in_firebase", "open_in_stripe"),
Expand Down Expand Up @@ -86,7 +92,9 @@ class AppUserAdmin(admin.ModelAdmin):
"total_usage_cost",
"created_at",
"upgraded_from_anonymous_at",
"user_runs",
"view_saved_runs",
"view_published_runs",
"view_embedded_files",
"view_transactions",
"open_in_firebase",
"open_in_stripe",
Expand All @@ -95,14 +103,32 @@ class AppUserAdmin(admin.ModelAdmin):
autocomplete_fields = ["handle", "subscription"]

@admin.display(description="User Runs")
def user_runs(self, user: models.AppUser):
def view_saved_runs(self, user: models.AppUser):
return list_related_html_url(
SavedRun.objects.filter(uid=user.uid),
query_param="uid",
instance_id=user.uid,
show_add=False,
)

@admin.display(description="Published Runs")
def view_published_runs(self, user: models.AppUser):
return list_related_html_url(
PublishedRun.objects.filter(created_by=user),
query_param="created_by",
instance_id=user.id,
show_add=False,
)

@admin.display(description="Embedded Files")
def view_embedded_files(self, user: models.AppUser):
return list_related_html_url(
EmbeddedFile.objects.filter(created_by=user),
query_param="created_by",
instance_id=user.id,
show_add=False,
)

@admin.display(description="Total Payments")
def total_payments(self, user: models.AppUser):
return "$" + str(
Expand Down
34 changes: 24 additions & 10 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import numpy as np
import requests
from django.db import transaction
from django.db.models import F
from django.utils import timezone
from furl import furl
from loguru import logger
from pydantic import BaseModel, Field

import gooey_ui as gui
from app_users.models import AppUser
from daras_ai.image_input import (
upload_file_from_bytes,
safe_filename,
Expand Down Expand Up @@ -112,15 +115,15 @@ def references_as_prompt(references: list[SearchReference], sep="\n\n") -> str:


def get_top_k_references(
request: DocSearchRequest,
is_user_url: bool = True,
request: DocSearchRequest, is_user_url: bool = True, current_user: AppUser = None
) -> typing.Generator[str, None, list[SearchReference]]:
"""
Get the top k documents that ref the search query
Args:
request: the document search request
is_user_url: whether the url is user-uploaded
current_user: the current user
Returns:
the top k documents
Expand Down Expand Up @@ -148,8 +151,8 @@ def get_top_k_references(
EmbeddedFile._meta.get_field("embedding_model").default
),
)
embedding_refs: list[EmbeddedFile] = map_parallel(
lambda f_url, file_meta: get_or_create_embeddings(
embedded_files: list[EmbeddedFile] = map_parallel(
lambda f_url, file_meta: get_or_create_embedded_file(
f_url=f_url,
file_meta=file_meta,
max_context_words=request.max_context_words,
Expand All @@ -158,20 +161,26 @@ def get_top_k_references(
selected_asr_model=selected_asr_model,
embedding_model=embedding_model,
is_user_url=is_user_url,
current_user=current_user,
),
file_urls,
file_metas,
max_workers=4,
)
if not embedding_refs:
if not embedded_files:
yield "No embeddings found - skipping search"
return []

vespa_file_ids = [ref.vespa_file_id for ref in embedding_refs]
yield "Searching knowledge base..."

vespa_file_ids = [ref.vespa_file_id for ref in embedded_files]
EmbeddedFile.objects.filter(id__in=[ref.id for ref in embedded_files]).update(
query_count=F("query_count") + 1,
last_query_at=timezone.now(),
)
# chunk_count = sum(len(ref.document_ids) for ref in embedding_refs)
# logger.debug(f"Knowledge base has {len(file_ids)} documents ({chunk_count} chunks)")

yield "Searching knowledge base..."
s = time()
search_result = query_vespa(
request.search_query,
Expand Down Expand Up @@ -361,16 +370,17 @@ def yt_dlp_extract_info(url: str) -> dict:
return data


def get_or_create_embeddings(
def get_or_create_embedded_file(
*,
f_url: str,
file_meta: FileMetadata,
*,
max_context_words: int,
scroll_jump: int,
google_translate_target: str | None,
selected_asr_model: str | None,
embedding_model: EmbeddingModels,
is_user_url: bool,
current_user: AppUser,
) -> EmbeddedFile:
"""
Return Vespa document ids and document tags
Expand Down Expand Up @@ -408,7 +418,11 @@ def get_or_create_embeddings(
file_meta.save()
embedded_file = EmbeddedFile.objects.get_or_create(
**lookup,
defaults=dict(metadata=file_meta, vespa_file_id=file_id),
defaults=dict(
metadata=file_meta,
vespa_file_id=file_id,
created_by=current_user,
),
)[0]
for ref in refs:
ref.embedded_file = embedded_file
Expand Down
6 changes: 6 additions & 0 deletions embeddings/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class EmbeddedFileAdmin(admin.ModelAdmin):
"google_translate_target",
"selected_asr_model",
"vespa_file_id",
"query_count",
"created_at",
"last_query_at",
]
search_fields = [
"url",
Expand All @@ -31,14 +33,18 @@ class EmbeddedFileAdmin(admin.ModelAdmin):
"selected_asr_model",
"created_at",
"updated_at",
"last_query_at",
]
readonly_fields = [
"metadata",
"vespa_file_id",
"created_at",
"updated_at",
"view_embeds",
"query_count",
"last_query_at",
]
autocomplete_fields = ["created_by"]
ordering = ["-created_at"]

@admin.display(description="View Embeds")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Generated by Django 4.2.7 on 2024-07-25 10:35

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
('app_users', '0019_alter_appusertransaction_reason'),
('embeddings', '0001_initial'),
]

operations = [
migrations.AddField(
model_name='embeddedfile',
name='created_by',
field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='embedded_files', to='app_users.appuser'),
),
migrations.AddField(
model_name='embeddedfile',
name='last_query_at',
field=models.DateTimeField(blank=True, db_index=True, default=None, null=True),
),
migrations.AddField(
model_name='embeddedfile',
name='query_count',
field=models.PositiveIntegerField(db_index=True, default=0),
),
]
14 changes: 14 additions & 0 deletions embeddings/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@


class EmbeddedFile(models.Model):
created_by = models.ForeignKey(
"app_users.AppUser",
on_delete=models.SET_NULL,
null=True,
blank=True,
default=None,
related_name="embedded_files",
)

url = CustomURLField(help_text="The URL of the original resource (e.g. a document)")
metadata = models.ForeignKey(
"files.FileMetadata",
Expand All @@ -32,6 +41,11 @@ class EmbeddedFile(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)

query_count = models.PositiveIntegerField(default=0, db_index=True)
last_query_at = models.DateTimeField(
null=True, blank=True, default=None, db_index=True
)

def __str__(self):
return f"{self.url} ({self.metadata})"

Expand Down
1 change: 1 addition & 0 deletions recipes/DocSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def run_v2(
"search_query": response.final_search_query,
},
),
current_user=self.request.user,
)

# empty search result, abort!
Expand Down
1 change: 1 addition & 0 deletions recipes/GoogleGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def run_v2(
},
),
is_user_url=False,
current_user=self.request.user,
)
# add pretty titles to references
for ref in response.references:
Expand Down
1 change: 1 addition & 0 deletions recipes/VideoBots.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ def run_v2(
"keyword_query": response.final_keyword_query,
},
),
current_user=self.request.user,
)
if request.use_url_shortener:
for reference in response.references:
Expand Down
18 changes: 18 additions & 0 deletions usage_costs/migrations/0016_alter_modelpricing_model_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.7 on 2024-07-25 10:35

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('usage_costs', '0015_alter_modelpricing_model_name_and_more'),
]

operations = [
migrations.AlterField(
model_name='modelpricing',
name='model_name',
field=models.CharField(choices=[('gpt_4_o', 'GPT-4o (openai)'), ('gpt_4_o_mini', 'GPT-4o-mini (openai)'), ('gpt_4_turbo_vision', 'GPT-4 Turbo with Vision (openai)'), ('gpt_4_vision', 'GPT-4 Vision (openai) 🔻'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai) 🔻'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('gpt_3_5_turbo_instruct', 'GPT-3.5 Instruct (openai) 🔻'), ('llama3_70b', 'Llama 3 70b (Meta AI)'), ('llama_3_groq_70b_tool_use', 'Llama 3 Groq 70b Tool Use'), ('llama3_8b', 'Llama 3 8b (Meta AI)'), ('llama_3_groq_8b_tool_use', 'Llama 3 Groq 8b Tool Use'), ('llama2_70b_chat', 'Llama 2 70b Chat [Deprecated] (Meta AI)'), ('mixtral_8x7b_instruct_0_1', 'Mixtral 8x7b Instruct v0.1 (Mistral)'), ('gemma_2_9b_it', 'Gemma 2 9B (Google)'), ('gemma_7b_it', 'Gemma 7B (Google)'), ('gemini_1_5_pro', 'Gemini 1.5 Pro (Google)'), ('gemini_1_pro_vision', 'Gemini 1.0 Pro Vision (Google)'), ('gemini_1_pro', 'Gemini 1.0 Pro (Google)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('claude_3_5_sonnet', 'Claude 3.5 Sonnet (Anthropic)'), ('claude_3_opus', 'Claude 3 Opus [L] (Anthropic)'), ('claude_3_sonnet', 'Claude 3 Sonnet [M] (Anthropic)'), ('claude_3_haiku', 'Claude 3 Haiku [S] (Anthropic)'), ('sea_lion_7b_instruct', 'SEA-LION-7B-Instruct (aisingapore)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 [Deprecated] (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 [Deprecated] (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('text_curie_001', 'Curie [Deprecated] (openai)'), ('text_babbage_001', 'Babbage [Deprecated] (openai)'), ('text_ada_001', 'Ada [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'DALL·E 2 (OpenAI)'), ('dall_e_3', 'DALL·E 3 (OpenAI)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero)'), ('openjourney', 'Open Journey (PromptHero)'), ('analog_diffusion', 'Analog Diffusion (wavymulder)'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('deepfloyd_if', 'DeepFloyd IF [Deprecated] (stability.ai)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('instruct_pix2pix', '✨ InstructPix2Pix (Tim Brooks)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐢'), ('openjourney', 'Open Journey (PromptHero) 🐢'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐢'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐢'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('runway_ml', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('wav2lip', 'LipSync (wav2lip)')], help_text='The name of the model. Only used for Display purposes.', max_length=255),
),
]

0 comments on commit 6a1b0d5

Please sign in to comment.