diff --git a/src/curategpt/cli.py b/src/curategpt/cli.py index f5ff96c..2e16f7c 100644 --- a/src/curategpt/cli.py +++ b/src/curategpt/cli.py @@ -49,6 +49,8 @@ "main", ] +from venomx.model.venomx import Dataset, Index, Model + def dump( obj: Union[str, AnnotatedObject, Dict], @@ -2086,8 +2088,9 @@ def list_collections(database_type, path, peek: bool, minimal: bool, derived: bo cm = db.collection_metadata(cn, include_derived=derived) if database_type == "chromadb": # TODO: make get_or_create abstract and implement in DBAdapter? - c = db.client.get_or_create_collection(cn) - print(f"## Collection: {cn} N={c.count()} meta={c.metadata} // {cm}") + c = db.client.get_collection(cn) + print(f"## Collection: {cn} N={c.count()} meta={c.metadata} \n" + f"Metadata: {cm}\n") if peek: r = c.peek() for id_ in r["ids"]: @@ -2322,7 +2325,6 @@ def index_ontology_command( curategpt ontology index -p stagedb/duck.db -c ont-hp sqlite:obo:hp -D duckdb """ - s = time.time() oak_adapter = get_adapter(ont) view = OntologyWrapper(oak_adapter=oak_adapter) @@ -2343,8 +2345,26 @@ def _text_lookup(obj: Dict): if not append: db.remove_collection(collection, exists_ok=True) click.echo(f"Indexing {len(list(view.objects()))} objects") - db.insert(view.objects(), collection=collection, model=model) - db.update_collection_metadata(collection, object_type="OntologyClass") + + venomx = Index( + id=collection, + dataset=Dataset( + name=ont + ), + embedding_model=Model( + name=model if model else None + ) + ) + + db.insert( + view.objects(), + collection=collection, + model=model, + venomx=venomx, + object_type="OntologyClass" + + ) + e = time.time() click.echo(f"Indexed {len(list(view.objects()))} in {e - s} seconds") diff --git a/src/curategpt/store/__init__.py b/src/curategpt/store/__init__.py index a59c9e5..9be5c96 100644 --- a/src/curategpt/store/__init__.py +++ b/src/curategpt/store/__init__.py @@ -19,7 +19,7 @@ from .chromadb_adapter import ChromaDBAdapter from .db_adapter import DBAdapter from .duckdb_adapter import DuckDBAdapter -from .metadata import CollectionMetadata +from .metadata import Metadata from .schema_proxy import SchemaProxy __all__ = [ @@ -27,7 +27,7 @@ "ChromaDBAdapter", "DuckDBAdapter", "SchemaProxy", - "CollectionMetadata", + "Metadata", "get_store", ] @@ -40,7 +40,7 @@ def get_all_subclasses(cls): ] -def get_store(name: str, *args, **kwargs) -> DBAdapter: # duckdb_vss or chromadb +def get_store(name: str, *args, **kwargs) -> DBAdapter: # duckdb or chromadb from .in_memory_adapter import InMemoryAdapter # noqa F401 # noqa I005 diff --git a/src/curategpt/store/chromadb_adapter.py b/src/curategpt/store/chromadb_adapter.py index 1b6c983..a4aab53 100644 --- a/src/curategpt/store/chromadb_adapter.py +++ b/src/curategpt/store/chromadb_adapter.py @@ -5,22 +5,22 @@ import os import time from dataclasses import dataclass, field -from typing import Callable, ClassVar, Iterable, Iterator, List, Mapping, Optional, Union +from typing import Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, Optional, Union import chromadb import yaml from chromadb import ClientAPI as API from chromadb import Settings from chromadb.api import EmbeddingFunction -from chromadb.types import Collection from chromadb.utils import embedding_functions from linkml_runtime.dumpers import json_dumper from linkml_runtime.utils.yamlutils import YAMLRoot from oaklib.utilities.iterator_utils import chunk -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError +from venomx.model.venomx import Index, Model from curategpt.store.db_adapter import DBAdapter -from curategpt.store.metadata import CollectionMetadata +from curategpt.store.metadata import Metadata from curategpt.store.vocab import OBJECT, PROJECTION, QUERY, SEARCH_RESULT from curategpt.utils.vector_algorithms import mmr_diversified_search @@ -34,7 +34,7 @@ class ChromaDBAdapter(DBAdapter): """ name: ClassVar[str] = "chromadb" - default_model = "all-MiniLM-L6-v2" + default_model: str = "all-MiniLM-L6-v2" client: API = None id_field: str = field(default="id") text_lookup: Optional[Union[str, Callable]] = field(default="text") @@ -111,6 +111,7 @@ def _object_metadata(self, obj: OBJECT): k: v for k, v in dict_obj.items() if not isinstance(v, (dict, list)) and v is not None } + def reset(self): """ Reset/delete the database. @@ -124,8 +125,6 @@ def _embedding_function(self, model: str = None) -> EmbeddingFunction: :param model: :return: """ - if model is None: - raise ValueError("Model must be specified") if model.startswith("openai:"): return embedding_functions.OpenAIEmbeddingFunction( api_key=os.environ.get("OPENAI_API_KEY"), @@ -149,6 +148,7 @@ def _insert_or_update( object_type: str = None, model: str = None, text_field: Union[str, Callable] = None, + venomx: Optional[Index] = None, **kwargs, ): """ @@ -161,22 +161,36 @@ def _insert_or_update( """ client = self.client collection = self._get_collection(collection) - cm = self.collection_metadata(collection) + # This is only None when inserting in a new collection + # otherwise it fetches Metadata from collection + cm = self.collection_metadata(collection, **kwargs) if model is None: - if cm: - model = cm.model + # if collection does not exist cm is None + if cm and cm.venomx and cm.venomx.embedding_model: + model = cm.venomx.embedding_model.name if model is None: model = self.default_model - cm = self.update_collection_metadata(collection, model=model, object_type=object_type) - ef = self._embedding_function(cm.model) - # cm = CollectionMetadata(name=collection, model=self.model, object_type=object_type) - cm_dict = cm.dict(exclude_none=True) + logger.info(f"No Model specified. Defaulting to {self.default_model}") + # earlier: if venomx is None - But when venomx is None, when populating + # beforehand we would not have updated in case model was None from cli + # and when acting from CLI the model would still be None, which would have cause _is_openai to crash as trying + # to access a None value + venomx = self.populate_venomx(collection, model, venomx) + cm = self.update_collection_metadata( + collection, + model=model, + object_type=object_type, + venomx=venomx + ) + ef = self._embedding_function(model) + # serializing metadata for insertion into db to fit chroma db requirements + adapter_metadata = cm.serialize_venomx_metadata_for_adapter(self.name) collection_obj = client.get_or_create_collection( name=collection, embedding_function=ef, - metadata=cm_dict, + metadata=adapter_metadata, ) - if self._is_openai(collection_obj) and batch_size is None: + if self._is_openai(venomx) and batch_size is None: # TODO: see https://github.com/chroma-core/chroma/issues/709 batch_size = 100 if batch_size is None: @@ -194,12 +208,11 @@ def _insert_or_update( docs_len = sum([len(d) for d in docs]) cumulative_len += docs_len # TODO: use tiktoken to get a better estimate - if self._is_openai(collection_obj) and cumulative_len > 3000000: + if self._is_openai(venomx) and cumulative_len > 3000000: logger.warning(f"Cumulative length = {cumulative_len}, pausing ...") # TODO: this is too conservative; it should be based on time of start of batch time.sleep(60) cumulative_len = 0 - logger.debug(f"Example doc (tf={text_field}): {docs[0]}") logger.info("Preparing metadatas...") metadatas = [self._object_metadata(o) for o in next_objs] logger.info("Preparing ids...") @@ -261,7 +274,7 @@ def list_collection_names(self) -> List[str]: def collection_metadata( self, collection_name: Optional[str] = None, include_derived=False, **kwargs - ) -> Optional[CollectionMetadata]: + ) -> Optional[Metadata]: """ Get the metadata for a collection. @@ -273,19 +286,34 @@ def collection_metadata( """ collection_name = self._get_collection(collection_name) try: - logger.info(f"Getting collection object {collection_name}") collection_obj = self.client.get_collection(name=collection_name) - except Exception: + logger.info(f"## GETTING METADATA FROM OBJ :{collection_obj.metadata}") + except Exception as e: + logger.error(f"Failed to get collection {collection_name}: {e}") return None - cm = CollectionMetadata(**collection_obj.metadata) + metadata_data = {**collection_obj.metadata, **kwargs} + logger.debug(f"Metadata from col_obj({metadata_data})") + try: + logger.info("Deserializing _venomx Metadata") + cm = Metadata.deserialize_venomx_metadata_from_adapter(metadata_data, self.name) + logger.info(f"## Metadata : {cm}") + except ValidationError as ve: + logger.error(f"Deserializing failed. Metadata validation error: {ve}") + logger.debug("No '_venomx' in Metadata, thus creating a new, clean Metadata(venomx=Index()) object.") + cm = Metadata(venomx=Index()) + logger.info(f"## New Clean Metadata: {cm}") + if include_derived: - logger.info(f"Getting object count for {collection_name}") - cm.object_count = collection_obj.count() + try: + logger.info(f"Getting object count for {collection_name}") + cm.object_count = collection_obj.count() + except Exception as e: + logger.error(f"Failed to get object count: {e}") return cm def set_collection_metadata( - self, collection_name: Optional[str], metadata: CollectionMetadata, **kwargs - ): + self, collection_name: Optional[str], metadata: Metadata, **kwargs + ) -> Union[Metadata, Dict]: """ Set the metadata for a collection. @@ -293,43 +321,105 @@ def set_collection_metadata( :param metadata: :return: """ - self.update_collection_metadata( - collection_name=collection_name, **metadata.dict(exclude_none=True) + current_metadata = self.collection_metadata(collection_name=collection_name) + + if metadata: + if metadata.venomx.id != collection_name: + raise ValueError(f"venomx.id: {metadata.venomx.id} must match collection_name {collection_name}") + + # metadata = metadata.model_copy(update=scalar_updates) + new_model = metadata.venomx.embedding_model.name + + prev_model = current_metadata.venomx.embedding_model.name + if prev_model and new_model != prev_model: + if self.client.get_or_create_collection(name=collection_name).count() > 0: + raise ValueError(f"Cannot change model from {prev_model} to {new_model}") + + chromadb_metadata = metadata.serialize_venomx_metadata_for_adapter(self.name) + self.client.get_or_create_collection( + name=collection_name, + metadata=chromadb_metadata ) + return chromadb_metadata - def update_collection_metadata(self, collection_name: str, **kwargs) -> CollectionMetadata: + def update_collection_metadata(self, collection_name: str, **kwargs) -> Metadata: """ - Update the metadata for a collection. + Update the metadata for a collection based on the adapter. - :param collection_name: - :param kwargs: - :return: + :param collection_name: Name of the collection. + :param kwargs: Additional metadata fields. + :return: Updated Metadata instance. """ collection_name = self._get_collection(collection_name) metadata = self.collection_metadata(collection_name=collection_name) - if metadata is None: - metadata = CollectionMetadata(**kwargs) - else: - prev_model = metadata.model - metadata = metadata.copy(update=kwargs) - if prev_model and metadata.model != prev_model: + logger.info(f"## Metadata: {metadata}") + + # Ensure 'venomx.id' matches 'collection_name' if venomx is provided + if metadata: + if metadata.venomx.id != collection_name: + raise ValueError(f"venomx.id: {metadata.venomx.id} must match collection_name {collection_name}") + + # if metadata available from cm + if metadata is not None: + scalar_updates = {k: v for k, v in kwargs.items() if k != "venomx"} # any additional param, e.g model or object type + metadata = metadata.model_copy(update=scalar_updates) + + prev_model = metadata.venomx.embedding_model.name + # specifically for "set_collection_metadata" from cli where kwargs.get('model') would be None if not given + # and update_collection_metadata would get this + if kwargs.get('model') is None: + kwargs['model'] = self.default_model + if prev_model and kwargs.get('model') != prev_model: if self.client.get_or_create_collection(name=collection_name).count() > 0: - raise ValueError(f"Cannot change model from {prev_model} to {metadata.model}") - else: - logger.info( - f"Changing (empty collection) model from {prev_model} to {metadata.model}" - ) - # self.set_collection_metadata(collection_name=collection_name, metadata=metadata) - if metadata.name: - assert metadata.name == collection_name + raise ValueError(f"Cannot change model from {prev_model} to {kwargs.get('model')}") + + # assign venomx to metadata object + if "venomx" in kwargs and kwargs.get("venomx") is not None: + metadata = Metadata(venomx=kwargs.get("venomx")) + else: - metadata.name = collection_name - metadata.hnsw_space = "cosine" + metadata = Metadata( + venomx=kwargs.get("venomx"), + hnsw_space=kwargs.get("hnsw_space", "cosine"), + object_type=kwargs.get("object_type"), + ) + logger.info(metadata) + + # metadata.hnsw_space = "cosine" + chromadb_metadata = metadata.serialize_venomx_metadata_for_adapter(self.name) self.client.get_or_create_collection( - name=collection_name, metadata=metadata.dict(exclude_none=True) + name=collection_name, + metadata=chromadb_metadata ) return metadata + @staticmethod + def populate_venomx(collection: Optional[str], model: Optional[str], existing_venomx: Index) -> Index: + """ + Populate venomx with data currently given when inserting + + :param collection: + :param model: + :param existing_venomx + :return: + """ + logger.info("Populating venomx") + venomx = Index( + id=f"{collection}", + embedding_model=Model( + name=model + ) + ) + logger.info(f"Retrieving venomx as: {venomx}") + if existing_venomx: + logger.info("Updating Venomx with the one created by the CLI command") + existing_venomx = existing_venomx.model_dump(exclude_none=True) + # validate + existing_venomx = Metadata(venomx=existing_venomx) + # update + venomx = venomx.model_copy(update=existing_venomx.model_dump()) + return venomx + def search(self, text: str, **kwargs) -> Iterator[SEARCH_RESULT]: yield from self._search(text=text, **kwargs) @@ -365,8 +455,11 @@ def _search( # want to accidentally set it collection = client.get_collection(name=self._get_collection(collection)) metadata = collection.metadata + # deserialize _venomx str to venomx dict and put in Metadata model + metadata = json.loads(metadata["_venomx"]) + metadata = Metadata(venomx=Index(**metadata)) collection = client.get_collection( - name=collection.name, embedding_function=self._embedding_function(metadata["model"]) + name=collection.name, embedding_function=self._embedding_function(metadata.venomx.embedding_model.name) ) logger.debug(f"Collection metadata: {metadata}") if text: @@ -461,7 +554,9 @@ def diversified_search( ) collection_obj = self._get_collection_object(collection) metadata = collection_obj.metadata - ef = self._embedding_function(metadata["model"]) + metadata = json.loads(metadata["_venomx"]) + metadata = Metadata(venomx=Index(**metadata)) + ef = self._embedding_function(metadata.venomx.embedding_model.name) if len(text) > self.default_max_document_length: logger.warning( f"Text too long ({len(text)}), truncating to {self.default_max_document_length}" @@ -522,10 +617,12 @@ def collections(self) -> Iterator[str]: for c in client.list_collections(): yield c.name - def _is_openai(self, collection: Collection): - if collection.metadata.get("model", "").startswith("openai:"): + @staticmethod + def _is_openai(venomx): + if venomx.embedding_model.name.startswith("openai:"): return True + def peek(self, collection: str = None, limit=5, offset: int = 0, **kwargs) -> Iterator[OBJECT]: c = self.client.get_collection(name=self._get_collection(collection)) logger.debug(f"Peeking at {collection} with limit={limit}, offset={offset}") @@ -589,13 +686,14 @@ def dump_then_load(self, collection: str = None, target: DBAdapter = None): if not isinstance(target, ChromaDBAdapter): raise ValueError("Target must be a ChromaDBAdapter") cm = self.collection_metadata(collection) - ef = self._embedding_function(cm.model) + adapter_metadata = cm.serialize_venomx_metadata_for_adapter(self.name) + ef = self._embedding_function(cm.venomx.embedding_model.name) # this currently prevents interadapter copying (duck to chroma) # target.get_collection (abstract) should be implemented target_collection_obj = target.client.get_or_create_collection( name=collection, embedding_function=ef, - metadata=cm.dict(exclude_none=True), + metadata=adapter_metadata ) result = collection_obj.get(include=["metadatas", "documents", "embeddings"]) if not result["ids"]: diff --git a/src/curategpt/store/db_adapter.py b/src/curategpt/store/db_adapter.py index 7d08665..c65c574 100644 --- a/src/curategpt/store/db_adapter.py +++ b/src/curategpt/store/db_adapter.py @@ -12,7 +12,7 @@ from click.utils import LazyFile from jsonlines import jsonlines -from curategpt.store.metadata import CollectionMetadata +from curategpt.store.metadata import Metadata from curategpt.store.schema_proxy import SchemaProxy from curategpt.store.vocab import ( DEFAULT_COLLECTION, @@ -205,7 +205,7 @@ def list_collection_names(self) -> List[str]: @abstractmethod def collection_metadata( self, collection_name: Optional[str] = None, include_derived=False, **kwargs - ) -> Optional[CollectionMetadata]: + ) -> Optional[Metadata]: """ Get the metadata for a collection. @@ -215,15 +215,17 @@ def collection_metadata( """ def set_collection_metadata( - self, collection_name: Optional[str], metadata: CollectionMetadata, **kwargs + self, collection_name: Optional[str], metadata: Metadata, **kwargs ): """ Set the metadata for a collection. >>> from curategpt.store import get_store - >>> from curategpt.store import CollectionMetadata + >>> from curategpt.store import Metadata >>> store = get_store("in_memory") - >>> cm = CollectionMetadata(name="People", description="People in the database") + >>> md = store.collection_metadata(collection) + >>> md.venomx.id == "People" + >>> md.venomx.embedding_model.name == "openai:" >>> store.set_collection_metadata("people", cm) :param collection_name: @@ -231,7 +233,7 @@ def set_collection_metadata( """ raise NotImplementedError - def update_collection_metadata(self, collection_name: str, **kwargs) -> CollectionMetadata: + def update_collection_metadata(self, collection_name: str, **kwargs) -> Metadata: """ Update the metadata for a collection. diff --git a/src/curategpt/store/duckdb_adapter.py b/src/curategpt/store/duckdb_adapter.py index 0a771c6..eae3d1c 100644 --- a/src/curategpt/store/duckdb_adapter.py +++ b/src/curategpt/store/duckdb_adapter.py @@ -6,9 +6,19 @@ import json import logging import os -import re from dataclasses import dataclass, field -from typing import Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, Optional, Union +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Union, +) import duckdb import llm @@ -22,10 +32,12 @@ from openai import OpenAI from pydantic import BaseModel from sentence_transformers import SentenceTransformer +from venomx.model.venomx import Index, Model from curategpt.store.db_adapter import DBAdapter +from curategpt.store.duckdb_connection_handler import DuckDBConnectionAndRecoveryHandler from curategpt.store.duckdb_result import DuckDBSearchResult -from curategpt.store.metadata import CollectionMetadata +from curategpt.store.metadata import Metadata from curategpt.store.vocab import ( DEFAULT_MODEL, DEFAULT_OPENAI_MODEL, @@ -62,37 +74,12 @@ class DuckDBAdapter(DBAdapter): openai_client: OpenAI = field(default=None) def __post_init__(self): - if not self.path: - self.path = "./db/db_file.duckdb" - if os.path.isdir(self.path): - self.path = os.path.join("./db", self.path, "db_file.duckdb") - os.makedirs(os.path.dirname(self.path), exist_ok=True) - logger.info( - f"Path {self.path} is a directory. Using {self.path} as the database path\n\ - as duckdb needs a file path" - ) + self.connection_handler = DuckDBConnectionAndRecoveryHandler(self.path) self.ef_construction = self._validate_ef_construction(self.ef_construction) self.ef_search = self._validate_ef_search(self.ef_search) self.M = self._validate_m(self.M) - logger.info(f"Using DuckDB at {self.path}") - # handling concurrency - try: - self.conn = duckdb.connect(self.path, read_only=False) - except duckdb.IOException as e: - match = re.search(r"PID (\d+)", str(e)) - if match: - pid = int(match.group(1)) - logger.info(f"Got {e}.Attempting to kill process with PID: {pid}") - self.kill_process(pid) - self.conn = duckdb.connect(self.path, read_only=False) - else: - logger.error(f"{e} without PID information.") - raise - self.conn.execute("INSTALL vss;") - self.conn.execute("LOAD vss;") - self.conn.execute("SET hnsw_enable_experimental_persistence=true;") - if self.default_model is None: - self.model = self.default_model + self.conn = self.connection_handler.connect() + self.model = self.default_model self.vec_dimension = self._get_embedding_dimension(self.default_model) def _initialize_openai_client(self): @@ -117,22 +104,13 @@ def _get_collection_name(self, collection: Optional[str] = None) -> str: return self._get_collection(collection) def _create_table_if_not_exists( - self, collection: str, vec_dimension: int, distance: str, model: str = None + self, collection: str, vec_dimension: int, venomx: Metadata = None ): """ Create a table for the given collection if it does not exist :param collection: :return: """ - logger.info( - f"Table {collection} does not exist, creating ...: PARAMS: model: {model}, distance: {distance},\ - vec_dimension: {vec_dimension}" - ) - if model is None: - model = self.default_model - logger.info(f"Model in create_table_if_not_exists: {model}") - if distance is None: - distance = self.distance_metric safe_collection_name = f'"{collection}"' create_table_sql = f""" CREATE TABLE IF NOT EXISTS {safe_collection_name} ( @@ -144,16 +122,16 @@ def _create_table_if_not_exists( """ self.conn.execute(create_table_sql) - metadata = CollectionMetadata(name=collection, model=model, hnsw_space=distance) - metadata_json = json.dumps(metadata.dict(exclude_none=True)) - safe_collection_name = f'"{collection}"' - self.conn.execute( - f""" - INSERT INTO {safe_collection_name} (id, metadata) VALUES ('__metadata__', ?) - ON CONFLICT (id) DO NOTHING - """, - [metadata_json], - ) + if venomx: + venomx = venomx.model_dump(exclude_none=True) + # venomx metadata insertion + self.conn.execute( + f""" + INSERT INTO {safe_collection_name} (id, metadata) VALUES ('__venomx__', ?) + ON CONFLICT (id) DO NOTHING + """, + [venomx], + ) def create_index(self, collection: str): """ @@ -289,6 +267,7 @@ def _process_objects( model: str = None, distance: str = None, text_field: Union[str, Callable] = None, + venomx: Optional[Metadata] = None, method: str = "insert", **kwargs, ): @@ -305,21 +284,36 @@ def _process_objects( :return: """ collection = self._get_collection_name(collection) - logger.info(f"Processing objects for collection {collection}") + if model is None: + model = self.default_model self.vec_dimension = self._get_embedding_dimension(model) - logger.info(f"(process_objects: Model: {model}, vec_dimension: {self.vec_dimension}") + + updated_venomx = self.update_or_create_venomx( + venomx, + collection, + model, + distance, + object_type, + self.vec_dimension, + ) + if collection not in self.list_collection_names(): logger.info(f"(process)Creating table for collection {collection}") self._create_table_if_not_exists( - collection, self.vec_dimension, model=model, distance=distance + collection, self.vec_dimension, venomx=updated_venomx, ) + + # if collection already exists, update metadata here + cm = self.update_collection_metadata(collection=collection, updated_venomx=updated_venomx) + # TODO continue here, and use this cm instead cm = self.collection_md down below if isinstance(objs, Iterable) and not isinstance(objs, str): objs = list(objs) else: objs = [objs] obj_count = len(objs) kwargs.update({"object_count": obj_count}) - cm = self.collection_metadata(collection) + # no need for update_metadata cause in table creation we build it + # cm = self.collection_metadata(collection) if batch_size is None: batch_size = 100000 if text_field is None: @@ -333,15 +327,12 @@ def _process_objects( docs = [self._text(o, text_field) for o in next_objs] metadatas = [self._dict(o) for o in next_objs] ids = [self._id(o, id_field) for o in next_objs] - embeddings = self._embedding_function(docs, cm.model) + embeddings = self._embedding_function(docs, cm.venomx.embedding_model.name) try: self.conn.execute("BEGIN TRANSACTION;") self.conn.executemany( sql_command, list(zip(ids, metadatas, embeddings, docs)) # noqa: B905 ) - # reason to block B905: codequality check - # blocking 3.11 because only code quality issue and 3.9 gives value error with keyword strict - # TODO: delete after PR#76 is merged self.conn.execute("COMMIT;") except Exception as e: self.conn.execute("ROLLBACK;") @@ -435,6 +426,68 @@ def _process_objects( finally: self.create_index(collection) + def update_or_create_venomx( + self, + venomx: Optional[Index], + collection: str, + model: str, + distance: str, + object_type: str, + embeddings_dimension: Optional[int], + ) -> Metadata: + """ + Updates an existing Index instance (venomx) with additional values or creates a new one if none is provided. + """ + # If venomx already exists, update its nested fields (as e.g. vec_dimension would not be given) + if venomx: + new_embedding_model = Model(name=model) + updated_index = venomx.model_copy(update={ # given venomx comes as venomx=Index() + "embedding_model": new_embedding_model, + "embeddings_dimensions": embeddings_dimension, + }) + + venomx = Metadata( + venomx=updated_index, + hnsw_space=distance, + object_type=object_type + ) + + else: + if distance is None: + distance = self.distance_metric + venomx = self.populate_venomx(collection, model, distance, object_type, embeddings_dimension) + + return venomx + + @staticmethod + def populate_venomx( + collection: Optional[str], + model: Optional[str], + distance: str, + object_type: str, + embeddings_dimension: int, + ) -> Metadata: + """ + Populate venomx with data currently given when inserting + + :param collection: + :param model: + :param distance: + :param object_type: + :param embeddings_dimension: + :return: + """ + venomx = Metadata( + venomx=Index( + id=collection, + embedding_model=Model(name=model), + embeddings_dimensions=embeddings_dimension, + ), + hnsw_space=distance, + object_type=object_type + ) + return venomx + def remove_collection(self, collection: str = None, exists_ok=False, **kwargs): """ Remove the collection from the database @@ -513,7 +566,7 @@ def _search( logger.info(f"Collection metadata={cm}") if model is None: if cm: - model = cm.model + model = cm.venomx.embedding_model.name if model is None: model = self.default_model logger.info(f"Model={model}") @@ -575,9 +628,9 @@ def _diversified_search( where_clause = " AND ".join(where_conditions) if where_clause: where_clause = f"WHERE {where_clause}" - query_embedding = self._embedding_function(text, model=cm.model) + query_embedding = self._embedding_function(text, model=cm.venomx.embedding_model.name) safe_collection_name = f'"{collection}"' - vec_dimension = self._get_embedding_dimension(cm.model) + vec_dimension = self._get_embedding_dimension(cm.venomx.embedding_model.name) results = self.conn.execute( f""" SELECT *, array_distance(embeddings::FLOAT[{vec_dimension}], @@ -610,7 +663,7 @@ def list_collection_names(self): def collection_metadata( self, collection_name: Optional[str] = None, include_derived=False, **kwargs - ) -> Optional[CollectionMetadata]: + ) -> Optional[Metadata]: """ Get the metadata for the collection :param collection_name: @@ -622,16 +675,17 @@ def collection_metadata( safe_collection_name = f'"{collection_name}"' try: result = self.conn.execute( - f"SELECT metadata FROM {safe_collection_name} WHERE id = '__metadata__'" + f"SELECT metadata FROM {safe_collection_name} WHERE id = '__venomx__'" ).fetchone() if result: metadata = json.loads(result[0]) - metadata_instance = CollectionMetadata(**metadata) + metadata = Metadata(**metadata) + # metadata = result[0] if include_derived: # not implemented yet # metadata_instance.object_count = compute_object_count(collection_name pass - return metadata_instance + return metadata except Exception as e: logger.error(f"Failed to retrieve metadata for collection {collection_name}: {str(e)}") return None @@ -645,29 +699,48 @@ def update_collection_metadata(self, collection: str, **kwargs): :param kwargs: :return: """ + if not collection: raise ValueError("Collection name must be provided.") - current_metadata = self.collection_metadata(collection) - if current_metadata is None: - current_metadata = CollectionMetadata(**kwargs) + metadata = self.collection_metadata(collection) + current_venomx = {**kwargs} + if metadata is None: # should not be possible + logger.warning(f"No existing metadata found for collection {collection}. Initializing new metadata.") + metadata = Metadata(venomx=Index(**current_venomx)) + else: + metadata_dict = metadata.model_dump(exclude_none=True) + # Check if the existing venomx has an embedding model and if it matches the one in kwargs + if 'venomx' in metadata_dict and metadata_dict['venomx'].get('embedding_model'): + existing_model_name = metadata_dict['venomx']['embedding_model'].get('name') + new_model_name = current_venomx.get('embedding_model', {}).get('name') + + if new_model_name and existing_model_name and new_model_name != existing_model_name: + raise ValueError( + f"Cannot change the embedding model name from '{existing_model_name}' to '{new_model_name}'. " + f"Model dimensions are incompatible with changes to the model." + ) + + # Merge current_venomx (from kwargs) into the nested venomx dictionary + if 'venomx' in metadata_dict and isinstance(metadata_dict['venomx'], dict): + metadata_dict['venomx'].update(current_venomx) else: - for key, value in kwargs.items(): - if hasattr(current_metadata, key): - setattr(current_metadata, key, value) - metadata_dict = current_metadata.dict(exclude_none=True) - metadata_json = json.dumps(metadata_dict) + metadata_dict['venomx'] = current_venomx + # Reconstruct the Metadata object from the updated dictionary + metadata = Metadata(**metadata_dict) + updated_metadata_dict = metadata.model_dump(exclude_none=True) + safe_collection_name = f'"{collection}"' self.conn.execute( f""" UPDATE {safe_collection_name} SET metadata = ? - WHERE id = '__metadata__' + WHERE id = '__venomx__' """, - [metadata_json], + [updated_metadata_dict], ) - return current_metadata + return metadata def set_collection_metadata( - self, collection_name: Optional[str], metadata: CollectionMetadata, **kwargs + self, collection_name: Optional[str], metadata: Metadata, **kwargs ): """ Set the metadata for the collection @@ -679,15 +752,28 @@ def set_collection_metadata( if collection_name is None: raise ValueError("Collection name must be provided.") - metadata_json = json.dumps(metadata.dict(exclude_none=True)) + current_metadata = self.collection_metadata(collection_name) + + if metadata: + if metadata.venomx.id != collection_name: + raise ValueError(f"venomx.id: {metadata.venomx.id} must match collection_name {collection_name}") + + new_model = metadata.venomx.embedding_model.name + + prev_model = current_metadata.venomx.embedding_model.name + if prev_model and new_model != prev_model: + raise ValueError(f"Cannot change model from {prev_model} to {new_model}") + + # metadata_json = json.dumps(metadata.dict(exclude_none=True)) + metadata = metadata.model_dump(exclude_none=True) safe_collection_name = f'"{collection_name}"' self.conn.execute( f""" UPDATE {safe_collection_name} SET metadata = ? - WHERE id = '__metadata__' + WHERE id = '__venomx__' """, - [metadata_json], + [metadata], ) def find( @@ -991,7 +1077,7 @@ def parse_duckdb_result(results, include) -> Iterator[SEARCH_RESULT]: ---------- """ for res in results: - if res[0] != "__metadata__": + if res[0] != "__metadata__" and res[0] != "__venomx__": D = DuckDBSearchResult( ids=res[0], metadatas=json.loads(res[1]), diff --git a/src/curategpt/store/duckdb_connection_handler.py b/src/curategpt/store/duckdb_connection_handler.py new file mode 100644 index 0000000..5e05c83 --- /dev/null +++ b/src/curategpt/store/duckdb_connection_handler.py @@ -0,0 +1,107 @@ +import logging +import os +import re +from pathlib import Path +from typing import Optional + +import duckdb + +logger = logging.getLogger(__name__) + + +class DuckDBConnectionAndRecoveryHandler: + def __init__(self, path: str): + self.path = self._setup_path(path) + self.conn: Optional[duckdb.DuckDBPyConnection] = None + + @staticmethod + def _setup_path(path: str) -> str: + """Handle path setup logic.""" + if not path: + path = "./db/db_file.duckdb" + if os.path.isdir(path): + path = os.path.join("./db", path, "db_file.duckdb") + os.makedirs(os.path.dirname(path), exist_ok=True) + logger.info( + f"Path {path} is a directory. Using {path} as the database path\n\ + as duckdb needs a file path" + ) + return path + + @staticmethod + def _kill_process(pid: int) -> None: + """Kill a process if it's holding the database lock.""" + try: + import psutil + if psutil.pid_exists(pid): + process = psutil.Process(pid) + process.terminate() + process.wait(timeout=5) + logger.info(f"Successfully terminated process {pid}") + except Exception as e: + logger.warning(f"Failed to kill process {pid}: {e}") + + @staticmethod + def _load_vss_extensions(conn: duckdb.DuckDBPyConnection) -> None: + """Load VSS extensions for a connection.""" + conn.execute("INSTALL vss;") + conn.execute("LOAD vss;") + conn.execute("SET hnsw_enable_experimental_persistence=true;") + + def connect(self) -> duckdb.DuckDBPyConnection: + """ + Establish database connection with error handling and recovery. + + Workflow as described in: + https://duckdb.org/docs/extensions/vss.html#persistence + + In case of any WAL related issue: + - Create a temporary workspace (in-memory database with VSS) + - Temporarily bring in the broken database (ATTACH) + - Fix it (WAL recovery happens) + - Save changes (CHECKPOINT) + - Put the fixed database back (DETACH) + - Clean up our temporary workspace (close) + - Now safely open the fixed database normally + + """ + wal_path = Path(self.path + '.wal') + if wal_path.exists(): + logger.info("Found WAL file, attempting recovery...") + try: + temp_conn = duckdb.connect(':memory:') + self._load_vss_extensions(temp_conn) + temp_conn.execute(f"ATTACH '{self.path}' AS main_db") + temp_conn.execute("CHECKPOINT;") + temp_conn.execute("DETACH main_db") + temp_conn.close() + except Exception as e: + logger.warning(f"WAL recovery attempt failed: {e}") + + try: + self.conn = duckdb.connect(self.path, read_only=False) + + except duckdb.Error as e: + match = re.search(r"PID (\d+)", str(e)) + if match: + pid = int(match.group(1)) + logger.info(f"Got {e}. Attempting to kill process with PID: {pid}") + self._kill_process(pid) + self.conn = duckdb.connect(self.path, read_only=False) + else: + logger.error(f"Connection error without PID information: {e}") + raise + + self._load_vss_extensions(self.conn) + + return self.conn + + def close(self) -> None: + """Safely close the database connection.""" + if self.conn: + try: + self.conn.execute("CHECKPOINT;") + self.conn.close() + logger.info("Database connection closed successfully") + except Exception as e: + logger.error(f"Error closing database connection: {e}") diff --git a/src/curategpt/store/in_memory_adapter.py b/src/curategpt/store/in_memory_adapter.py index e92bee3..736610b 100644 --- a/src/curategpt/store/in_memory_adapter.py +++ b/src/curategpt/store/in_memory_adapter.py @@ -2,13 +2,14 @@ import logging from dataclasses import dataclass, field -from typing import ClassVar, Dict, Iterable, Iterator, List, Optional, Union, get_origin +from typing import ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Union, get_origin from pydantic import BaseModel, ConfigDict +from venomx.model.venomx import Index, Model, ModelInputMethod from curategpt import DBAdapter from curategpt.store.db_adapter import OBJECT, PROJECTION, QUERY, SEARCH_RESULT -from curategpt.store.metadata import CollectionMetadata +from curategpt.store.metadata import Metadata logger = logging.getLogger(__name__) @@ -23,6 +24,9 @@ class Collection(BaseModel): def add(self, object: Dict) -> None: self.objects.append(object) + def add_metadata(self, venomx: Metadata) -> None: + self.metadata.update(venomx) + def delete(self, key_value: str, key: str) -> None: self.objects = [obj for obj in self.objects if obj[key] != key_value] @@ -61,22 +65,22 @@ def _get_collection_object(self, collection_name: str) -> Collection: collection_obj = self.collection_index.get_collection(self._get_collection(collection_name)) return collection_obj - def insert(self, objs: Union[OBJECT, Iterable[OBJECT]], collection: str = None, **kwargs): + def update(self, objs: Union[OBJECT, List[OBJECT]], collection: str = None, **kwargs): """ - Insert an object or list of objects into the store. + Update an object or list of objects in the store. :param objs: :param collection: :return: """ collection_obj = self._get_collection_object(collection) - if get_origin(type(objs)) is not Dict: + if isinstance(objs, OBJECT): objs = [objs] collection_obj.add(objs) - def update(self, objs: Union[OBJECT, List[OBJECT]], collection: str = None, **kwargs): + def upsert(self, objs: Union[OBJECT, List[OBJECT]], collection: str = None, **kwargs): """ - Update an object or list of objects in the store. + Upsert an object or list of objects in the store. :param objs: :param collection: @@ -87,18 +91,65 @@ def update(self, objs: Union[OBJECT, List[OBJECT]], collection: str = None, **kw objs = [objs] collection_obj.add(objs) - def upsert(self, objs: Union[OBJECT, List[OBJECT]], collection: str = None, **kwargs): + def insert(self, objs: Union[OBJECT, Iterable[OBJECT]], collection: str = None, **kwargs): """ - Upsert an object or list of objects in the store. + Insert an object or list of objects into the store. :param objs: :param collection: :return: """ + self._insert(objs, collection, **kwargs) + + + def _insert( + self, + objs: Union[OBJECT, Iterable[OBJECT]], + collection: str = None, + venomx: Metadata = None + ): collection_obj = self._get_collection_object(collection) - if isinstance(objs, OBJECT): + if venomx is None: + venomx = self.populate_venomx( + collection=collection, + ) + if get_origin(type(objs)) is not Dict: objs = [objs] + collection_obj.add(objs) + collection_obj.add_metadata(venomx) + + @staticmethod + def populate_venomx( + collection: Optional[str], + model: Optional[str] = None, + distance: str = None, + object_type: str = None, + embeddings_dimension: int = None, + index_fields: Optional[Union[List[str], Tuple[str]]] = None, + ) -> Metadata: + """ + Populate venomx with data currently given when inserting + + :param collection: + :param model: + :param distance: + :param object_type: + :param embeddings_dimension: + :param index_fields: + :return: + """ + venomx = Metadata( + venomx=Index( + id=collection, + embedding_model=Model(name=model), + embeddings_dimensions=embeddings_dimension, + embedding_input_method=ModelInputMethod(fields=index_fields) if index_fields else None + ), + hnsw_space=distance, + object_type=object_type + ) + return venomx def delete(self, id: str, collection: str = None, **kwargs): """ @@ -134,7 +185,7 @@ def list_collection_names(self) -> List[str]: def collection_metadata( self, collection_name: Optional[str] = None, include_derived=False, **kwargs - ) -> Optional[CollectionMetadata]: + ) -> Optional[Metadata]: """ Get the metadata for a collection. @@ -144,13 +195,13 @@ def collection_metadata( """ collection_obj = self._get_collection_object(collection_name) md_dict = collection_obj.metadata - cm = CollectionMetadata(**md_dict) + cm = Metadata(**md_dict) if include_derived: cm.object_count = len(collection_obj.objects) return cm def set_collection_metadata( - self, collection_name: Optional[str], metadata: CollectionMetadata, **kwargs + self, collection_name: Optional[str], metadata: Metadata, **kwargs ): """ Set the metadata for a collection. @@ -159,9 +210,13 @@ def set_collection_metadata( :return: """ collection_obj = self._get_collection_object(collection_name) - collection_obj.metadata = metadata.dict() + # TODO: allow for now, as now embed functionality + # if metadata.venomx.id != collection_name: + # raise ValueError(f"venomx.id: {metadata.venomx.id} must match collection_name {collection_name} and should not be changed") + collection_obj.metadata = metadata.model_dump(exclude_none=True) + - def update_collection_metadata(self, collection_name: str, **kwargs) -> CollectionMetadata: + def update_collection_metadata(self, collection_name: str, **kwargs) -> Metadata: """ Update the metadata for a collection. diff --git a/src/curategpt/store/metadata.py b/src/curategpt/store/metadata.py index df3433b..9d1dbd6 100644 --- a/src/curategpt/store/metadata.py +++ b/src/curategpt/store/metadata.py @@ -1,38 +1,88 @@ +import json from typing import Dict, Optional from pydantic import BaseModel, ConfigDict +from venomx.model.venomx import Index +""" +ChromaDB Constraints: + Metadata Must Be Scalar: ChromaDB only accepts metadata values that are scalar types (str, int, float, bool). + No None Values: Metadata fields cannot have None as a value. +DuckDB Capabilities: + Nested Objects Supported: DuckDB can handle nested objects directly within metadata. +""" -class CollectionMetadata(BaseModel): - """ - Metadata about a collection. - - This is an open class, so additional metadata can be added. - """ +class Metadata(BaseModel): model_config = ConfigDict(protected_namespaces=()) - name: Optional[str] = None - """Name of the collection""" + # Application-level field for 'duckdb' and to keep pydantic advantages in code for 'chromadb' + venomx: Optional[Index] = None + """ + Retains the complex venomx Index object for internal application use. + Index is the main object of venomx + https://github.com/cmungall/venomx + """ - description: Optional[str] = None - """Description of the collection""" + # Serialized field to store venomx in adapters that require scalar metadata (e.g., 'chromadb') + _venomx: Optional[str] = None + """Stores the serialized JSON string of the venomx object for ChromaDB.""" - model: Optional[str] = None - """Name of any ML model""" + hnsw_space: Optional[str] = None + """Space used for hnsw index (e.g. 'cosine')""" object_type: Optional[str] = None """Type of object in the collection""" - source: Optional[str] = None - """Source of the collection""" + object_count: Optional[int] = None + + @classmethod + def deserialize_venomx_metadata_from_adapter(cls, metadata_dict: dict, adapter: str) -> Dict: + """ + Create a Metadata instance from adapter-specific metadata dictionary. + ChromaDB: _venomx is deserialized back into venomx. (str to dict) + DuckDB: venomx is accessed directly as a nested object. + :param metadata_dict: Metadata dictionary from the adapter. + :param adapter: Adapter name (e.g., 'chroma', 'duckdb'). + :return: Metadata instance. + """ + if adapter == 'chromadb': + # Deserialize '_venomx' (str) back into 'venomx' (dict) + if "_venomx" in metadata_dict: + venomx_json = metadata_dict.pop("_venomx") + metadata_dict["venomx"] = Index(**json.loads(venomx_json)) + # for 'duckdb', 'venomx' remains as is + if adapter == 'duckdb': + metadata_dict = metadata_dict + return cls(**metadata_dict) + + def serialize_venomx_metadata_for_adapter(self, adapter: str) -> dict: + """ + Convert the Metadata instance to a dictionary suitable for the specified adapter. + ChromaDB: venomx is serialized into _venomx before storing. (dict to str) + DuckDB: venomx remains as an Index object without serialization. + :param adapter: Adapter name (e.g., 'chroma', 'duckdb'). + :return: Metadata dictionary. + """ + if adapter == 'chromadb': + # Serialize 'venomx' (dict) into '_venomx' (str) + metadata_dict = self.model_dump( + exclude={"venomx"}, + exclude_unset=True, + exclude_none=True + ) + if self.venomx: + metadata_dict["_venomx"] = json.dumps(self.venomx.model_dump()) + return metadata_dict + elif adapter == 'duckdb': + metadata_dict = self.model_dump( + exclude={"_venomx"}, + exclude_unset=True, + exclude_none=True + ) + return metadata_dict + else: + raise ValueError(f"Unsupported adapter: {adapter}") - # DEPRECATED - annotations: Optional[Dict] = None - """Additional metadata""" - object_count: Optional[int] = None - """Number of objects in the collection""" - hnsw_space: Optional[str] = None - """Space used for hnsw index (e.g. 'cosine')""" diff --git a/tests/store/test_chromadb_adapter.py b/tests/store/test_chromadb_adapter.py index a05d84a..8177300 100644 --- a/tests/store/test_chromadb_adapter.py +++ b/tests/store/test_chromadb_adapter.py @@ -44,48 +44,71 @@ def simple_schema_manager() -> SchemaProxy: ) return SchemaProxy(sb.schema) - -def test_store(simple_schema_manager, example_texts): +@pytest.mark.parametrize( + "model, requires_key, change_field, expected_error", + [ + pytest.param("openai:", True, "model", True, marks=requires_openai_api_key), + ("all-MiniLM-L6-v2", False, "model", True), + (None, False, "model", True), + pytest.param("openai:", True, "id", True, marks=requires_openai_api_key), + ("all-MiniLM-L6-v2", False, "id", True), + (None, False, "id", True), + ], +) +def test_store(simple_schema_manager, example_texts, model, change_field, expected_error, requires_key): db = ChromaDBAdapter(str(OUTPUT_CHROMA_DB_PATH)) db.schema_proxy = simple_schema_manager db.client.reset() assert db.list_collection_names() == [] collection = "test" objs = terms_to_objects(example_texts) - db.insert(objs, collection=collection) - md = db.collection_metadata(collection) - md.description = "test collection" - db.set_collection_metadata(collection, md) - assert db.collection_metadata(collection).description == "test collection" - db2 = ChromaDBAdapter(str(OUTPUT_CHROMA_DB_PATH)) - assert db2.collection_metadata(collection).description == "test collection" - assert db.list_collection_names() == ["test"] - results = list(db.search("fox", collection=collection)) - print(results) - for obj in objs: - print(f"QUERYING: {obj}") - for match in db.matches(obj, collection=collection): - print(f" - MATCH: {match}") - db.update(objs, collection=collection) - assert db.collection_metadata(collection).description == "test collection" - canines = list(db.find(where={"text": {"$eq": "canine"}}, collection=collection)) - print(f"CANINES: {canines}") - long_words = list(db.find(where={"wordlen": {"$gt": 12}}, collection=collection)) - print(long_words) - assert len(long_words) == 2 - db.remove_collection(collection) - db.insert(objs, collection=collection) - results2 = list(db.search("fox", collection=collection)) + if model: + db.insert(objs, collection=collection, model=model) + else: + db.insert(objs, collection=collection) - def _id(obj, _dist, _meta): - return obj["id"] + md = db.collection_metadata(collection) - assert _id(*results[0]) == _id(*results2[0]) - limit = 5 - results2 = list(db.find({}, limit=5, collection=collection)) - assert len(results2) == limit - results2 = list(db.find({}, limit=10000000, collection=collection)) - assert len(results2) > limit + if change_field == "model": + if model == "openai:": + new_model = "all-MiniLM-L6-v2" + else: + new_model = "openai:" + md.venomx.embedding_model.name = new_model + elif change_field == "id": + md.venomx.id = "different_collection_name" + + if expected_error: + with pytest.raises(ValueError): + db.set_collection_metadata(collection, md) + else: + # case: no error + db.set_collection_metadata(collection, md) + assert md.venomx.id == db.collection_metadata(collection).venomx.id + assert md.venomx.id == collection + assert db.collection_metadata(collection).venomx.id == collection + + results = list(db.search("fox", collection=collection)) + results2 = list(db.search("fox", collection=collection)) + def _id(obj, _dist, _meta): + return obj["id"] + + assert _id(*results[0]) == _id(*results2[0]) + + db.remove_collection(collection) + db.update(objs, collection=collection) + canines = list(db.find(where={"text": {"$eq": "canine"}}, collection=collection)) + print(f"CANINES: {canines}") + long_words = list(db.find(where={"wordlen": {"$gt": 12}}, collection=collection)) + print(long_words) + assert len(long_words) == 2 + + limit = 5 + results2 = list(db.find({}, limit=limit, collection=collection)) + assert len(results2) == limit, f"Expected {limit} results, but got {len(results2)}" + limit = 10000 + results2 = list(db.find({}, limit=limit, collection=collection)) + assert len(results2) > limit, f"Expected more than {limit} results, but got {len(results2)}" def test_fetch_all_memory_safe(example_texts): @@ -125,16 +148,14 @@ def test_embedding_function(simple_schema_manager, example_texts): db.insert(objs[1:]) db.insert(objs[1:], collection="default_ef", model=None) db.insert(objs[1:], collection="openai", model="openai:") - assert db.collection_metadata("default_ef").name == "default_ef" - assert db.collection_metadata("openai").name == "openai" - assert db.collection_metadata(None).model == "all-MiniLM-L6-v2" - assert db.collection_metadata("default_ef").model == "all-MiniLM-L6-v2" - assert db.collection_metadata("openai").model == "openai:" + assert db.collection_metadata("default_ef").venomx.id == "default_ef" + assert db.collection_metadata("openai").venomx.id == "openai" + assert db.collection_metadata(None).venomx.embedding_model.name == "all-MiniLM-L6-v2" + assert db.collection_metadata("default_ef").venomx.embedding_model.name == "all-MiniLM-L6-v2" + assert db.collection_metadata("openai").venomx.embedding_model.name == "openai:" db.insert([objs[0]]) db.insert([objs[0]], collection="default_ef") db.insert([objs[0]], collection="openai") - assert db.collection_metadata("default_ef").model == "all-MiniLM-L6-v2" - assert db.collection_metadata("openai").model == "openai:" results_ef = list(db.search("fox", collection="default_ef")) results_oai = list(db.search("fox", collection="openai")) assert len(results_ef) > 0 diff --git a/tests/store/test_duckdb_adapter.py b/tests/store/test_duckdb_adapter.py index 3ba6a4b..0847091 100644 --- a/tests/store/test_duckdb_adapter.py +++ b/tests/store/test_duckdb_adapter.py @@ -55,14 +55,17 @@ def simple_schema_manager() -> SchemaProxy: @pytest.mark.parametrize( - "model, requires_key", + "model, requires_key, change_field, expected_error", [ - pytest.param("openai:", True, marks=requires_openai_api_key), - ("all-MiniLM-L6-v2", False), - (None, False), + pytest.param("openai:", True, "model", True, marks=requires_openai_api_key), + ("all-MiniLM-L6-v2", False, "model", True), + (None, False, "model", True), + pytest.param("openai:", True, "id", True, marks=requires_openai_api_key), + ("all-MiniLM-L6-v2", False, "id", True), + (None, False, "id", True), ], ) -def test_store_variations(simple_schema_manager, example_texts, model, requires_key): +def test_store_variations(simple_schema_manager, example_texts, model, requires_key, change_field, expected_error): db = DuckDBAdapter(OUTPUT_DUCKDB_PATH) for i in db.list_collection_names(): db.remove_collection(i) @@ -76,39 +79,44 @@ def test_store_variations(simple_schema_manager, example_texts, model, requires_ db.insert(objs, collection=collection) md = db.collection_metadata(collection) - md.description = "test collection" - db.set_collection_metadata(collection, md) - assert db.collection_metadata(collection).description == "test collection" - if model: - assert db.collection_metadata(collection).model == model + if change_field == "model": + # ensure changing model is not valid + if model == "openai:": + new_model = "all-MiniLM-L6-v2" + else: + new_model = "openai:" + md.venomx.embedding_model.name = new_model + elif change_field == "id": + # ensure changing venomx.id is not valid if inconsistent with collection name + md.venomx.id = "different_collection_name" + + if expected_error: + with pytest.raises(ValueError): + db.set_collection_metadata(collection, md) else: - assert db.collection_metadata(collection).model == "all-MiniLM-L6-v2" + db.set_collection_metadata(collection, md) - db2 = DuckDBAdapter(str(OUTPUT_DUCKDB_PATH)) - assert db2.collection_metadata(collection).description == "test collection" if model: - assert db2.collection_metadata(collection).model == model + assert db.collection_metadata(collection).venomx.embedding_model.name == model else: - assert db2.collection_metadata(collection).model == "all-MiniLM-L6-v2" - assert db.list_collection_names() == ["test_collection"] + assert db.collection_metadata(collection).venomx.embedding_model.name == "all-MiniLM-L6-v2" results = list(db.search("fox", collection=collection, include=["metadatas"])) if model: db.update(objs, collection=collection, model=model) else: db.update(objs, collection=collection) - assert db.collection_metadata(collection).description == "test collection" + assert db.collection_metadata(collection).venomx.id == collection long_words = list(db.find(where={"wordlen": {"$gt": 12}}, collection=collection)) assert len(long_words) == 2 + db.remove_collection(collection) if model: db.insert(objs, collection=collection, model=model) else: db.insert(objs, collection=collection) results2 = list(db.search("fox", collection=collection, include=["metadatas"])) - peek = list(db.fetch_all_objects_memory_safe(collection=collection, batch_size=2)) - assert len(peek) == 7 def _id(obj, dist, meta): return obj["id"] @@ -164,8 +172,8 @@ def test_the_embedding_function_variations( db.insert(objs, collection=collection, model=model) expected_model = model if model else "all-MiniLM-L6-v2" expected_name = collection - assert db.collection_metadata(collection).model == expected_model - assert db.collection_metadata(collection).name == expected_name + assert db.collection_metadata(collection).venomx.embedding_model.name == expected_model + assert db.collection_metadata(collection).venomx.id == expected_name assert db.collection_metadata(collection).hnsw_space == "cosine" @@ -282,7 +290,6 @@ def test_load_in_batches(ontology_db, batch_size): ontology_db.insert(view.objects(), batch_size=batch_size, collection="other_collection") # end = time.time() # print(f"Time to insert {len(list(view.objects()))} objects with batch of {batch_size}: {end - start}") - objs = list(ontology_db.find(collection="other_collection", limit=2000)) assert len(objs) > 100 diff --git a/tests/store/test_in_memory_adapter.py b/tests/store/test_in_memory_adapter.py index 2185fcd..6f02fbf 100644 --- a/tests/store/test_in_memory_adapter.py +++ b/tests/store/test_in_memory_adapter.py @@ -48,6 +48,6 @@ def test_store(simple_schema_manager, example_texts): objs = terms_to_objects(example_texts) db.insert(objs, collection=collection) md = db.collection_metadata(collection) - md.description = "test collection" + md.venomx.id = "test collection" db.set_collection_metadata(collection, md) - assert db.collection_metadata(collection).description == "test collection" + assert db.collection_metadata(collection).venomx.id == "test collection" diff --git a/tests/wrappers/test_ontology.py b/tests/wrappers/test_ontology.py index fb45ffc..d518514 100644 --- a/tests/wrappers/test_ontology.py +++ b/tests/wrappers/test_ontology.py @@ -1,9 +1,11 @@ import logging +from pathlib import Path from pprint import pprint import pytest from oaklib import get_adapter from oaklib.datamodels.obograph import GraphDocument +from venomx.model.venomx import Dataset, Index, Model, ModelInputMethod from curategpt.extract import BasicExtractor from curategpt.wrappers.ontology.ontology_wrapper import OntologyWrapper @@ -21,6 +23,78 @@ logger = logging.root logger.setLevel(logging.DEBUG) +# for debugging meanwhile implementing +def test_insert_without_venomx(): + db = setup_db(Path("../db")) + collection_name = "test_collection_without_venomx_set_upfront" + + extractor = BasicExtractor() + adapter = get_adapter(INPUT_DIR / "go-nucleus.db") + wrapper = OntologyWrapper(oak_adapter=adapter, local_store=db, extractor=extractor) + + # venomx = Index( + # id=f"{collection_name}", + # dataset=Dataset( + # name="test_ont_hp" + # ), + # embedding_model=Model( + # name=db.default_model + # ), + # embedding_input_method=ModelInputMethod( + # fields=['label'] + # ) + # ) + # + # print(venomx) + + + db.insert( + wrapper.objects(), + collection=collection_name, + # venomx=venomx + ) + names = db.list_collection_names() + print(names) + + col = db.client.get_collection(f"{collection_name}") + metadata = col.metadata + print(metadata) + +# for debugging meanwhile implementing +def test_insert_with_venomx(): + db = setup_db(Path("../db")) + collection_name = "test_collection_with_venomx_set_upfront" + + extractor = BasicExtractor() + adapter = get_adapter(INPUT_DIR / "go-nucleus.db") + wrapper = OntologyWrapper(oak_adapter=adapter, local_store=db, extractor=extractor) + + venomx = Index( + id=f"{collection_name}", + dataset=Dataset( + name="test_ont_hp" + ), + embedding_model=Model( + name=db.default_model + ), + embedding_input_method=ModelInputMethod( + fields=['label'] + ) + ) + + # print(venomx) + # print("\n\n", venomx.id ,"\n\n") + + db.insert( + wrapper.objects(), + collection=collection_name, + venomx=venomx + ) + names = db.list_collection_names() + print(names) + col = db.client.get_collection(f"{collection_name}") + metadata = col.metadata + print(metadata) @pytest.fixture def vstore(request, tmp_path): @@ -33,7 +107,7 @@ def vstore(request, tmp_path): try: wrapper = OntologyWrapper(oak_adapter=adapter, local_store=db, extractor=extractor) db.insert(wrapper.objects()) - yield wrapper + yield wrapper, db except Exception as e: raise e finally: @@ -44,11 +118,12 @@ def vstore(request, tmp_path): @pytest.mark.parametrize('vstore', [TEMP_OAK_OBJ], indirect=True) def test_oak_objects(vstore): """Test that the objects are extracted from the oak adapter.""" - objs = list(vstore.objects()) + wrapper, _ = vstore + objs = list(wrapper.objects()) [nucleus] = [obj for obj in objs if obj["id"] == "Nucleus"] assert nucleus["label"] == "nucleus" assert nucleus["original_id"] == "GO:0005634" - reversed = vstore.unwrap_object(nucleus, store=vstore.local_store) + reversed = wrapper.unwrap_object(nucleus, store=wrapper.local_store) nucleus = reversed.graphs[0].nodes[0] assert nucleus["lbl"] == "nucleus" assert nucleus["id"] == "GO:0005634" @@ -58,14 +133,15 @@ def test_oak_objects(vstore): @pytest.mark.parametrize('vstore', [TEMP_OAK_IND], indirect=True) def test_oak_index(vstore): """Test that the objects are indexed in the local store.""" - g = vstore.unwrap_object( + wrapper, _ = vstore + g = wrapper.unwrap_object( { "id": "Nucleus", "label": "nucleus", "relationships": [{"predicate": "rdfs:subClassOf", "target": "Organelle"}], "original_id": "GO:0005634", }, - store=vstore.local_store, + store=wrapper.local_store, ) if isinstance(g, GraphDocument): pprint(g.__dict__, width=100, indent=2) @@ -84,6 +160,7 @@ def test_oak_index(vstore): @requires_openai_api_key def test_oak_search(vstore): """Test that the objects are indexed and searchable in the local store.""" - results = list(vstore.search("nucl")) + _, db = vstore + results = list(db.search("nucl")) assert len(results) > 0 assert any("nucleus" in result[0]["label"] for result in results)