Skip to content

Commit

Permalink
Add script for ingesting embeddings to mongodb
Browse files Browse the repository at this point in the history
  • Loading branch information
binkjakub committed Jun 1, 2024
1 parent c0089c6 commit b34b116
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions scripts/embed/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import math
from pathlib import Path
from dotenv import load_dotenv
import torch
import tqdm
import typer

from juddges.data.database import BatchDatabaseUpdate, BatchedDatabaseCursor, get_mongo_collection

load_dotenv()

BATCH_SIZE = 64


def main(
mongo_uri: str = typer.Option(..., envvar="MONGO_URI"),
batch_size: int = typer.Option(BATCH_SIZE),
embeddings_file: Path = typer.Option(...),
):
collection = get_mongo_collection(mongo_uri)
query = {"embedding": {"$exists": False}}
cursor = collection.find(query, {"_id": 1})
num_docs_to_update = collection.count_documents(query)
batched_cursor = BatchedDatabaseCursor(cursor, batch_size=BATCH_SIZE, prefetch=True)

embeddings = torch.load(embeddings_file)
ingest_embeddings = BatchDatabaseUpdate(mongo_uri, lambda doc: embeddings.get(doc["_id"]))

for batch in tqdm(batched_cursor, total=math.ceil(num_docs_to_update / batch_size)):
ingest_embeddings(batch)


if __name__ == "__main__":
typer.run(main)

0 comments on commit b34b116

Please sign in to comment.