From 92d3bb83d23e583e35ee137cdf00cc53de5aedc4 Mon Sep 17 00:00:00 2001 From: Thomas van Dongen Date: Sun, 1 Dec 2024 12:45:13 +0100 Subject: [PATCH] feat: Add evaluator (#31) * Added euclidean metric to basic backend * Switched to mixins * Updates * Updates * Aligned metrics * Update * WIP * Update * Update * Added test * Updates * Fixed supported metric issue * Update --- tests/test_vicinity.py | 20 +++++++++++ vicinity/vicinity.py | 76 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/tests/test_vicinity.py b/tests/test_vicinity.py index ea07e5e..79b7c4d 100644 --- a/tests/test_vicinity.py +++ b/tests/test_vicinity.py @@ -220,3 +220,23 @@ def test_vicinity_delete_and_query(vicinity_instance: Vicinity, items: list[str] # Check that the queried item is in the results assert "item3" in returned_items + + +def test_vicinity_evaluate(vicinity_instance: Vicinity, vectors: np.ndarray) -> None: + """ + Test the evaluate method of the Vicinity instance. + + :param vicinity_instance: A Vicinity instance. + :param vectors: The full dataset vectors used to build the index. + """ + query_vectors = vectors[:10] + qps, recall = vicinity_instance.evaluate(vectors, query_vectors) + + # Ensure the QPS and recall values are within valid ranges + assert qps > 0 + assert 0 <= recall <= 1 + + # Test with an unsupported metric + vicinity_instance.backend.arguments.metric = "manhattan" + with pytest.raises(ValueError): + vicinity_instance.evaluate(vectors, query_vectors) diff --git a/vicinity/vicinity.py b/vicinity/vicinity.py index d469325..e6273ef 100644 --- a/vicinity/vicinity.py +++ b/vicinity/vicinity.py @@ -6,13 +6,15 @@ import time from io import open from pathlib import Path +from time import perf_counter from typing import Any, Sequence, Union import numpy as np import orjson from numpy import typing as npt -from vicinity.backends import AbstractBackend, get_backend_class +from vicinity import Metric +from vicinity.backends import AbstractBackend, BasicBackend, get_backend_class from vicinity.datatypes import Backend, PathLike logger = logging.getLogger(__name__) @@ -83,6 +85,11 @@ def dim(self) -> int: """The dimensionality of the vectors.""" return self.backend.dim + @property + def metric(self) -> str: + """The metric used by the backend.""" + return self.backend.arguments.metric + def query( self, vectors: npt.NDArray, @@ -229,3 +236,70 @@ def delete(self, tokens: Sequence[str]) -> None: # Delete items starting from the highest index for index in sorted(curr_indices, reverse=True): self.items.pop(index) + + def evaluate( + self, + full_vectors: npt.NDArray, + query_vectors: npt.NDArray, + k: int = 10, + epsilon: float = 1e-3, + ) -> tuple[float, float]: + """ + Evaluate the Vicinity instance on the given query vectors. + + Computes recall and measures QPS (Queries Per Second). + For recall calculation, the same methodology is used as in the ann-benchmarks repository. + + NOTE: this is only supported for Cosine and Euclidean metric backends. + + :param full_vectors: The full dataset vectors used to build the index. + :param query_vectors: The query vectors to evaluate. + :param k: The number of nearest neighbors to retrieve. + :param epsilon: The epsilon threshold for recall calculation. + :return: A tuple of (QPS, recall). + :raises ValueError: If the metric is not supported by the BasicBackend. + """ + try: + # Validate and map the metric using Metric.from_string + metric_enum = Metric.from_string(self.metric) + if metric_enum not in BasicBackend.supported_metrics: + raise ValueError(f"Unsupported metric '{metric_enum.value}' for BasicBackend.") + basic_metric = metric_enum.value + except ValueError as e: + raise ValueError( + f"Unsupported metric '{self.metric}' for evaluation with BasicBackend. " + f"Supported metrics are: {[m.value for m in BasicBackend.supported_metrics]}" + ) from e + + # Create ground truth Vicinity instance + gt_vicinity = Vicinity.from_vectors_and_items( + vectors=full_vectors, + items=self.items, + backend_type=Backend.BASIC, + metric=basic_metric, + ) + + # Compute ground truth results + gt_distances = [[dist for _, dist in neighbors] for neighbors in gt_vicinity.query(query_vectors, k=k)] + + # Start timer for approximate query + start_time = perf_counter() + run_results = self.query(query_vectors, k=k) + elapsed_time = perf_counter() - start_time + + # Compute QPS + num_queries = len(query_vectors) + qps = num_queries / elapsed_time if elapsed_time > 0 else float("inf") + + # Extract approximate distances + approx_distances = [[dist for _, dist in neighbors] for neighbors in run_results] + + # Compute recall using the ground truth and approximate distances + recalls = [] + for _gt_distances, _approx_distances in zip(gt_distances, approx_distances): + t = _gt_distances[k - 1] + epsilon + recall = sum(1 for dist in _approx_distances if dist <= t) / k + recalls.append(recall) + + mean_recall = float(np.mean(recalls)) + return qps, mean_recall