Skip to content

Commit

Permalink
Used range_search for supported indexes
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Nov 15, 2024
1 parent b5b1561 commit 2b9f234
Showing 1 changed file with 37 additions and 5 deletions.
42 changes: 37 additions & 5 deletions vicinity/backends/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,44 @@ def delete(self, indices: list[int]) -> None:
raise NotImplementedError("This FAISS index type does not support deletion.")

def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
"""Query vectors within a distance threshold."""
"""Query vectors within a distance threshold, using range_search if supported."""
out: list[npt.NDArray] = []
distances, indices = self.index.search(vectors, 100)
for dist, idx in zip(distances, indices):
within_threshold = idx[dist < threshold]
out.append(within_threshold)

# Normalize query vectors if using cosine similarity
if self.arguments.metric == "cosine":
vectors = normalize(vectors)

if isinstance(
self.index, (faiss.IndexFlat, faiss.IndexIVFFlat, faiss.IndexScalarQuantizer, faiss.IndexIVFScalarQuantizer)
):
# Use range_search for supported indexes
radius = threshold
lims, D, I = self.index.range_search(vectors, radius)

for i in range(vectors.shape[0]):
start, end = lims[i], lims[i + 1]
indices = I[start:end]
distances = D[start:end]

# Convert distances for cosine if needed
if self.arguments.metric == "cosine":
distances = 1 - distances

# Only include indices within the threshold
within_threshold_indices = indices[distances < threshold]
out.append(within_threshold_indices)
else:
# Fallback to search-based filtering for indexes that do not support range_search
distances, indices = self.index.search(vectors, 100) # Arbitrarily large `k` to capture potential matches

for dist, idx in zip(distances, indices):
# Convert distances for cosine if needed
if self.arguments.metric == "cosine":
dist = 1 - dist
# Filter based on the threshold
within_threshold = idx[dist < threshold]
out.append(within_threshold)

return out

def save(self, base_path: Path) -> None:
Expand Down

0 comments on commit 2b9f234

Please sign in to comment.