Skip to content

Commit

Permalink
feat: add default cache_knowledge to google bucket urls, updated vesp…
Browse files Browse the repository at this point in the history
…a_id lookup
  • Loading branch information
milovate committed Dec 26, 2024
1 parent 6e72eed commit b2efeda
Showing 1 changed file with 95 additions and 40 deletions.
135 changes: 95 additions & 40 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from daras_ai_v2.embedding_model import create_embeddings_cached, EmbeddingModels
from daras_ai_v2.exceptions import raise_for_status, call_cmd, UserError
from daras_ai_v2.functional import (
flatmap_parallel,
map_parallel,
flatmap_parallel_ascompleted,
)
Expand Down Expand Up @@ -84,7 +83,6 @@ class DocSearchRequest(BaseModel):
scroll_jump: int | None

doc_extract_url: str | None

embedding_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None
dense_weight: float | None = Field(
ge=0.0,
Expand Down Expand Up @@ -146,8 +144,31 @@ def get_top_k_references(
else:
selected_asr_model = google_translate_target = None

file_url_metas = flatmap_parallel(doc_or_yt_url_to_metadatas, input_docs)
file_urls, file_metas = zip(*file_url_metas)
# # Filter yt_dlp_able URLs from input_docs
# yt_dlp_able_docs = [doc for doc in input_docs if is_yt_dlp_able_url(doc)]
#
# # Get metadata and file URLs for yt_dlp_able URLs
# if yt_dlp_able_docs:
# # Get metadata and file URLs for yt_dlp_able URLs
# yt_dlp_file_url_metas = flatmap_parallel(
# doc_or_yt_url_to_metadatas, yt_dlp_able_docs
# )
# yt_dlp_file_urls, yt_dlp_file_metas = zip(*yt_dlp_file_url_metas)
# else:
# yt_dlp_file_urls, yt_dlp_file_metas = [], []
#
# # Remove original yt_dlp_able URLs from input_docs and extend with new file URLs
# file_urls = [doc for doc in input_docs if not is_yt_dlp_able_url(doc)] + list(
# yt_dlp_file_urls
# )
#
# # Create a file_meta list filled with None and extend it
# file_metas = [None] * (len(input_docs) - len(yt_dlp_able_docs))
# file_metas.extend(yt_dlp_file_metas)
#
# # todo: move this down into get_embeds remove the flatmap_parallel
# # file_url_metas = flatmap_parallel(doc_or_yt_url_to_metadatas, input_docs)
# # file_urls, file_url_metas = zip(*file_url_metas)

yield "Creating knowledge embeddings..."

Expand All @@ -158,9 +179,8 @@ def get_top_k_references(
),
)
embedded_files: list[EmbeddedFile] = map_parallel(
lambda f_url, file_meta: get_or_create_embedded_file(
lambda f_url: get_or_create_embedded_file(
f_url=f_url,
file_meta=file_meta,
max_context_words=request.max_context_words,
scroll_jump=request.scroll_jump,
google_translate_target=google_translate_target,
Expand All @@ -169,8 +189,7 @@ def get_top_k_references(
is_user_url=is_user_url,
current_user=current_user,
),
file_urls,
file_metas,
input_docs,
max_workers=4,
)
if not embedded_files:
Expand Down Expand Up @@ -270,6 +289,15 @@ def get_vespa_app():


def doc_or_yt_url_to_metadatas(f_url: str) -> list[tuple[str, FileMetadata]]:
"""
Retrieve metadata for a given document or YouTube URL.
Args:
f_url (str): The URL of the document or YouTube video.
Returns:
list[tuple[str, FileMetadata]]: A list of tuples containing the URL and its metadata.
"""
if is_yt_dlp_able_url(f_url):
entries = yt_dlp_get_video_entries(f_url)
return [
Expand Down Expand Up @@ -314,7 +342,13 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata:
else:
try:
if is_user_uploaded_url(f_url):
r = requests.head(f_url)
name = f.path.segments[-1]
return FileMetadata(
name=name,
etag=None,
mime_type=mimetypes.guess_type(name)[0],
total_bytes=0,
)
else:
r = requests.head(
f_url,
Expand Down Expand Up @@ -387,7 +421,6 @@ def yt_dlp_extract_info(url: str) -> dict:
def get_or_create_embedded_file(
*,
f_url: str,
file_meta: FileMetadata,
max_context_words: int,
scroll_jump: int,
google_translate_target: str | None,
Expand All @@ -402,46 +435,68 @@ def get_or_create_embedded_file(
"""
lookup = dict(
url=f_url,
metadata__name=file_meta.name,
metadata__etag=file_meta.etag,
metadata__mime_type=file_meta.mime_type,
metadata__total_bytes=file_meta.total_bytes,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model.name,
)
file_meta = None
file_id = hashlib.sha256(str(lookup).encode()).hexdigest()
with redis_lock(f"gooey/get_or_create_embeddings/v1/{file_id}"):
try:
return EmbeddedFile.objects.filter(**lookup).order_by("-updated_at")[0]
embedded_file = EmbeddedFile.objects.filter(**lookup).order_by(
"-updated_at"
)[0]
logger.debug(f" here is the embedded file from the lookup {embedded_file}")
if is_user_uploaded_url(f_url):
logger.debug(f"Using existing embeddings for {f_url}")
return embedded_file

if not is_yt_dlp_able_url(f_url):
file_meta = doc_or_yt_url_to_metadatas(f_url)[0][1]

if file_meta == embedded_file.metadata:
logger.debug(f"Using existing embeddings for {f_url}")
return embedded_file
pass

except IndexError:
refs = create_embeddings_in_search_db(
f_url=f_url,
file_meta=file_meta,
file_id=file_id,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model,
is_user_url=is_user_url,
)
with transaction.atomic():
file_meta.save()
embedded_file = EmbeddedFile.objects.get_or_create(
**lookup,
defaults=dict(
metadata=file_meta,
vespa_file_id=file_id,
created_by=current_user,
),
)[0]
for ref in refs:
ref.embedded_file = embedded_file
EmbeddingsReference.objects.bulk_create(refs)
return embedded_file

for leaf_url, file_meta in doc_or_yt_url_to_metadatas(f_url):
lookup.update(
metadata__name=file_meta.name,
metadata__etag=file_meta.etag,
metadata__mime_type=file_meta.mime_type,
metadata__total_bytes=file_meta.total_bytes,
)
file_id = hashlib.sha256(str(lookup).encode()).hexdigest()

refs = create_embeddings_in_search_db(
f_url=leaf_url,
file_meta=file_meta,
file_id=file_id,
max_context_words=max_context_words,
scroll_jump=scroll_jump,
google_translate_target=google_translate_target or "",
selected_asr_model=selected_asr_model or "",
embedding_model=embedding_model,
is_user_url=is_user_url,
)
with transaction.atomic():
file_meta.save()
embedded_file = EmbeddedFile.objects.get_or_create(
**lookup,
defaults=dict(
metadata=file_meta,
vespa_file_id=file_id,
created_by=current_user,
),
)[0]
for ref in refs:
ref.embedded_file = embedded_file
EmbeddingsReference.objects.bulk_create(refs)
return embedded_file


def create_embeddings_in_search_db(
Expand Down

0 comments on commit b2efeda

Please sign in to comment.