Skip to content

Commit

Permalink
Merge pull request #117 from iQuxLE/new_hf_upload-download
Browse files Browse the repository at this point in the history
hf upload/download - duckdb impl
  • Loading branch information
caufieldjh authored Dec 2, 2024
2 parents 45e1258 + 2a5e7a4 commit 00e916e
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 32 deletions.
38 changes: 38 additions & 0 deletions src/curategpt/agents/huggingface_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,48 @@ def upload(self, objects, metadata, repo_id, private=False, **kwargs):

embedding_file = "embeddings.parquet"
metadata_file = "metadata.yaml"

try:
df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['document']) for obj in objects])
except Exception as e:
raise ValueError(f"Creation of Dataframe not successful: {e}") from e

with ExitStack() as stack:
tmp_parquet = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".parquet", delete=True))
tmp_yaml = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".yaml", delete=True))

embedding_path = tmp_parquet.name
metadata_path = tmp_yaml.name

df.to_parquet(path=embedding_path, index=False)
with open(metadata_path, "w") as f:
yaml.dump(metadata.model_dump(), f)

self._create_repo(repo_id, private=private)

self._upload_files(repo_id, {
embedding_path : repo_id + "/" + embedding_file,
metadata_path : repo_id + "/" + metadata_file
})

def upload_duckdb(self, objects, metadata, repo_id, private=False, **kwargs):
"""
Upload an entire collection to a Hugging Face repository.
:param objects: The objects to upload.
:param metadata: The metadata associated with the collection.
:param repo_id: The repository ID on Hugging Face.
:param private: Whether the repository should be private.
:param kwargs: Additional arguments such as batch size or metadata options.
"""

embedding_file = "embeddings.parquet"
metadata_file = "metadata.yaml"
try:
df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['documents']) for obj in objects])
except Exception as e:
raise ValueError(f"Creation of Dataframe not successful: {e}") from e

