Skip to content

Commit

Permalink
feat: Add evaluator (#31)
Browse files Browse the repository at this point in the history
* Added euclidean metric to basic backend

* Switched to mixins

* Updates

* Updates

* Aligned metrics

* Update

* WIP

* Update

* Update

* Added test

* Updates

* Fixed supported metric issue

* Update
  • Loading branch information
Pringled authored Dec 1, 2024
1 parent ba2bc67 commit 92d3bb8
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 1 deletion.
20 changes: 20 additions & 0 deletions tests/test_vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,23 @@ def test_vicinity_delete_and_query(vicinity_instance: Vicinity, items: list[str]

# Check that the queried item is in the results
assert "item3" in returned_items


def test_vicinity_evaluate(vicinity_instance: Vicinity, vectors: np.ndarray) -> None:
"""
Test the evaluate method of the Vicinity instance.
:param vicinity_instance: A Vicinity instance.
:param vectors: The full dataset vectors used to build the index.
"""
query_vectors = vectors[:10]
qps, recall = vicinity_instance.evaluate(vectors, query_vectors)

# Ensure the QPS and recall values are within valid ranges
assert qps > 0
assert 0 <= recall <= 1

# Test with an unsupported metric
vicinity_instance.backend.arguments.metric = "manhattan"
with pytest.raises(ValueError):
vicinity_instance.evaluate(vectors, query_vectors)
76 changes: 75 additions & 1 deletion vicinity/vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import time
from io import open
from pathlib import Path
from time import perf_counter
from typing import Any, Sequence, Union

import numpy as np
import orjson
from numpy import typing as npt

from vicinity.backends import AbstractBackend, get_backend_class
from vicinity import Metric
from vicinity.backends import AbstractBackend, BasicBackend, get_backend_class
from vicinity.datatypes import Backend, PathLike

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,6 +85,11 @@ def dim(self) -> int:
"""The dimensionality of the vectors."""
return self.backend.dim

@property
def metric(self) -> str:
"""The metric used by the backend."""
return self.backend.arguments.metric

def query(
self,
vectors: npt.NDArray,
Expand Down Expand Up @@ -229,3 +236,70 @@ def delete(self, tokens: Sequence[str]) -> None:
# Delete items starting from the highest index
for index in sorted(curr_indices, reverse=True):
self.items.pop(index)

def evaluate(
self,
full_vectors: npt.NDArray,
query_vectors: npt.NDArray,
k: int = 10,
epsilon: float = 1e-3,
) -> tuple[float, float]:
"""
Evaluate the Vicinity instance on the given query vectors.
Computes recall and measures QPS (Queries Per Second).
For recall calculation, the same methodology is used as in the ann-benchmarks repository.
NOTE: this is only supported for Cosine and Euclidean metric backends.
:param full_vectors: The full dataset vectors used to build the index.
:param query_vectors: The query vectors to evaluate.
:param k: The number of nearest neighbors to retrieve.
:param epsilon: The epsilon threshold for recall calculation.
:return: A tuple of (QPS, recall).
:raises ValueError: If the metric is not supported by the BasicBackend.
"""
try:
# Validate and map the metric using Metric.from_string
metric_enum = Metric.from_string(self.metric)
if metric_enum not in BasicBackend.supported_metrics:
raise ValueError(f"Unsupported metric '{metric_enum.value}' for BasicBackend.")
basic_metric = metric_enum.value
except ValueError as e:
raise ValueError(
f"Unsupported metric '{self.metric}' for evaluation with BasicBackend. "
f"Supported metrics are: {[m.value for m in BasicBackend.supported_metrics]}"
) from e

# Create ground truth Vicinity instance
gt_vicinity = Vicinity.from_vectors_and_items(
vectors=full_vectors,
items=self.items,
backend_type=Backend.BASIC,
metric=basic_metric,
)

# Compute ground truth results
gt_distances = [[dist for _, dist in neighbors] for neighbors in gt_vicinity.query(query_vectors, k=k)]

# Start timer for approximate query
start_time = perf_counter()
run_results = self.query(query_vectors, k=k)
elapsed_time = perf_counter() - start_time

# Compute QPS
num_queries = len(query_vectors)
qps = num_queries / elapsed_time if elapsed_time > 0 else float("inf")

# Extract approximate distances
approx_distances = [[dist for _, dist in neighbors] for neighbors in run_results]

# Compute recall using the ground truth and approximate distances
recalls = []
for _gt_distances, _approx_distances in zip(gt_distances, approx_distances):
t = _gt_distances[k - 1] + epsilon
recall = sum(1 for dist in _approx_distances if dist <= t) / k
recalls.append(recall)

mean_recall = float(np.mean(recalls))
return qps, mean_recall

0 comments on commit 92d3bb8

Please sign in to comment.