From f4c6c560961f6f1351333fcab42c02ceeaf3ff42 Mon Sep 17 00:00:00 2001 From: iQuxLE Date: Mon, 2 Dec 2024 13:18:55 +0000 Subject: [PATCH] handling upload and 'documents' column name diversion --- src/curategpt/agents/huggingface_agent.py | 45 +++++++++++++++++++++++ src/curategpt/cli.py | 5 ++- src/curategpt/store/duckdb_adapter.py | 9 +---- 3 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/curategpt/agents/huggingface_agent.py b/src/curategpt/agents/huggingface_agent.py index 5efcce1..91acad7 100644 --- a/src/curategpt/agents/huggingface_agent.py +++ b/src/curategpt/agents/huggingface_agent.py @@ -36,10 +36,55 @@ def upload(self, objects, metadata, repo_id, private=False, **kwargs): embedding_file = "embeddings.parquet" metadata_file = "metadata.yaml" + print("\n\n") + print(objects[0][0]) + print(objects[0][2]['_embeddings']) + print(objects[0][2]['documents']) + print("\n\n") try: df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['document']) for obj in objects]) except Exception as e: + # df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['documents']) for obj in objects]) + # logger.info(f"df changed") raise ValueError(f"Creation of Dataframe not successful: {e}") from e + + + with ExitStack() as stack: + tmp_parquet = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".parquet", delete=True)) + tmp_yaml = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".yaml", delete=True)) + + embedding_path = tmp_parquet.name + metadata_path = tmp_yaml.name + + df.to_parquet(path=embedding_path, index=False) + with open(metadata_path, "w") as f: + yaml.dump(metadata.model_dump(), f) + + self._create_repo(repo_id, private=private) + + self._upload_files(repo_id, { + embedding_path : repo_id + "/" + embedding_file, + metadata_path : repo_id + "/" + metadata_file + }) + + def upload_duckdb(self, objects, metadata, repo_id, private=False, **kwargs): + """ + Upload an entire collection to a Hugging Face repository. + + :param objects: The objects to upload. + :param metadata: The metadata associated with the collection. + :param repo_id: The repository ID on Hugging Face. + :param private: Whether the repository should be private. + :param kwargs: Additional arguments such as batch size or metadata options. + """ + + embedding_file = "embeddings.parquet" + metadata_file = "metadata.yaml" + try: + df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['documents']) for obj in objects]) + except Exception as e: + raise ValueError(f"Creation of Dataframe not successful: {e}") from e + with ExitStack() as stack: tmp_parquet = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".parquet", delete=True)) tmp_yaml = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".yaml", delete=True)) diff --git a/src/curategpt/cli.py b/src/curategpt/cli.py index 4a6400a..53f231f 100644 --- a/src/curategpt/cli.py +++ b/src/curategpt/cli.py @@ -2476,7 +2476,10 @@ def upload_embeddings(path, collection, repo_id, private, adapter, database_type f"Unsupported adapter: {adapter} " f"currently only huggingface adapter is supported" ) try: - agent.upload(objects=objects, metadata=metadata, repo_id=repo_id, private=private) + if database_type == "chromadb": + agent.upload(objects=objects, metadata=metadata, repo_id=repo_id, private=private) + elif database_type == "duckdb": + agent.upload_duckdb(objects=objects, metadata=metadata, repo_id=repo_id, private=private) except Exception as e: print(f"Error uploading collection to {repo_id}: {e}") diff --git a/src/curategpt/store/duckdb_adapter.py b/src/curategpt/store/duckdb_adapter.py index c0fd7ef..9ce70b9 100644 --- a/src/curategpt/store/duckdb_adapter.py +++ b/src/curategpt/store/duckdb_adapter.py @@ -378,11 +378,7 @@ def _process_objects( logger.warning( f"Document with ID {ids[i]} exceeds the token limit alone and will be skipped." ) - # try: - # embeddings = OpenAIEmbeddings(model=model, tiktoken_model_name=model).embed_query(texts, - # embeddings.average model) - # batch_embeddings.extend(embeddings) - # skipping + # should not be happening as batched above i += 1 continue else: @@ -396,9 +392,6 @@ def _process_objects( embedding_model = llm.get_embedding_model(short_name) embeddings = list(embedding_model.embed_multi(texts)) batch_embeddings.extend(embeddings) - logger.info( - f"Trying to insert: {len(ids)} IDS, {len(metadatas)} METADATAS, {len(batch_embeddings)} EMBEDDINGS" - ) try: self.conn.execute("BEGIN TRANSACTION;") self.conn.executemany(