diff --git a/ann_benchmarks/runner.py b/ann_benchmarks/runner.py index 4916e13e..81428114 100644 --- a/ann_benchmarks/runner.py +++ b/ann_benchmarks/runner.py @@ -66,6 +66,10 @@ def single_query(v: numpy.array) -> Tuple[float, List[Tuple[int, float]]]: start = time.time() candidates = algo.query(v, count) total = time.time() - start + + # make sure all returned indices are unique + assert len(candidates) == len(set(candidates)), "Implementation returned duplicated candidates" + candidates = [ (int(idx), float(metrics[distance].distance(v, X_train[idx]))) for idx in candidates # noqa ] @@ -105,6 +109,11 @@ def batch_query(X: numpy.array) -> List[Tuple[float, List[Tuple[int, float]]]]: batch_latencies = algo.get_batch_latencies() else: batch_latencies = [total / float(len(X))] * len(X) + + # make sure all returned indices are unique + for res in results: + assert len(res) == len(set(res)), "Implementation returned duplicated candidates" + candidates = [ [(int(idx), float(metrics[distance].distance(v, X_train[idx]))) for idx in single_results] # noqa for v, single_results in zip(X, results)