with ExitStack() as stack:
tmp_parquet = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".parquet", delete=True))
tmp_yaml = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".yaml", delete=True))
Expand Down
6 changes: 4 additions & 2 deletions src/curategpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2465,7 +2465,6 @@ def upload_embeddings(path, collection, repo_id, private, adapter, database_type
try:
objects = list(db.fetch_all_objects_memory_safe(collection=collection))
metadata = db.collection_metadata(collection)
print(metadata)
except Exception as e:
print(f"Error accessing collection '{collection}' from database: {e}")
return
Expand All @@ -2477,7 +2476,10 @@ def upload_embeddings(path, collection, repo_id, private, adapter, database_type
f"Unsupported adapter: {adapter} " f"currently only huggingface adapter is supported"
)
try:
agent.upload(objects=objects, metadata=metadata, repo_id=repo_id, private=private)
if database_type == "chromadb":
agent.upload(objects=objects, metadata=metadata, repo_id=repo_id, private=private)
elif database_type == "duckdb":
agent.upload_duckdb(objects=objects, metadata=metadata, repo_id=repo_id, private=private)
except Exception as e:
print(f"Error uploading collection to {repo_id}: {e}")

Expand Down
26 changes: 17 additions & 9 deletions src/curategpt/store/chromadb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,13 @@ def insert_from_huggingface(
collection = self._get_collection(collection)
model = None

if venomx:
hf_metadata_model = venomx.venomx.embedding_model.name
if hf_metadata_model:
model = hf_metadata_model
try:
if venomx:
hf_metadata_model = venomx.venomx.embedding_model.name
if hf_metadata_model:
model = hf_metadata_model
except Exception as e:
raise KeyError(f"Metadata from {collection} is not compatible with the current version of CurateGPT") from e

venomx = self.populate_venomx(collection, model, venomx.venomx)
cm = self.update_collection_metadata(
Expand Down Expand Up @@ -502,9 +505,11 @@ def _search(
# want to accidentally set it
collection = client.get_collection(name=self._get_collection(collection))
metadata = collection.metadata
# deserialize _venomx str to venomx dict and put in Metadata model
metadata = json.loads(metadata["_venomx"])
metadata = Metadata(venomx=Index(**metadata))
try:
metadata = json.loads(metadata["_venomx"])
metadata = Metadata(venomx=Index(**metadata))
except KeyError as e:
raise KeyError(f"Metadata from {collection} is not compatible with the current version of CurateGPT") from e
collection = client.get_collection(
name=collection.name, embedding_function=self._embedding_function(metadata.venomx.embedding_model.name)
)
Expand Down Expand Up @@ -601,8 +606,11 @@ def diversified_search(
)
collection_obj = self._get_collection_object(collection)
metadata = collection_obj.metadata
metadata = json.loads(metadata["_venomx"])
metadata = Metadata(venomx=Index(**metadata))
try:
metadata = json.loads(metadata["_venomx"])
metadata = Metadata(venomx=Index(**metadata))
except KeyError as e:
raise KeyError(f"Metadata from {collection} is not compatible with the current version of CurateGPT") from e
ef = self._embedding_function(metadata.venomx.embedding_model.name)
if len(text) > self.default_max_document_length:
logger.warning(
Expand Down
102 changes: 81 additions & 21 deletions src/curategpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def insert(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs):
:param kwargs:
:return:
"""
logger.info(f"\n\nIn insert duckdb, {kwargs.get('model')}\n\n")
self._process_objects(objs, method="insert", **kwargs)

# DELETE first to ensure primary key constraint https://duckdb.org/docs/sql/indexes
Expand All @@ -219,9 +218,7 @@ def update(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs):
ids = [self._id(o, self.id_field) for o in objs]
safe_collection_name = f'"{collection}"'
delete_sql = f"DELETE FROM {safe_collection_name} WHERE id = ?"
logger.info("DELETED collection: {collection}")
self.conn.executemany(delete_sql, [(id_,) for id_ in ids])
logger.info(f"INSERTING collection: {collection}")
self.insert(objs, **kwargs)

def upsert(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs):
Expand All @@ -232,8 +229,6 @@ def upsert(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs):
:return:
"""
collection = kwargs.get("collection")
logger.info(f"\n\nUpserting objects into collection {collection}\n\n")
logger.info(f"model in upsert: {kwargs.get('model')}, distance: {self.distance_metric}")
if collection not in self.list_collection_names():
vec_dimension = self._get_embedding_dimension(kwargs.get("model"))
self._create_table_if_not_exists(
Expand All @@ -251,11 +246,9 @@ def upsert(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs):
objs_to_update = [o for o in objs if self._id(o, self.id_field) in existing_ids]
objs_to_insert = [o for o in objs if self._id(o, self.id_field) not in existing_ids]
if objs_to_update:
logger.info(f"in Upsert and updating now in collection: {collection}")
self.update(objs_to_update, **kwargs)

if objs_to_insert:
logger.info(f"in Upsert and inserting now in collection: {collection}")
self.insert(objs_to_insert, **kwargs)

def _process_objects(
Expand Down Expand Up @@ -298,22 +291,18 @@ def _process_objects(
)

if collection not in self.list_collection_names():
logger.info(f"(process)Creating table for collection {collection}")
self._create_table_if_not_exists(
collection, self.vec_dimension, venomx=updated_venomx,
)

# if collection already exists, update metadata here
cm = self.update_collection_metadata(collection=collection, updated_venomx=updated_venomx)
# TODO continue here, and use this cm instead cm = self.collection_md down below
if isinstance(objs, Iterable) and not isinstance(objs, str):
objs = list(objs)
else:
objs = [objs]
obj_count = len(objs)
kwargs.update({"object_count": obj_count})
# no need for update_metadata cause in table creation we build it
# cm = self.collection_metadata(collection)
if batch_size is None:
batch_size = 100000
if text_field is None:
Expand Down Expand Up @@ -356,7 +345,7 @@ def _process_objects(
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
for next_objs in chunk(objs, batch_size): # Existing chunking
for next_objs in chunk(objs, batch_size):
next_objs = list(next_objs)
docs = [self._text(o, text_field) for o in next_objs]
metadatas = [self._dict(o) for o in next_objs]
Expand All @@ -381,7 +370,6 @@ def _process_objects(
texts = [tokenizer.decode(tokens) for tokens in current_batch]
short_name, _ = MODEL_MAP[openai_model]
embedding_model = llm.get_embedding_model(short_name)
logger.info(f"Number of texts/docs to embed in batch: {len(texts)}")
embeddings = list(embedding_model.embed_multi(texts, len(texts)))
logger.info(f"Number of Documents in batch: {len(embeddings)}")
batch_embeddings.extend(embeddings)
Expand All @@ -390,11 +378,7 @@ def _process_objects(
logger.warning(
f"Document with ID {ids[i]} exceeds the token limit alone and will be skipped."
)
# try:
# embeddings = OpenAIEmbeddings(model=model, tiktoken_model_name=model).embed_query(texts,
# embeddings.average model)
# batch_embeddings.extend(embeddings)
# skipping
# should not be happening as batched above
i += 1
continue
else:
Expand All @@ -408,9 +392,6 @@ def _process_objects(
embedding_model = llm.get_embedding_model(short_name)
embeddings = list(embedding_model.embed_multi(texts))
batch_embeddings.extend(embeddings)
logger.info(
f"Trying to insert: {len(ids)} IDS, {len(metadatas)} METADATAS, {len(batch_embeddings)} EMBEDDINGS"
)
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
Expand All @@ -424,8 +405,87 @@ def _process_objects(
)
raise
finally:
# TODO: move outside - check memory/time profile
self.create_index(collection)

def insert_from_huggingface(
self,
objs: Union[OBJECT, Iterable[OBJECT]],
collection: str = None,
batch_size: int = None,
text_field: Union[str, Callable] = None,
venomx: Optional[Metadata] = None,
object_type: Optional[str] = None,
distance: Optional[str] = None,
vec_dimension: Optional[int] = None,
method: str = "insert",
**kwargs,
):
collection = self._get_collection(collection)
model = None
try:
if venomx:
hf_metadata_model = venomx.venomx.embedding_model.name
# object_type = venomx.object_type
distance = venomx.hnsw_space
# vec_dimension = venomx.venomx.embedding_dimension
if hf_metadata_model:
model = hf_metadata_model
vec_dimension = self._get_embedding_dimension(model)

except Exception as e:
raise KeyError(f"Metadata from {collection} is not compatible with the current version of CurateGPT") from e

updated_venomx = self.update_or_create_venomx(
venomx.venomx,
collection,
model,
distance,
object_type,
vec_dimension,
)
if collection not in self.list_collection_names():
self._create_table_if_not_exists(
collection, vec_dimension, venomx=updated_venomx,
)
updated_venomx.venomx.id = collection # prevent name error
self.set_collection_metadata(collection_name=collection, metadata=updated_venomx)
if batch_size is None:
batch_size = 100000

if not isinstance(objs, list):
objs = list(objs)

obj_count = len(objs)
kwargs.update({"object_count": obj_count})

sql_command = self._generate_sql_command(collection, method)
sql_command = sql_command.format(collection=collection)

for next_objs in chunk(objs, batch_size):
next_objs = list(next_objs)
ids = [item['metadata']['id'] for item in next_objs]
metadatas = [self._dict(o) for o in next_objs]
documents = [item['document'] for item in next_objs]
embeddings = [item['embeddings'].tolist() if isinstance(item['embeddings'], np.ndarray)
else item['embeddings'] for item in next_objs]
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, embeddings, documents, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(
f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}"
)
raise
finally:
self.create_index(collection)



def update_or_create_venomx(
self,
venomx: Optional[Index],
Expand Down

0 comments on commit 00e916e

Please sign in to comment.