From 026a4dcd5d781351495da2aa47b9d1fca80a810d Mon Sep 17 00:00:00 2001 From: Chris Mungall Date: Thu, 15 Aug 2024 14:57:18 -0700 Subject: [PATCH] Getting correct vector dimensions for search --- src/curate_gpt/store/duckdb_adapter.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/curate_gpt/store/duckdb_adapter.py b/src/curate_gpt/store/duckdb_adapter.py index 729f728..ef98151 100644 --- a/src/curate_gpt/store/duckdb_adapter.py +++ b/src/curate_gpt/store/duckdb_adapter.py @@ -399,11 +399,13 @@ def _search( include = set(include) collection = self._get_collection(collection) cm = self.collection_metadata(collection) + logger.info(f"Collection metadata={cm}") if model is None: if cm: model = cm.model if model is None: model = self.default_model + logger.info(f"Model={model}") where_conditions = [] if where: where_conditions.append(where) @@ -418,6 +420,8 @@ def _search( query_embedding = self._embedding_function(text, model) safe_collection_name = f'"{collection}"' + vec_dimension = self._get_embedding_dimension(model) + # TODO: !VERY IMPORTANT! distance metrics between Chroma and DuckDB have very different, unclear implementations # https://duckdb.org/docs/sql/functions/array.html#array_distancearray1-array2 # https://docs.trychroma.com/guides @@ -426,8 +430,8 @@ def _search( # than chromaDBs distance metric results = self.conn.execute( f""" - SELECT *, array_distance(embeddings::FLOAT[{self.vec_dimension}], - {query_embedding}::FLOAT[{self.vec_dimension}]) as distance + SELECT *, array_distance(embeddings::FLOAT[{vec_dimension}], + {query_embedding}::FLOAT[{vec_dimension}]) as distance FROM {safe_collection_name} {where_clause} ORDER BY distance @@ -462,10 +466,11 @@ def _diversified_search( where_clause = f"WHERE {where_clause}" query_embedding = self._embedding_function(text, model=cm.model) safe_collection_name = f'"{collection}"' + vec_dimension = self._get_embedding_dimension(cm.model) results = self.conn.execute( f""" - SELECT *, array_distance(embeddings::FLOAT[{self.vec_dimension}], - {query_embedding}::FLOAT[{self.vec_dimension}]) as distance + SELECT *, array_distance(embeddings::FLOAT[{vec_dimension}], + {query_embedding}::FLOAT[{vec_dimension}]) as distance FROM {safe_collection_name} {where_clause} ORDER BY distance