From ba2bc67a06c03fec891257518da93a96eee2119c Mon Sep 17 00:00:00 2001 From: Thomas van Dongen Date: Sun, 1 Dec 2024 12:08:57 +0100 Subject: [PATCH] feat: Align metrics (#30) * Added euclidean metric to basic backend * Switched to mixins * Updates * Updates * Aligned metrics * Update * Update * Resolved comments --- vicinity/__init__.py | 4 +- vicinity/backends/annoy.py | 52 +++++++++++--------- vicinity/backends/base.py | 7 +++ vicinity/backends/basic.py | 4 +- vicinity/backends/faiss.py | 84 +++++++++++++------------------- vicinity/backends/hnsw.py | 25 +++++++--- vicinity/backends/pynndescent.py | 33 ++++++++----- vicinity/backends/usearch.py | 48 +++++++++--------- vicinity/utils.py | 39 +++++++++++++++ 9 files changed, 177 insertions(+), 119 deletions(-) diff --git a/vicinity/__init__.py b/vicinity/__init__.py index 51507fd..e819e1b 100644 --- a/vicinity/__init__.py +++ b/vicinity/__init__.py @@ -1,8 +1,8 @@ """Small vector store.""" from vicinity.datatypes import Backend -from vicinity.utils import normalize +from vicinity.utils import Metric, normalize from vicinity.version import __version__ from vicinity.vicinity import Vicinity -__all__ = ["Backend", "Vicinity", "normalize", "__version__"] +__all__ = ["Backend", "Metric", "Vicinity", "normalize", "__version__"] diff --git a/vicinity/backends/annoy.py b/vicinity/backends/annoy.py index 6a86f0b..3961cc0 100644 --- a/vicinity/backends/annoy.py +++ b/vicinity/backends/annoy.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal +from typing import Any, Union import numpy as np from annoy import AnnoyIndex @@ -10,19 +10,25 @@ from vicinity.backends.base import AbstractBackend, BaseArgs from vicinity.datatypes import Backend, QueryResult -from vicinity.utils import normalize +from vicinity.utils import Metric, normalize @dataclass class AnnoyArgs(BaseArgs): dim: int = 0 - metric: Literal["dot", "euclidean", "cosine"] = "cosine" + metric: str = "cosine" trees: int = 100 length: int | None = None class AnnoyBackend(AbstractBackend[AnnoyArgs]): argument_class = AnnoyArgs + supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN, Metric.INNER_PRODUCT} + inverse_metric_mapping = { + Metric.COSINE: "dot", + Metric.EUCLIDEAN: "euclidean", + Metric.INNER_PRODUCT: "dot", + } def __init__( self, @@ -40,25 +46,28 @@ def __init__( def from_vectors( cls: type[AnnoyBackend], vectors: npt.NDArray, - metric: Literal["dot", "euclidean", "cosine"], + metric: Union[str, Metric], trees: int, **kwargs: Any, ) -> AnnoyBackend: """Create a new instance from vectors.""" - dim = vectors.shape[1] - actual_metric: Literal["dot", "euclidean"] - if metric == "cosine": - actual_metric = "dot" + metric_enum = Metric.from_string(metric) + + if metric_enum not in cls.supported_metrics: + raise ValueError(f"Metric '{metric_enum.value}' is not supported by AnnoyBackend.") + + metric = cls._map_metric_to_string(metric_enum) + + if metric == "dot": vectors = normalize(vectors) - else: - actual_metric = metric - index = AnnoyIndex(f=dim, metric=actual_metric) + dim = vectors.shape[1] + index = AnnoyIndex(f=dim, metric=metric) # type: ignore for i, vector in enumerate(vectors): index.add_item(i, vector) index.build(trees) - arguments = AnnoyArgs(dim=dim, trees=trees, metric=metric, length=len(vectors)) + arguments = AnnoyArgs(dim=dim, metric=metric, trees=trees, length=len(vectors)) # type: ignore return AnnoyBackend(index, arguments=arguments) @property @@ -80,11 +89,7 @@ def load(cls: type[AnnoyBackend], base_path: Path) -> AnnoyBackend: """Load the vectors from a path.""" path = Path(base_path) / "index.bin" arguments = AnnoyArgs.load(base_path / "arguments.json") - - metric = arguments.metric - actual_metric = "dot" if metric == "cosine" else metric - - index = AnnoyIndex(arguments.dim, actual_metric) + index = AnnoyIndex(arguments.dim, arguments.metric) # type: ignore index.load(str(path)) return cls(index, arguments=arguments) @@ -93,7 +98,7 @@ def save(self, base_path: Path) -> None: """Save the vectors to a path.""" path = Path(base_path) / "index.bin" self.index.save(str(path)) - # NOTE: set the length before saving. + # Ensure the length is set before saving self.arguments.length = len(self) self.arguments.dump(base_path / "arguments.json") @@ -101,28 +106,27 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult: """Query the backend.""" out = [] for vec in vectors: - if self.arguments.metric == "cosine": + if self.arguments.metric == "dot": vec = normalize(vec) indices, scores = self.index.get_nns_by_vector(vec, k, include_distances=True) scores_array = np.asarray(scores) - if self.arguments.metric == "cosine": - # Turn cosine similarity into cosine distance. + if self.arguments.metric == "dot": + # Convert cosine similarity to cosine distance scores_array = 1 - scores_array out.append((np.asarray(indices), scores_array)) return out def insert(self, vectors: npt.NDArray) -> None: """Insert vectors into the backend.""" - raise NotImplementedError("Insertion is not supported in ANNOY backend.") + raise NotImplementedError("Insertion is not supported in Annoy backend.") def delete(self, indices: list[int]) -> None: """Delete vectors from the backend.""" - raise NotImplementedError("Deletion is not supported in ANNOY backend.") + raise NotImplementedError("Deletion is not supported in Annoy backend.") def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]: """Threshold the backend.""" out: list[npt.NDArray] = [] for x, y in self.query(vectors, 100): out.append(x[y < threshold]) - return out diff --git a/vicinity/backends/base.py b/vicinity/backends/base.py index 86857b3..4876047 100644 --- a/vicinity/backends/base.py +++ b/vicinity/backends/base.py @@ -8,6 +8,7 @@ from numpy import typing as npt +from vicinity import Metric from vicinity.datatypes import Backend, QueryResult @@ -34,6 +35,7 @@ def dict(self) -> dict[str, Any]: class AbstractBackend(ABC, Generic[ArgType]): argument_class: type[ArgType] + inverse_metric_mapping: dict[Metric, str] = {} def __init__(self, arguments: ArgType, *args: Any, **kwargs: Any) -> None: """Initialize the backend with vectors.""" @@ -93,5 +95,10 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult: """Query the backend.""" raise NotImplementedError() + @classmethod + def _map_metric_to_string(cls, metric: Metric) -> str: + """Map a Metric enum to a backend-specific metric string.""" + return cls.inverse_metric_mapping.get(metric, metric.value) + BaseType = TypeVar("BaseType", bound=AbstractBackend) diff --git a/vicinity/backends/basic.py b/vicinity/backends/basic.py index b58825f..23b33a8 100644 --- a/vicinity/backends/basic.py +++ b/vicinity/backends/basic.py @@ -10,7 +10,7 @@ from vicinity.backends.base import AbstractBackend, BaseArgs from vicinity.datatypes import Backend, Matrix, QueryResult -from vicinity.utils import normalize, normalize_or_copy +from vicinity.utils import Metric, normalize, normalize_or_copy @dataclass @@ -21,6 +21,7 @@ class BasicArgs(BaseArgs): class BasicBackend(AbstractBackend[BasicArgs], ABC): argument_class = BasicArgs _vectors: npt.NDArray + supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN} def __init__(self, arguments: BasicArgs) -> None: """Initialize the backend.""" @@ -116,7 +117,6 @@ def threshold( indices = np.flatnonzero(dists <= threshold) sorted_indices = indices[np.argsort(dists[indices])] out.append(sorted_indices) - return out def query( diff --git a/vicinity/backends/faiss.py b/vicinity/backends/faiss.py index d499f17..d726941 100644 --- a/vicinity/backends/faiss.py +++ b/vicinity/backends/faiss.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal +from typing import Any, Union import faiss import numpy as np @@ -11,12 +11,17 @@ from vicinity.backends.base import AbstractBackend, BaseArgs from vicinity.datatypes import Backend, QueryResult -from vicinity.utils import normalize +from vicinity.utils import Metric, normalize logger = logging.getLogger(__name__) # FAISS indexes that support range_search -RANGE_SEARCH_INDEXES = (faiss.IndexFlat, faiss.IndexIVFFlat, faiss.IndexScalarQuantizer, faiss.IndexIVFScalarQuantizer) +RANGE_SEARCH_INDEXES = ( + faiss.IndexFlat, + faiss.IndexIVFFlat, + faiss.IndexScalarQuantizer, + faiss.IndexIVFScalarQuantizer, +) # FAISS indexes that need to be trained before adding vectors TRAINABLE_INDEXES = ( faiss.IndexIVFFlat, @@ -31,8 +36,8 @@ @dataclass class FaissArgs(BaseArgs): dim: int = 0 - index_type: Literal["flat", "ivf", "hnsw", "lsh", "scalar", "pq", "ivf_scalar", "ivfpq", "ivfpqr"] = "hnsw" - metric: Literal["cosine", "l2"] = "cosine" + index_type: str = "flat" + metric: str = "cosine" nlist: int = 100 m: int = 8 nbits: int = 8 @@ -41,6 +46,11 @@ class FaissArgs(BaseArgs): class FaissBackend(AbstractBackend[FaissArgs]): argument_class = FaissArgs + supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN} + inverse_metric_mapping = { + Metric.COSINE: faiss.METRIC_INNER_PRODUCT, + Metric.EUCLIDEAN: faiss.METRIC_L2, + } def __init__( self, @@ -55,43 +65,29 @@ def __init__( def from_vectors( # noqa: C901 cls: type[FaissBackend], vectors: npt.NDArray, - index_type: Literal["flat", "ivf", "hnsw", "lsh", "scalar", "pq", "ivf_scalar", "ivfpq", "ivfpqr"] = "flat", - metric: Literal["cosine", "l2"] = "cosine", + index_type: str = "flat", + metric: Union[str, Metric] = "cosine", nlist: int = 100, m: int = 8, nbits: int = 8, refine_nbits: int = 8, **kwargs: Any, ) -> FaissBackend: - """ - Create a new instance from vectors. - - :param vectors: The vectors to index. - :param index_type: The type of FAISS index to use. - :param metric: The metric to use for similarity search. - :param nlist: The number of cells for IVF indexes. - :param m: The number of subquantizers for PQ and HNSW indexes. - :param nbits: The number of bits for LSH and PQ indexes. - :param refine_nbits: The number of bits for the refinement stage in IVFPQR indexes. - :param **kwargs: Additional arguments to pass to the backend. - :return: A new FaissBackend instance. - :raises ValueError: If an invalid index type is provided. - """ - dim = vectors.shape[1] + """Create a new instance from vectors.""" + metric_enum = Metric.from_string(metric) + + if metric_enum not in cls.supported_metrics: + raise ValueError(f"Metric '{metric_enum.value}' is not supported by FaissBackend.") - # If using cosine, normalize vectors to unit length - if metric == "cosine": + faiss_metric = cls._map_metric_to_string(metric_enum) + if faiss_metric == faiss.METRIC_INNER_PRODUCT: vectors = normalize(vectors) - faiss_metric = faiss.METRIC_INNER_PRODUCT - else: - faiss_metric = faiss.METRIC_L2 - if index_type.startswith("ivf"): - # Create a quantizer for IVF indexes - quantizer = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim) + dim = vectors.shape[1] + # Handle index creation based on index_type if index_type == "flat": - index = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim) + index = faiss.IndexFlat(dim, faiss_metric) elif index_type == "hnsw": index = faiss.IndexHNSWFlat(dim, m) elif index_type == "lsh": @@ -100,13 +96,11 @@ def from_vectors( # noqa: C901 index = faiss.IndexScalarQuantizer(dim, faiss.ScalarQuantizer.QT_8bit) elif index_type == "pq": if not (1 <= nbits <= 16): - # Log a warning and adjust nbits to the maximum supported value for PQ logger.warning(f"Invalid nbits={nbits} for IndexPQ. Setting nbits to 16.") nbits = 16 index = faiss.IndexPQ(dim, m, nbits) elif index_type.startswith("ivf"): - # Create a quantizer for IVF indexes - quantizer = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim) + quantizer = faiss.IndexFlat(dim, faiss_metric) if index_type == "ivf": index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss_metric) elif index_type == "ivf_scalar": @@ -115,6 +109,8 @@ def from_vectors( # noqa: C901 index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, nbits) elif index_type == "ivfpqr": index = faiss.IndexIVFPQR(quantizer, dim, nlist, m, nbits, m, refine_nbits) + else: + raise ValueError(f"Unsupported FAISS index type: {index_type}") else: raise ValueError(f"Unsupported FAISS index type: {index_type}") @@ -127,7 +123,7 @@ def from_vectors( # noqa: C901 arguments = FaissArgs( dim=dim, index_type=index_type, - metric=metric, + metric=metric_enum.value, nlist=nlist, m=m, nbits=nbits, @@ -171,39 +167,25 @@ def delete(self, indices: list[int]) -> None: def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]: """Query vectors within a distance threshold, using range_search if supported.""" out: list[npt.NDArray] = [] - - # Normalize query vectors if using cosine similarity if self.arguments.metric == "cosine": vectors = normalize(vectors) if isinstance(self.index, RANGE_SEARCH_INDEXES): - # Use range_search for supported indexes radius = threshold lims, D, I = self.index.range_search(vectors, radius) - for i in range(vectors.shape[0]): start, end = lims[i], lims[i + 1] idx = I[start:end] dist = D[start:end] - - # Convert dist for cosine if needed if self.arguments.metric == "cosine": dist = 1 - dist - - # Only include idx within the threshold - within_threshold = idx[dist < threshold] - out.append(within_threshold) + out.append(idx[dist < threshold]) else: - # Fallback to search-based filtering for indexes that do not support range_search distances, indices = self.index.search(vectors, 100) - for dist, idx in zip(distances, indices): - # Convert distances for cosine if needed if self.arguments.metric == "cosine": dist = 1 - dist - # Filter based on the threshold - within_threshold = idx[dist < threshold] - out.append(within_threshold) + out.append(idx[dist < threshold]) return out diff --git a/vicinity/backends/hnsw.py b/vicinity/backends/hnsw.py index 827adab..480d9f5 100644 --- a/vicinity/backends/hnsw.py +++ b/vicinity/backends/hnsw.py @@ -2,25 +2,31 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, Union from hnswlib import Index as HnswIndex from numpy import typing as npt from vicinity.backends.base import AbstractBackend, BaseArgs from vicinity.datatypes import Backend, QueryResult +from vicinity.utils import Metric @dataclass class HNSWArgs(BaseArgs): dim: int = 0 - space: Literal["cosine", "l2"] = "cosine" + metric: str = "cosine" ef_construction: int = 200 m: int = 16 class HNSWBackend(AbstractBackend[HNSWArgs]): argument_class = HNSWArgs + supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN} + inverse_metric_mapping = { + Metric.COSINE: "cosine", + Metric.EUCLIDEAN: "l2", + } def __init__( self, @@ -35,17 +41,24 @@ def __init__( def from_vectors( cls: type[HNSWBackend], vectors: npt.NDArray, - space: Literal["cosine", "l2"], + metric: Union[str, Metric], ef_construction: int, m: int, **kwargs: Any, ) -> HNSWBackend: """Create a new instance from vectors.""" + metric_enum = Metric.from_string(metric) + + if metric_enum not in cls.supported_metrics: + raise ValueError(f"Metric '{metric_enum.value}' is not supported by HNSWBackend.") + + # Map Metric to HNSW's space parameter + metric = cls._map_metric_to_string(metric_enum) dim = vectors.shape[1] - index = HnswIndex(space=space, dim=dim) + index = HnswIndex(space=metric, dim=dim) index.init_index(max_elements=vectors.shape[0], ef_construction=ef_construction, M=m) index.add_items(vectors) - arguments = HNSWArgs(dim=dim, space=space, ef_construction=ef_construction, m=m) + arguments = HNSWArgs(dim=dim, metric=metric, ef_construction=ef_construction, m=m) return HNSWBackend(index, arguments=arguments) @property @@ -67,7 +80,7 @@ def load(cls: type[HNSWBackend], base_path: Path) -> HNSWBackend: """Load the vectors from a path.""" path = Path(base_path) / "index.bin" arguments = HNSWArgs.load(base_path / "arguments.json") - index = HnswIndex(space=arguments.space, dim=arguments.dim) + index = HnswIndex(space=arguments.metric, dim=arguments.dim) index.load_index(str(path)) return cls(index, arguments=arguments) diff --git a/vicinity/backends/pynndescent.py b/vicinity/backends/pynndescent.py index 726eb65..964ab08 100644 --- a/vicinity/backends/pynndescent.py +++ b/vicinity/backends/pynndescent.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Literal +from typing import Any, Union import numpy as np from numpy import typing as npt @@ -10,21 +10,18 @@ from vicinity.backends.base import AbstractBackend, BaseArgs from vicinity.datatypes import Backend, QueryResult -from vicinity.utils import normalize_or_copy +from vicinity.utils import Metric, normalize_or_copy @dataclass class PyNNDescentArgs(BaseArgs): n_neighbors: int = 15 - metric: Literal[ - "cosine", - "euclidean", - "manhattan", - ] = "cosine" + metric: str = "cosine" class PyNNDescentBackend(AbstractBackend[PyNNDescentArgs]): argument_class = PyNNDescentArgs + supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN, Metric.MANHATTAN} def __init__( self, @@ -40,10 +37,18 @@ def from_vectors( cls: type[PyNNDescentBackend], vectors: npt.NDArray, n_neighbors: int = 15, - metric: Literal["cosine", "euclidean", "manhattan"] = "cosine", + metric: Union[str, Metric] = "cosine", + **kwargs: Any, ) -> PyNNDescentBackend: """Create a new instance from vectors.""" - index = NNDescent(vectors, n_neighbors=n_neighbors, metric=metric) + metric_enum = Metric.from_string(metric) + + if metric_enum not in cls.supported_metrics: + raise ValueError(f"Metric '{metric_enum.value}' is not supported by PyNNDescentBackend.") + + metric = metric_enum.value + + index = NNDescent(vectors, n_neighbors=n_neighbors, metric=metric, **kwargs) arguments = PyNNDescentArgs(n_neighbors=n_neighbors, metric=metric) return cls(index=index, arguments=arguments) @@ -69,11 +74,11 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult: def insert(self, vectors: npt.NDArray) -> None: """Insert vectors into the backend.""" - raise NotImplementedError("Insertion is not supported in pynndescent backend.") + raise NotImplementedError("Insertion is not supported in PyNNDescent backend.") def delete(self, indices: list[int]) -> None: """Delete vectors from the backend.""" - raise NotImplementedError("Deletion is not supported in pynndescent backend.") + raise NotImplementedError("Deletion is not supported in PyNNDescent backend.") def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]: """Find neighbors within a distance threshold.""" @@ -99,7 +104,11 @@ def load(cls: type[PyNNDescentBackend], base_path: Path) -> PyNNDescentBackend: """Load the vectors and configuration from a specified path.""" arguments = PyNNDescentArgs.load(base_path / "arguments.json") vectors = np.load(Path(base_path) / "vectors.npy") - index = NNDescent(vectors, n_neighbors=arguments.n_neighbors, metric=arguments.metric) + + metric_enum = Metric.from_string(arguments.metric) + pynndescent_metric = metric_enum.value + + index = NNDescent(vectors, n_neighbors=arguments.n_neighbors, metric=pynndescent_metric) # Load the neighbor graph if it was saved neighbor_graph_path = base_path / "neighbor_graph.npy" diff --git a/vicinity/backends/usearch.py b/vicinity/backends/usearch.py index 470a335..2671bc8 100644 --- a/vicinity/backends/usearch.py +++ b/vicinity/backends/usearch.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal +from typing import Any, Union import numpy as np from numpy import typing as npt @@ -10,12 +10,13 @@ from vicinity.backends.base import AbstractBackend, BaseArgs from vicinity.datatypes import Backend, QueryResult +from vicinity.utils import Metric @dataclass class UsearchArgs(BaseArgs): dim: int = 0 - metric: Literal["cos", "ip", "l2sq", "hamming", "tanimoto"] = "cos" + metric: str = "cos" connectivity: int = 16 expansion_add: int = 128 expansion_search: int = 64 @@ -23,6 +24,14 @@ class UsearchArgs(BaseArgs): class UsearchBackend(AbstractBackend[UsearchArgs]): argument_class = UsearchArgs + supported_metrics = {Metric.COSINE, Metric.INNER_PRODUCT, Metric.L2_SQUARED, Metric.HAMMING, Metric.TANIMOTO} + inverse_metric_mapping = { + Metric.COSINE: "cos", + Metric.INNER_PRODUCT: "ip", + Metric.L2_SQUARED: "l2sq", + Metric.HAMMING: "hamming", + Metric.TANIMOTO: "tanimoto", + } def __init__( self, @@ -37,23 +46,19 @@ def __init__( def from_vectors( cls: type[UsearchBackend], vectors: npt.NDArray, - metric: Literal["cos", "ip", "l2sq", "hamming", "tanimoto"], - connectivity: int, - expansion_add: int, - expansion_search: int, + metric: Union[str, Metric] = "cos", + connectivity: int = 16, + expansion_add: int = 128, + expansion_search: int = 64, **kwargs: Any, ) -> UsearchBackend: - """ - Create a new instance from vectors. - - :param vectors: The vectors to index. - :param metric: The metric to use. - :param connectivity: The connectivity parameter. - :param expansion_add: The expansion add parameter. - :param expansion_search: The expansion search parameter. - :param **kwargs: Additional keyword arguments. - :return: A new instance of the backend. - """ + """Create a new instance from vectors.""" + metric_enum = Metric.from_string(metric) + + if metric_enum not in cls.supported_metrics: + raise ValueError(f"Metric '{metric_enum.value}' is not supported by UsearchBackend.") + + metric = cls._map_metric_to_string(metric_enum) dim = vectors.shape[1] index = UsearchIndex( ndim=dim, @@ -70,9 +75,7 @@ def from_vectors( expansion_add=expansion_add, expansion_search=expansion_search, ) - backend = cls(index, arguments=arguments) - - return backend + return cls(index, arguments) @property def backend_type(self) -> Backend: @@ -93,6 +96,7 @@ def load(cls: type[UsearchBackend], base_path: Path) -> UsearchBackend: """Load the index from a path.""" path = Path(base_path) / "index.usearch" arguments = UsearchArgs.load(base_path / "arguments.json") + index = UsearchIndex( ndim=arguments.dim, metric=arguments.metric, @@ -121,8 +125,8 @@ def insert(self, vectors: npt.NDArray) -> None: self.index.add(None, vectors) # type: ignore def delete(self, indices: list[int]) -> None: - """Delete vectors from the index (not supported by usearch).""" - raise NotImplementedError("Dynamic deletion is not supported by usearch.") + """Delete vectors from the index (not supported by Usearch).""" + raise NotImplementedError("Dynamic deletion is not supported in Usearch.") def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]: """Threshold the backend and return filtered keys.""" diff --git a/vicinity/utils.py b/vicinity/utils.py index 76aef5e..1566a81 100644 --- a/vicinity/utils.py +++ b/vicinity/utils.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from enum import Enum from typing import Union import numpy as np @@ -51,3 +54,39 @@ def normalize_or_copy(vectors: npt.NDArray) -> npt.NDArray: if all_unit_length: return vectors return normalize(vectors, norms) + + +class Metric(Enum): + COSINE = "cosine" + EUCLIDEAN = "euclidean" + MANHATTAN = "manhattan" + INNER_PRODUCT = "inner_product" + L2_SQUARED = "l2sq" + HAMMING = "hamming" + TANIMOTO = "tanimoto" + + @classmethod + def from_string(cls, metric: Union[str, Metric]) -> Metric: + """Convert a string or Metric enum to a Metric enum member.""" + if isinstance(metric, cls): + return metric + if isinstance(metric, str): + mapping = { + "cos": cls.COSINE, + "cosine": cls.COSINE, + "dot": cls.COSINE, + "euclidean": cls.EUCLIDEAN, + "l2": cls.EUCLIDEAN, + "manhattan": cls.MANHATTAN, + "l1": cls.MANHATTAN, + "inner_product": cls.INNER_PRODUCT, + "ip": cls.INNER_PRODUCT, + "l2sq": cls.L2_SQUARED, + "l2_squared": cls.L2_SQUARED, + "hamming": cls.HAMMING, + "tanimoto": cls.TANIMOTO, + } + metric_str = metric.lower() + if metric_str in mapping: + return mapping[metric_str] + raise ValueError(f"Unsupported metric: {metric}")