From 5aefa2c84b9c215943509327d7e6661462252cb1 Mon Sep 17 00:00:00 2001 From: polarbear567 <269739606@qq.com> Date: Wed, 28 Aug 2024 15:29:11 +0800 Subject: [PATCH] fix `SimpleVectorStore.from_persist_dir()` behaviour (#15534) * fix 15451 * add test cases --------- Co-authored-by: Massimiliano Pippi --- .../llama_index/core/vector_stores/simple.py | 7 ++-- .../tests/vector_stores/test_simple.py | 32 ++++++++++++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/llama-index-core/llama_index/core/vector_stores/simple.py b/llama-index-core/llama_index/core/vector_stores/simple.py index 81643b8ca90d3..69ae8cd86e635 100644 --- a/llama-index-core/llama_index/core/vector_stores/simple.py +++ b/llama-index-core/llama_index/core/vector_stores/simple.py @@ -163,14 +163,11 @@ def __init__( def from_persist_dir( cls, persist_dir: str = DEFAULT_PERSIST_DIR, - namespace: Optional[str] = None, + namespace: str = DEFAULT_VECTOR_STORE, fs: Optional[fsspec.AbstractFileSystem] = None, ) -> "SimpleVectorStore": """Load from persist dir.""" - if namespace: - persist_fname = f"{namespace}{NAMESPACE_SEP}{DEFAULT_PERSIST_FNAME}" - else: - persist_fname = DEFAULT_PERSIST_FNAME + persist_fname = f"{namespace}{NAMESPACE_SEP}{DEFAULT_PERSIST_FNAME}" if fs is not None: persist_path = concat_dirs(persist_dir, persist_fname) diff --git a/llama-index-core/tests/vector_stores/test_simple.py b/llama-index-core/tests/vector_stores/test_simple.py index 119e740231aee..ba62fb7f0a646 100644 --- a/llama-index-core/tests/vector_stores/test_simple.py +++ b/llama-index-core/tests/vector_stores/test_simple.py @@ -1,7 +1,16 @@ import unittest +from pathlib import Path from typing import List -from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode +import pytest + +from llama_index.core import VectorStoreIndex, MockEmbedding +from llama_index.core.schema import ( + NodeRelationship, + RelatedNodeInfo, + TextNode, + Document, +) from llama_index.core.vector_stores import SimpleVectorStore from llama_index.core.vector_stores.types import ( ExactMatchFilter, @@ -17,6 +26,15 @@ _NODE_ID_WEIGHT_3_RANK_C = "452D24AB-F185-414C-A352-590B4B9EE51B" +@pytest.fixture() +def persist_dir(tmp_path: Path): + index = VectorStoreIndex.from_documents( + [Document(id_="1", text="1")], embed_model=MockEmbedding(embed_dim=1) + ) + index.storage_context.persist(str(tmp_path)) + return str(tmp_path) + + def _node_embeddings_for_test() -> List[TextNode]: return [ TextNode( @@ -434,3 +452,15 @@ def test_delete_nodes(self) -> None: query = VectorStoreQuery(query_embedding=[1.0, 1.0], similarity_top_k=3) result = simple_vector_store.query(query) self.assertEqual(result.ids, [_NODE_ID_WEIGHT_3_RANK_C]) + + +def test_from_persist_dir(persist_dir: str) -> None: + vector_store = SimpleVectorStore.from_persist_dir(persist_dir=persist_dir) + assert vector_store is not None + + +def test_from_namespaced_persist_dir(persist_dir: str) -> None: + vector_store = SimpleVectorStore.from_namespaced_persist_dir( + persist_dir=persist_dir + ) + assert vector_store is not None