Skip to content

Commit

Permalink
handling upload and 'documents' column name diversion
Browse files Browse the repository at this point in the history
  • Loading branch information
iQuxLE committed Dec 2, 2024
1 parent 3861469 commit f4c6c56
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
45 changes: 45 additions & 0 deletions src/curategpt/agents/huggingface_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,55 @@ def upload(self, objects, metadata, repo_id, private=False, **kwargs):

embedding_file = "embeddings.parquet"
metadata_file = "metadata.yaml"
print("\n\n")
print(objects[0][0])
print(objects[0][2]['_embeddings'])
print(objects[0][2]['documents'])
print("\n\n")
try:
df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['document']) for obj in objects])
except Exception as e:
# df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['documents']) for obj in objects])
# logger.info(f"df changed")
raise ValueError(f"Creation of Dataframe not successful: {e}") from e


with ExitStack() as stack:
tmp_parquet = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".parquet", delete=True))
tmp_yaml = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".yaml", delete=True))

embedding_path = tmp_parquet.name
metadata_path = tmp_yaml.name

df.to_parquet(path=embedding_path, index=False)
with open(metadata_path, "w") as f:
yaml.dump(metadata.model_dump(), f)

self._create_repo(repo_id, private=private)

self._upload_files(repo_id, {
embedding_path : repo_id + "/" + embedding_file,
metadata_path : repo_id + "/" + metadata_file
})

def upload_duckdb(self, objects, metadata, repo_id, private=False, **kwargs):
"""
Upload an entire collection to a Hugging Face repository.
:param objects: The objects to upload.
:param metadata: The metadata associated with the collection.
:param repo_id: The repository ID on Hugging Face.
:param private: Whether the repository should be private.
:param kwargs: Additional arguments such as batch size or metadata options.
"""

embedding_file = "embeddings.parquet"
metadata_file = "metadata.yaml"
try:
df = pd.DataFrame(data=[(obj[0], obj[2]['_embeddings'], obj[2]['documents']) for obj in objects])
except Exception as e:
raise ValueError(f"Creation of Dataframe not successful: {e}") from e

with ExitStack() as stack:
tmp_parquet = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".parquet", delete=True))
tmp_yaml = stack.enter_context(tempfile.NamedTemporaryFile(suffix=".yaml", delete=True))
Expand Down
5 changes: 4 additions & 1 deletion src/curategpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2476,7 +2476,10 @@ def upload_embeddings(path, collection, repo_id, private, adapter, database_type
f"Unsupported adapter: {adapter} " f"currently only huggingface adapter is supported"
)
try:
agent.upload(objects=objects, metadata=metadata, repo_id=repo_id, private=private)
if database_type == "chromadb":
agent.upload(objects=objects, metadata=metadata, repo_id=repo_id, private=private)
elif database_type == "duckdb":
agent.upload_duckdb(objects=objects, metadata=metadata, repo_id=repo_id, private=private)
except Exception as e:
print(f"Error uploading collection to {repo_id}: {e}")

Expand Down
9 changes: 1 addition & 8 deletions src/curategpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,7 @@ def _process_objects(
logger.warning(
f"Document with ID {ids[i]} exceeds the token limit alone and will be skipped."
)
# try:
# embeddings = OpenAIEmbeddings(model=model, tiktoken_model_name=model).embed_query(texts,
# embeddings.average model)
# batch_embeddings.extend(embeddings)
# skipping
# should not be happening as batched above
i += 1
continue
else:
Expand All @@ -396,9 +392,6 @@ def _process_objects(
embedding_model = llm.get_embedding_model(short_name)
embeddings = list(embedding_model.embed_multi(texts))
batch_embeddings.extend(embeddings)
logger.info(
f"Trying to insert: {len(ids)} IDS, {len(metadatas)} METADATAS, {len(batch_embeddings)} EMBEDDINGS"
)
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
Expand Down

0 comments on commit f4c6c56

Please sign in to comment.