Skip to content

Commit

Permalink
Merge pull request monarch-initiative#56 from monarch-initiative/fix-…
Browse files Browse the repository at this point in the history
…vec-dimension

Getting correct vector dimensions for search
  • Loading branch information
cmungall authored Aug 15, 2024
2 parents 76372d7 + 026a4dc commit a364507
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,11 +399,13 @@ def _search(
include = set(include)
collection = self._get_collection(collection)
cm = self.collection_metadata(collection)
logger.info(f"Collection metadata={cm}")
if model is None:
if cm:
model = cm.model
if model is None:
model = self.default_model
logger.info(f"Model={model}")
where_conditions = []
if where:
where_conditions.append(where)
Expand All @@ -418,6 +420,8 @@ def _search(
query_embedding = self._embedding_function(text, model)
safe_collection_name = f'"{collection}"'

vec_dimension = self._get_embedding_dimension(model)

# TODO: !VERY IMPORTANT! distance metrics between Chroma and DuckDB have very different, unclear implementations
# https://duckdb.org/docs/sql/functions/array.html#array_distancearray1-array2
# https://docs.trychroma.com/guides
Expand All @@ -426,8 +430,8 @@ def _search(
# than chromaDBs distance metric
results = self.conn.execute(
f"""
SELECT *, array_distance(embeddings::FLOAT[{self.vec_dimension}],
{query_embedding}::FLOAT[{self.vec_dimension}]) as distance
SELECT *, array_distance(embeddings::FLOAT[{vec_dimension}],
{query_embedding}::FLOAT[{vec_dimension}]) as distance
FROM {safe_collection_name}
{where_clause}
ORDER BY distance
Expand Down Expand Up @@ -462,10 +466,11 @@ def _diversified_search(
where_clause = f"WHERE {where_clause}"
query_embedding = self._embedding_function(text, model=cm.model)
safe_collection_name = f'"{collection}"'
vec_dimension = self._get_embedding_dimension(cm.model)
results = self.conn.execute(
f"""
SELECT *, array_distance(embeddings::FLOAT[{self.vec_dimension}],
{query_embedding}::FLOAT[{self.vec_dimension}]) as distance
SELECT *, array_distance(embeddings::FLOAT[{vec_dimension}],
{query_embedding}::FLOAT[{vec_dimension}]) as distance
FROM {safe_collection_name}
{where_clause}
ORDER BY distance
Expand Down

0 comments on commit a364507

Please sign in to comment.