diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index f0fd51df8..806a50c8a 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -45,7 +45,6 @@ from daras_ai_v2.embedding_model import create_embeddings_cached, EmbeddingModels from daras_ai_v2.exceptions import raise_for_status, call_cmd, UserError from daras_ai_v2.functional import ( - flatmap_parallel, map_parallel, flatmap_parallel_ascompleted, ) @@ -146,9 +145,6 @@ def get_top_k_references( else: selected_asr_model = google_translate_target = None - file_url_metas = flatmap_parallel(doc_or_yt_url_to_metadatas, input_docs) - file_urls, file_metas = zip(*file_url_metas) - yield "Creating knowledge embeddings..." embedding_model = EmbeddingModels.get( @@ -158,9 +154,8 @@ def get_top_k_references( ), ) embedded_files: list[EmbeddedFile] = map_parallel( - lambda f_url, file_meta: get_or_create_embedded_file( + lambda f_url: get_or_create_embedded_file( f_url=f_url, - file_meta=file_meta, max_context_words=request.max_context_words, scroll_jump=request.scroll_jump, google_translate_target=google_translate_target, @@ -169,8 +164,7 @@ def get_top_k_references( is_user_url=is_user_url, current_user=current_user, ), - file_urls, - file_metas, + input_docs, max_workers=4, ) if not embedded_files: @@ -314,7 +308,13 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: else: try: if is_user_uploaded_url(f_url): - r = requests.head(f_url) + name = f.path.segments[-1] + return FileMetadata( + name=name, + etag=None, + mime_type=mimetypes.guess_type(name)[0], + total_bytes=0, + ) else: r = requests.head( f_url, @@ -387,7 +387,6 @@ def yt_dlp_extract_info(url: str) -> dict: def get_or_create_embedded_file( *, f_url: str, - file_meta: FileMetadata, max_context_words: int, scroll_jump: int, google_translate_target: str | None, @@ -402,51 +401,69 @@ def get_or_create_embedded_file( """ lookup = dict( url=f_url, - metadata__name=file_meta.name, - metadata__etag=file_meta.etag, - metadata__mime_type=file_meta.mime_type, - metadata__total_bytes=file_meta.total_bytes, max_context_words=max_context_words, scroll_jump=scroll_jump, google_translate_target=google_translate_target or "", selected_asr_model=selected_asr_model or "", embedding_model=embedding_model.name, ) + file_id = hashlib.sha256(str(lookup).encode()).hexdigest() with redis_lock(f"gooey/get_or_create_embeddings/v1/{file_id}"): try: - return EmbeddedFile.objects.filter(**lookup).order_by("-updated_at")[0] + embedded_file = EmbeddedFile.objects.filter(**lookup).order_by( + "-updated_at" + )[0] + if is_user_uploaded_url(f_url): + return embedded_file + + if not is_yt_dlp_able_url(f_url): + file_meta = doc_or_yt_url_to_metadatas(f_url)[0][1] + if file_meta == embedded_file.metadata: + return embedded_file + pass + except IndexError: - refs = create_embeddings_in_search_db( - f_url=f_url, - file_meta=file_meta, - file_id=file_id, - max_context_words=max_context_words, - scroll_jump=scroll_jump, - google_translate_target=google_translate_target or "", - selected_asr_model=selected_asr_model or "", - embedding_model=embedding_model, - is_user_url=is_user_url, - ) - with transaction.atomic(): - file_meta.save() - embedded_file = EmbeddedFile.objects.get_or_create( - **lookup, - defaults=dict( - metadata=file_meta, - vespa_file_id=file_id, - created_by=current_user, - ), - )[0] - for ref in refs: - ref.embedded_file = embedded_file - EmbeddingsReference.objects.bulk_create( - refs, - update_conflicts=True, - update_fields=["url", "title", "snippet", "updated_at"], - unique_fields=["vespa_doc_id"], + + for leaf_url, file_meta in doc_or_yt_url_to_metadatas(f_url): + lookup.update( + metadata__name=file_meta.name, + metadata__etag=file_meta.etag, + metadata__mime_type=file_meta.mime_type, + metadata__total_bytes=file_meta.total_bytes, + ) + file_id = hashlib.sha256(str(lookup).encode()).hexdigest() + + refs = create_embeddings_in_search_db( + f_url=f_url, + file_meta=file_meta, + file_id=file_id, + max_context_words=max_context_words, + scroll_jump=scroll_jump, + google_translate_target=google_translate_target or "", + selected_asr_model=selected_asr_model or "", + embedding_model=embedding_model, + is_user_url=is_user_url, ) - return embedded_file + with transaction.atomic(): + file_meta.save() + embedded_file = EmbeddedFile.objects.get_or_create( + **lookup, + defaults=dict( + metadata=file_meta, + vespa_file_id=file_id, + created_by=current_user, + ), + )[0] + for ref in refs: + ref.embedded_file = embedded_file + EmbeddingsReference.objects.bulk_create( + refs, + update_conflicts=True, + update_fields=["url", "title", "snippet", "updated_at"], + unique_fields=["vespa_doc_id"], + ) + return embedded_file def create_embeddings_in_search_db(