Skip to content

Commit

Permalink
feat: Added scores to query threshold (#54)
Browse files Browse the repository at this point in the history
* Changed query threshold logic

* Updated method

* Used mask
  • Loading branch information
Pringled authored Jan 7, 2025
1 parent 2533a86 commit 2c35eac
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 45 deletions.
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions vicinity/backends/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def delete(self, indices: list[int]) -> None:
"""Delete vectors from the backend."""
raise NotImplementedError("Deletion is not supported in Annoy backend.")

def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
"""Threshold the backend."""
out: list[npt.NDArray] = []
for x, y in self.query(vectors, 100):
out.append(x[y < threshold])
out: QueryResult = []
for x, y in self.query(vectors, max_k):
mask = y < threshold
out.append((x[mask], y[mask]))
return out
2 changes: 1 addition & 1 deletion vicinity/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def delete(self, indices: list[int]) -> None:
raise NotImplementedError()

@abstractmethod
def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
"""Threshold the backend."""
raise NotImplementedError()

Expand Down
15 changes: 9 additions & 6 deletions vicinity/backends/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,22 +150,25 @@ def threshold(
self,
vectors: npt.NDArray,
threshold: float,
) -> list[npt.NDArray]:
max_k: int,
) -> QueryResult:
"""
Batched distance thresholding.
:param vectors: The vectors to threshold.
:param threshold: The threshold to use.
:return: A list of lists of indices of vectors that are below the threshold
:param max_k: The maximum number of neighbors to consider.
:return: A list of tuples with the indices and distances.
"""
out: list[npt.NDArray] = []
out: QueryResult = []
for i in range(0, len(vectors), 1024):
batch = vectors[i : i + 1024]
distances = self._dist(batch)
for dists in distances:
indices = np.flatnonzero(dists <= threshold)
sorted_indices = indices[np.argsort(dists[indices])]
out.append(sorted_indices)
mask = dists <= threshold
indices = np.flatnonzero(mask)
filtered_distances = dists[mask]
out.append((indices, filtered_distances))
return out

def query(
Expand Down
12 changes: 7 additions & 5 deletions vicinity/backends/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ def delete(self, indices: list[int]) -> None:
"""Delete vectors from the backend."""
raise NotImplementedError("Deletion is not supported in FAISS backends.")

def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
"""Query vectors within a distance threshold, using range_search if supported."""
out: list[npt.NDArray] = []
out: QueryResult = []
if self.arguments.metric == "cosine":
vectors = normalize(vectors)

Expand All @@ -179,13 +179,15 @@ def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]
dist = D[start:end]
if self.arguments.metric == "cosine":
dist = 1 - dist
out.append(idx[dist < threshold])
mask = dist < threshold
out.append((idx[mask], dist[mask]))
else:
distances, indices = self.index.search(vectors, 100)
distances, indices = self.index.search(vectors, max_k)
for dist, idx in zip(distances, indices):
if self.arguments.metric == "cosine":
dist = 1 - dist
out.append(idx[dist < threshold])
mask = dist < threshold
out.append((idx[mask], dist[mask]))

return out

Expand Down
9 changes: 5 additions & 4 deletions vicinity/backends/hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ def delete(self, indices: list[int]) -> None:
"""Delete vectors from the backend."""
raise NotImplementedError("Deletion is not supported in HNSW backend.")

def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
"""Threshold the backend."""
out: list[npt.NDArray] = []
for x, y in self.query(vectors, 100):
out.append(x[y < threshold])
out: QueryResult = []
for x, y in self.query(vectors, max_k):
mask = y < threshold
out.append((x[mask], y[mask]))

return out
12 changes: 6 additions & 6 deletions vicinity/backends/pynndescent.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def delete(self, indices: list[int]) -> None:
"""Delete vectors from the backend."""
raise NotImplementedError("Deletion is not supported in PyNNDescent backend.")

def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
"""Find neighbors within a distance threshold."""
normalized_vectors = normalize_or_copy(vectors)
indices, distances = self.index.query(normalized_vectors, k=100)
result = []
indices, distances = self.index.query(normalized_vectors, k=max_k)
out: QueryResult = []
for idx, dist in zip(indices, distances):
within_threshold = idx[dist < threshold]
result.append(within_threshold)
return result
mask = dist < threshold
out.append((idx[mask], dist[mask]))
return out

def save(self, base_path: Path) -> None:
"""Save the vectors and configuration to a specified path."""
Expand Down
15 changes: 9 additions & 6 deletions vicinity/backends/usearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,12 @@ def delete(self, indices: list[int]) -> None:
"""Delete vectors from the index (not supported by Usearch)."""
raise NotImplementedError("Dynamic deletion is not supported in Usearch.")

def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
"""Threshold the backend and return filtered keys."""
return [
np.array(keys_row)[np.array(distances_row, dtype=np.float32) < threshold]
for keys_row, distances_row in self.query(vectors, 100)
]
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
"""Query vectors within a distance threshold and return keys and distances."""
out: QueryResult = []
for keys_row, distances_row in self.query(vectors, max_k):
keys_row = np.array(keys_row)
distances_row = np.array(distances_row, dtype=np.float32)
mask = distances_row < threshold
out.append((keys_row[mask], distances_row[mask]))
return out
9 changes: 5 additions & 4 deletions vicinity/backends/voyager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ def delete(self, indices: list[int]) -> None:
"""Delete vectors from the backend."""
raise NotImplementedError("Deletion is not supported in Voyager backend.")

def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> QueryResult:
"""Threshold the backend."""
out: list[npt.NDArray] = []
for x, y in self.query(vectors, len(self)):
out.append(x[y < threshold])
out: list[tuple[npt.NDArray, npt.NDArray]] = []
for x, y in self.query(vectors, max_k):
mask = y < threshold
out.append((x[mask], y[mask]))

return out

Expand Down
20 changes: 12 additions & 8 deletions vicinity/vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from vicinity import Metric
from vicinity.backends import AbstractBackend, BasicBackend, BasicVectorStore, get_backend_class
from vicinity.datatypes import Backend, PathLike
from vicinity.datatypes import Backend, PathLike, QueryResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,7 +114,7 @@ def query(
self,
vectors: npt.NDArray,
k: int = 10,
) -> list[list[tuple[str, float]]]:
) -> list[QueryResult]:
"""
Find the nearest neighbors to some arbitrary vector.
Expand All @@ -140,22 +140,26 @@ def query_threshold(
self,
vectors: npt.NDArray,
threshold: float = 0.5,
) -> list[list[str]]:
max_k: int = 100,
) -> list[QueryResult]:
"""
Find the nearest neighbors to some arbitrary vector with some threshold.
Find the nearest neighbors to some arbitrary vector with some threshold. Note: the output is not sorted.
:param vectors: The vectors to find the most similar vectors to.
:param threshold: The threshold to use.
:param max_k: The maximum number of neighbors to consider for the threshold query.
:return: For each item in the input, all items above the threshold are returned.
:return: For each item in the input, the items above the threshold are returned in the form of
(NAME, SIMILARITY) tuples.
"""
vectors = np.array(vectors)
vectors = np.asarray(vectors)
if np.ndim(vectors) == 1:
vectors = vectors[None, :]

out = []
for indexes in self.backend.threshold(vectors, threshold):
out.append([self.items[idx] for idx in indexes])
for indices, distances in self.backend.threshold(vectors, threshold, max_k=max_k):
distances.clip(min=0, out=distances)
out.append([(self.items[idx], dist) for idx, dist in zip(indices, distances)])

return out

Expand Down

0 comments on commit 2c35eac

Please sign in to comment.