Skip to content

Commit

Permalink
batch embedding function call, using tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
iQuxLE committed Aug 20, 2024
1 parent cf1277b commit 12715a9
Showing 1 changed file with 125 additions and 29 deletions.
154 changes: 125 additions & 29 deletions src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from dataclasses import dataclass, field
from typing import Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, Optional, Union

from langchain_openai import OpenAIEmbeddings
import duckdb
import numpy as np
import openai
Expand Down Expand Up @@ -174,7 +174,37 @@ def create_index(self, collection: str):
"""
self.conn.execute(create_index_sql)

def _embedding_function(self, texts: Union[str, List[str]], model: str = None) -> list:
def _embedding_function_langchain(self, texts: Union[str, List[str], List[List[str]]], model: str = None) -> list:
"""
Get the embeddings for the given texts using the specified model
:param texts: A single text or a list of texts or list of list of texts to embed
:param model: Model to use for embedding, defaults to "text-embedding-ada-002" if none provided
:return: A list of embeddings or a list of list of embeddings depending on input format
"""
if model is None:
model = "text-embedding-ada-002" # Default model

# Flatten the input if it's a list of lists
flatten = False
if any(isinstance(i, list) for i in texts):
original_structure = [len(sublist) for sublist in texts if isinstance(sublist, list)]
texts = [item for sublist in texts for item in sublist] # Flatten the list
flatten = True

embeddings = embeddings.embed_documents(texts, model)

# If the input was a list of lists, reconstruct the nested list structure
if flatten:
new_embeddings = []
index = 0
for size in original_structure:
new_embeddings.append(embeddings[index:index + size])
index += size
embeddings = new_embeddings

return embeddings

def _embedding_function(self, texts: Union[str, List[str], List[List[str]]], model: str = None) -> list:
"""
Get the embeddings for the given texts using the specified model
:param texts: A single text or a list of texts to embed
Expand Down Expand Up @@ -320,33 +350,99 @@ def _process_objects(
cumulative_len = 0
sql_command = self._generate_sql_command(collection, method)
sql_command = sql_command.format(collection=collection)
for next_objs in chunk(objs, batch_size):
next_objs = list(next_objs)
logger.info("Processing batch of objects in DuckDB process_objects ...")
docs = [self._text(o, text_field) for o in next_objs]
docs_len = sum([len(d) for d in docs])
cumulative_len += docs_len
if self._is_openai(collection) and cumulative_len > 3000000:
logger.warning(f"Cumulative length = {cumulative_len}, pausing ...")
time.sleep(60)
cumulative_len = 0
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]
embeddings = self._embedding_function(docs, cm.model)
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, embeddings, docs, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(
f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}"
)
raise
finally:
self.create_index(collection)
if not self._is_openai(collection):
for next_objs in chunk(objs, batch_size):
next_objs = list(next_objs)
docs = [self._text(o, text_field) for o in next_objs]
docs_len = sum([len(d) for d in docs])
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]
embeddings = self._embedding_function(docs, cm.model)
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, embeddings, docs, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}")
raise
finally:
self.create_index(collection)
else:
if model.startswith("openai:"):
openai_model = model.split(":", 1)[1]
if openai_model == "" or openai_model not in MODELS:
logger.info(f"The model {openai_model} is not "
f"one of {MODELS}. Defaulting to {MODELS[0]}")
openai_model = MODELS[0] #ada 002
else:
logger.error(f"Something went wonky ## model: {model}")
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
for next_objs in chunk(objs, batch_size): # Existing chunking
next_objs = list(next_objs)
docs = [self._text(o, text_field) for o in next_objs]
docs_len = sum([len(d) for d in docs])
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]

tokenized_docs = [tokenizer.encode(doc) for doc in docs]
current_batch = []
current_token_count = 0
batch_embeddings = []

i = 0
while i < len(tokenized_docs):
doc_tokens = tokenized_docs[i]
# peek
if current_token_count + len(doc_tokens) <= 8192:
current_batch.append(doc_tokens)
current_token_count += len(doc_tokens)
i += 1
else:
if current_batch:
logger.info(f"Curent token count to embed: {current_token_count}")
texts = [tokenizer.decode(tokens) for tokens in current_batch]
embeddings = OpenAIEmbeddings(model=openai_model, tiktoken_model_name=model).embed_documents(texts,
openai_model)
logger.info(f"len embeddings: {len(embeddings)}")
batch_embeddings.extend(embeddings)

if len(doc_tokens) > 8192:
logger.warning(
f"Document with ID {ids[i]} exceeds the token limit alone and will be skipped.")
# try:
# embeddings = OpenAIEmbeddings(model=model, tiktoken_model_name=model).embed_query(texts,
# model)
# batch_embeddings.extend(embeddings)
# skipping
i += 1
continue
else:
current_batch = []
current_token_count = 0

if current_batch:
texts = [tokenizer.decode(tokens) for tokens in current_batch]
embeddings = OpenAIEmbeddings(model=openai_model, tiktoken_model_name=openai_model).embed_documents(texts,
openai_model)
batch_embeddings.extend(embeddings)
logger.info(f"Trying to insert: {len(ids)} IDS, {len(metadatas)} METADATAS, {len(batch_embeddings)} EMBEDDINGS")
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, batch_embeddings, docs, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(
f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}")
raise
finally:
self.create_index(collection)

def remove_collection(self, collection: str = None, exists_ok=False, **kwargs):
"""
Expand Down

0 comments on commit 12715a9

Please sign in to comment.