Skip to content

Commit

Permalink
feat: Align metrics (#30)
Browse files Browse the repository at this point in the history
* Added euclidean metric to basic backend

* Switched to mixins

* Updates

* Updates

* Aligned metrics

* Update

* Update

* Resolved comments
  • Loading branch information
Pringled authored Dec 1, 2024
1 parent c5a3434 commit ba2bc67
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 119 deletions.
4 changes: 2 additions & 2 deletions vicinity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Small vector store."""

from vicinity.datatypes import Backend
from vicinity.utils import normalize
from vicinity.utils import Metric, normalize
from vicinity.version import __version__
from vicinity.vicinity import Vicinity

__all__ = ["Backend", "Vicinity", "normalize", "__version__"]
__all__ = ["Backend", "Metric", "Vicinity", "normalize", "__version__"]
52 changes: 28 additions & 24 deletions vicinity/backends/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,33 @@

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
from typing import Any, Union

import numpy as np
from annoy import AnnoyIndex
from numpy import typing as npt

from vicinity.backends.base import AbstractBackend, BaseArgs
from vicinity.datatypes import Backend, QueryResult
from vicinity.utils import normalize
from vicinity.utils import Metric, normalize


@dataclass
class AnnoyArgs(BaseArgs):
dim: int = 0
metric: Literal["dot", "euclidean", "cosine"] = "cosine"
metric: str = "cosine"
trees: int = 100
length: int | None = None


class AnnoyBackend(AbstractBackend[AnnoyArgs]):
argument_class = AnnoyArgs
supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN, Metric.INNER_PRODUCT}
inverse_metric_mapping = {
Metric.COSINE: "dot",
Metric.EUCLIDEAN: "euclidean",
Metric.INNER_PRODUCT: "dot",
}

def __init__(
self,
Expand All @@ -40,25 +46,28 @@ def __init__(
def from_vectors(
cls: type[AnnoyBackend],
vectors: npt.NDArray,
metric: Literal["dot", "euclidean", "cosine"],
metric: Union[str, Metric],
trees: int,
**kwargs: Any,
) -> AnnoyBackend:
"""Create a new instance from vectors."""
dim = vectors.shape[1]
actual_metric: Literal["dot", "euclidean"]
if metric == "cosine":
actual_metric = "dot"
metric_enum = Metric.from_string(metric)

if metric_enum not in cls.supported_metrics:
raise ValueError(f"Metric '{metric_enum.value}' is not supported by AnnoyBackend.")

metric = cls._map_metric_to_string(metric_enum)

if metric == "dot":
vectors = normalize(vectors)
else:
actual_metric = metric

index = AnnoyIndex(f=dim, metric=actual_metric)
dim = vectors.shape[1]
index = AnnoyIndex(f=dim, metric=metric) # type: ignore
for i, vector in enumerate(vectors):
index.add_item(i, vector)
index.build(trees)

arguments = AnnoyArgs(dim=dim, trees=trees, metric=metric, length=len(vectors))
arguments = AnnoyArgs(dim=dim, metric=metric, trees=trees, length=len(vectors)) # type: ignore
return AnnoyBackend(index, arguments=arguments)

@property
Expand All @@ -80,11 +89,7 @@ def load(cls: type[AnnoyBackend], base_path: Path) -> AnnoyBackend:
"""Load the vectors from a path."""
path = Path(base_path) / "index.bin"
arguments = AnnoyArgs.load(base_path / "arguments.json")

metric = arguments.metric
actual_metric = "dot" if metric == "cosine" else metric

index = AnnoyIndex(arguments.dim, actual_metric)
index = AnnoyIndex(arguments.dim, arguments.metric) # type: ignore
index.load(str(path))

return cls(index, arguments=arguments)
Expand All @@ -93,36 +98,35 @@ def save(self, base_path: Path) -> None:
"""Save the vectors to a path."""
path = Path(base_path) / "index.bin"
self.index.save(str(path))
# NOTE: set the length before saving.
# Ensure the length is set before saving
self.arguments.length = len(self)
self.arguments.dump(base_path / "arguments.json")

def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
"""Query the backend."""
out = []
for vec in vectors:
if self.arguments.metric == "cosine":
if self.arguments.metric == "dot":
vec = normalize(vec)
indices, scores = self.index.get_nns_by_vector(vec, k, include_distances=True)
scores_array = np.asarray(scores)
if self.arguments.metric == "cosine":
# Turn cosine similarity into cosine distance.
if self.arguments.metric == "dot":
# Convert cosine similarity to cosine distance
scores_array = 1 - scores_array
out.append((np.asarray(indices), scores_array))
return out

def insert(self, vectors: npt.NDArray) -> None:
"""Insert vectors into the backend."""
raise NotImplementedError("Insertion is not supported in ANNOY backend.")
raise NotImplementedError("Insertion is not supported in Annoy backend.")

def delete(self, indices: list[int]) -> None:
"""Delete vectors from the backend."""
raise NotImplementedError("Deletion is not supported in ANNOY backend.")
raise NotImplementedError("Deletion is not supported in Annoy backend.")

def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
"""Threshold the backend."""
out: list[npt.NDArray] = []
for x, y in self.query(vectors, 100):
out.append(x[y < threshold])

return out
7 changes: 7 additions & 0 deletions vicinity/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from numpy import typing as npt

from vicinity import Metric
from vicinity.datatypes import Backend, QueryResult


Expand All @@ -34,6 +35,7 @@ def dict(self) -> dict[str, Any]:

class AbstractBackend(ABC, Generic[ArgType]):
argument_class: type[ArgType]
inverse_metric_mapping: dict[Metric, str] = {}

def __init__(self, arguments: ArgType, *args: Any, **kwargs: Any) -> None:
"""Initialize the backend with vectors."""
Expand Down Expand Up @@ -93,5 +95,10 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
"""Query the backend."""
raise NotImplementedError()

@classmethod
def _map_metric_to_string(cls, metric: Metric) -> str:
"""Map a Metric enum to a backend-specific metric string."""
return cls.inverse_metric_mapping.get(metric, metric.value)


BaseType = TypeVar("BaseType", bound=AbstractBackend)
4 changes: 2 additions & 2 deletions vicinity/backends/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from vicinity.backends.base import AbstractBackend, BaseArgs
from vicinity.datatypes import Backend, Matrix, QueryResult
from vicinity.utils import normalize, normalize_or_copy
from vicinity.utils import Metric, normalize, normalize_or_copy


@dataclass
Expand All @@ -21,6 +21,7 @@ class BasicArgs(BaseArgs):
class BasicBackend(AbstractBackend[BasicArgs], ABC):
argument_class = BasicArgs
_vectors: npt.NDArray
supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN}

def __init__(self, arguments: BasicArgs) -> None:
"""Initialize the backend."""
Expand Down Expand Up @@ -116,7 +117,6 @@ def threshold(
indices = np.flatnonzero(dists <= threshold)
sorted_indices = indices[np.argsort(dists[indices])]
out.append(sorted_indices)

return out

def query(
Expand Down
84 changes: 33 additions & 51 deletions vicinity/backends/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
from typing import Any, Union

import faiss
import numpy as np
from numpy import typing as npt

from vicinity.backends.base import AbstractBackend, BaseArgs
from vicinity.datatypes import Backend, QueryResult
from vicinity.utils import normalize
from vicinity.utils import Metric, normalize

logger = logging.getLogger(__name__)

# FAISS indexes that support range_search
RANGE_SEARCH_INDEXES = (faiss.IndexFlat, faiss.IndexIVFFlat, faiss.IndexScalarQuantizer, faiss.IndexIVFScalarQuantizer)
RANGE_SEARCH_INDEXES = (
faiss.IndexFlat,
faiss.IndexIVFFlat,
faiss.IndexScalarQuantizer,
faiss.IndexIVFScalarQuantizer,
)
# FAISS indexes that need to be trained before adding vectors
TRAINABLE_INDEXES = (
faiss.IndexIVFFlat,
Expand All @@ -31,8 +36,8 @@
@dataclass
class FaissArgs(BaseArgs):
dim: int = 0
index_type: Literal["flat", "ivf", "hnsw", "lsh", "scalar", "pq", "ivf_scalar", "ivfpq", "ivfpqr"] = "hnsw"
metric: Literal["cosine", "l2"] = "cosine"
index_type: str = "flat"
metric: str = "cosine"
nlist: int = 100
m: int = 8
nbits: int = 8
Expand All @@ -41,6 +46,11 @@ class FaissArgs(BaseArgs):

class FaissBackend(AbstractBackend[FaissArgs]):
argument_class = FaissArgs
supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN}
inverse_metric_mapping = {
Metric.COSINE: faiss.METRIC_INNER_PRODUCT,
Metric.EUCLIDEAN: faiss.METRIC_L2,
}

def __init__(
self,
Expand All @@ -55,43 +65,29 @@ def __init__(
def from_vectors( # noqa: C901
cls: type[FaissBackend],
vectors: npt.NDArray,
index_type: Literal["flat", "ivf", "hnsw", "lsh", "scalar", "pq", "ivf_scalar", "ivfpq", "ivfpqr"] = "flat",
metric: Literal["cosine", "l2"] = "cosine",
index_type: str = "flat",
metric: Union[str, Metric] = "cosine",
nlist: int = 100,
m: int = 8,
nbits: int = 8,
refine_nbits: int = 8,
**kwargs: Any,
) -> FaissBackend:
"""
Create a new instance from vectors.
:param vectors: The vectors to index.
:param index_type: The type of FAISS index to use.
:param metric: The metric to use for similarity search.
:param nlist: The number of cells for IVF indexes.
:param m: The number of subquantizers for PQ and HNSW indexes.
:param nbits: The number of bits for LSH and PQ indexes.
:param refine_nbits: The number of bits for the refinement stage in IVFPQR indexes.
:param **kwargs: Additional arguments to pass to the backend.
:return: A new FaissBackend instance.
:raises ValueError: If an invalid index type is provided.
"""
dim = vectors.shape[1]
"""Create a new instance from vectors."""
metric_enum = Metric.from_string(metric)

if metric_enum not in cls.supported_metrics:
raise ValueError(f"Metric '{metric_enum.value}' is not supported by FaissBackend.")

# If using cosine, normalize vectors to unit length
if metric == "cosine":
faiss_metric = cls._map_metric_to_string(metric_enum)
if faiss_metric == faiss.METRIC_INNER_PRODUCT:
vectors = normalize(vectors)
faiss_metric = faiss.METRIC_INNER_PRODUCT
else:
faiss_metric = faiss.METRIC_L2

if index_type.startswith("ivf"):
# Create a quantizer for IVF indexes
quantizer = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim)
dim = vectors.shape[1]

# Handle index creation based on index_type
if index_type == "flat":
index = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim)
index = faiss.IndexFlat(dim, faiss_metric)
elif index_type == "hnsw":
index = faiss.IndexHNSWFlat(dim, m)
elif index_type == "lsh":
Expand All @@ -100,13 +96,11 @@ def from_vectors( # noqa: C901
index = faiss.IndexScalarQuantizer(dim, faiss.ScalarQuantizer.QT_8bit)
elif index_type == "pq":
if not (1 <= nbits <= 16):
# Log a warning and adjust nbits to the maximum supported value for PQ
logger.warning(f"Invalid nbits={nbits} for IndexPQ. Setting nbits to 16.")
nbits = 16
index = faiss.IndexPQ(dim, m, nbits)
elif index_type.startswith("ivf"):
# Create a quantizer for IVF indexes
quantizer = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim)
quantizer = faiss.IndexFlat(dim, faiss_metric)
if index_type == "ivf":
index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss_metric)
elif index_type == "ivf_scalar":
Expand All @@ -115,6 +109,8 @@ def from_vectors( # noqa: C901
index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, nbits)
elif index_type == "ivfpqr":
index = faiss.IndexIVFPQR(quantizer, dim, nlist, m, nbits, m, refine_nbits)
else:
raise ValueError(f"Unsupported FAISS index type: {index_type}")
else:
raise ValueError(f"Unsupported FAISS index type: {index_type}")

Expand All @@ -127,7 +123,7 @@ def from_vectors( # noqa: C901
arguments = FaissArgs(
dim=dim,
index_type=index_type,
metric=metric,
metric=metric_enum.value,
nlist=nlist,
m=m,
nbits=nbits,
Expand Down Expand Up @@ -171,39 +167,25 @@ def delete(self, indices: list[int]) -> None:
def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
"""Query vectors within a distance threshold, using range_search if supported."""
out: list[npt.NDArray] = []

# Normalize query vectors if using cosine similarity
if self.arguments.metric == "cosine":
vectors = normalize(vectors)

if isinstance(self.index, RANGE_SEARCH_INDEXES):
# Use range_search for supported indexes
radius = threshold
lims, D, I = self.index.range_search(vectors, radius)

for i in range(vectors.shape[0]):
start, end = lims[i], lims[i + 1]
idx = I[start:end]
dist = D[start:end]

# Convert dist for cosine if needed
if self.arguments.metric == "cosine":
dist = 1 - dist

# Only include idx within the threshold
within_threshold = idx[dist < threshold]
out.append(within_threshold)
out.append(idx[dist < threshold])
else:
# Fallback to search-based filtering for indexes that do not support range_search
distances, indices = self.index.search(vectors, 100)

for dist, idx in zip(distances, indices):
# Convert distances for cosine if needed
if self.arguments.metric == "cosine":
dist = 1 - dist
# Filter based on the threshold
within_threshold = idx[dist < threshold]
out.append(within_threshold)
out.append(idx[dist < threshold])

return out

Expand Down
Loading

0 comments on commit ba2bc67

Please sign in to comment.