Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Dec 8, 2024
1 parent 2e05a03 commit 8859d81
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
36 changes: 29 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def query_vector() -> np.ndarray:


# Create human-readable ids for each backend type
BACKEND_IDS = [
f"{backend.name}-{index_type}" if index_type else backend.name
for backend, index_type in BACKEND_PARAMS
]
BACKEND_IDS = [f"{backend.name}-{index_type}" if index_type else backend.name for backend, index_type in BACKEND_PARAMS]


@pytest.fixture(params=BACKEND_PARAMS)
Expand All @@ -63,9 +60,7 @@ def backend_type(request: pytest.FixtureRequest) -> Backend:


@pytest.fixture(params=BACKEND_PARAMS, ids=BACKEND_IDS)
def vicinity_instance(
request: pytest.FixtureRequest, items: list[str], vectors: np.ndarray
) -> Vicinity:
def vicinity_instance(request: pytest.FixtureRequest, items: list[str], vectors: np.ndarray) -> Vicinity:
"""Fixture providing a Vicinity instance for each backend type."""
backend_type, index_type = request.param
# Handle FAISS backend with specific FAISS index types
Expand All @@ -91,3 +86,30 @@ def vicinity_instance(
)

return Vicinity.from_vectors_and_items(vectors, items, backend_type=backend_type)


@pytest.fixture(params=BACKEND_PARAMS, ids=BACKEND_IDS)
def vicinity_instance_with_stored_vectors(
request: pytest.FixtureRequest, items: list[str], vectors: np.ndarray
) -> Vicinity:
"""Fixture providing a Vicinity instance for each backend type."""
backend_type, index_type = request.param
# Handle FAISS backend with specific FAISS index types
if backend_type == Backend.FAISS:
if index_type in ("pq", "ivfpq", "ivfpqr"):
# Use smaller values for pq indexes since the dataset is small
return Vicinity.from_vectors_and_items(
vectors, items, backend_type=backend_type, index_type=index_type, m=2, nbits=4, store_vectors=True
)
else:
return Vicinity.from_vectors_and_items(
vectors, items, backend_type=backend_type, index_type=index_type, nlist=2, nbits=32, store_vectors=True
)

return Vicinity.from_vectors_and_items(vectors, items, backend_type=backend_type, store_vectors=True)


@pytest.fixture()
def vicinity_with_basic_backend(vectors: np.ndarray, items: list[str]) -> Vicinity:
"""Fixture providing a BasicBackend instance."""
return Vicinity.from_vectors_and_items(vectors, items, backend_type=Backend.BASIC, store_vectors=True)
36 changes: 35 additions & 1 deletion tests/test_vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,42 @@ def test_vicinity_save_and_load(tmp_path: Path, vicinity_instance: Vicinity) ->
"""
save_path = tmp_path / "vicinity_data"
vicinity_instance.save(save_path)
assert vicinity_instance.vector_store is None

Vicinity.load(save_path)
v = Vicinity.load(save_path)
assert v.vector_store is None


def test_vicinity_save_and_load_vector_store(tmp_path: Path, vicinity_instance_with_stored_vectors: Vicinity) -> None:
"""
Test Vicinity.save and Vicinity.load.
:param tmp_path: Temporary directory provided by pytest.
:param vicinity_instance: A Vicinity instance.
"""
save_path = tmp_path / "vicinity_data"
vicinity_instance_with_stored_vectors.save(save_path)

assert (save_path / "store").exists()
assert (save_path / "store" / "vectors.npy").exists()

v = Vicinity.load(save_path)
assert v.vector_store is not None


def test_index_vector_store(vicinity_with_basic_backend: Vicinity, vectors: np.ndarray) -> None:
"""
Index vectors in the Vicinity instance.
:param vicinity_instance: A Vicinity instance.
:param vectors: Array of vectors to index.
"""
v = vicinity_with_basic_backend.get_vector_by_index(0)
assert np.allclose(v, vectors[0])

idx = [0, 1, 2, 3, 4, 10]
v = vicinity_with_basic_backend.get_vector_by_index(idx)
assert np.allclose(v, vectors[idx])


def test_vicinity_insert_duplicate(vicinity_instance: Vicinity, query_vector: np.ndarray) -> None:
Expand Down

0 comments on commit 8859d81

Please sign in to comment.