From 7b6f4ef26422370e26b378dd945d2dac4fc8d090 Mon Sep 17 00:00:00 2001 From: willtai Date: Wed, 4 Dec 2024 16:21:30 +0000 Subject: [PATCH] Added new tests to improve test coverage for Neo4jVector (#17) --- .gitignore | 1 + .../vectorstores/neo4j_vector.py | 14 +- libs/neo4j/poetry.lock | 97 ++- libs/neo4j/pyproject.toml | 3 +- .../unit_tests/vectorstores/test_neo4j.py | 803 ++++++++++++++++++ 5 files changed, 909 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 704af65..be345ce 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ __pycache__ .idea .vscode **/.DS_Store +.coverage diff --git a/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py b/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py index 0932b0d..e2beaef 100644 --- a/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py +++ b/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py @@ -345,11 +345,11 @@ def _handle_field_filter( query_snippet = f"toLower(n.`{field}`) CONTAINS $param_{param_number}" query_param = {f"param_{param_number}": filter_value.rstrip("%")} return (query_snippet, query_param) - else: - raise NotImplementedError() else: raise NotImplementedError() + raise NotImplementedError("Unhandled operator") + def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]: """Construct a metadata filter. @@ -386,7 +386,7 @@ def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]: and_ = combine_queries( [construct_metadata_filter(el) for el in value], "AND" ) - if len(and_) >= 1: + if len(and_[0]) >= 1: return and_ else: raise ValueError( @@ -397,7 +397,7 @@ def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]: or_ = combine_queries( [construct_metadata_filter(el) for el in value], "OR" ) - if len(or_) >= 1: + if len(or_[0]) >= 1: return or_ else: raise ValueError( @@ -422,7 +422,7 @@ def construct_metadata_filter(filter: Dict[str, Any]) -> Tuple[str, Dict]: for index, (k, v) in enumerate(filter.items()) ] ) - if len(and_multiple) >= 1: + if len(and_multiple[0]) >= 1: return " AND ".join(and_multiple[0]), and_multiple[1] else: raise ValueError( @@ -862,7 +862,7 @@ def __from( embedding_dimension and not store.embedding_dimension == embedding_dimension ): raise ValueError( - f"Index with name {store.index_name} already exists." + f"Index with name {store.index_name} already exists. " "The provided embedding function and vector index " "dimensions do not match.\n" f"Embedding function dimension: {store.embedding_dimension}\n" @@ -1539,7 +1539,7 @@ def from_existing_graph( embedding_dimension and not store.embedding_dimension == embedding_dimension ): raise ValueError( - f"Index with name {store.index_name} already exists." + f"Index with name {store.index_name} already exists. " "The provided embedding function and vector index " "dimensions do not match.\n" f"Embedding function dimension: {store.embedding_dimension}\n" diff --git a/libs/neo4j/poetry.lock b/libs/neo4j/poetry.lock index bdb2f26..9af7616 100644 --- a/libs/neo4j/poetry.lock +++ b/libs/neo4j/poetry.lock @@ -339,6 +339,83 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.6.8" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "coverage-7.6.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b39e6011cd06822eb964d038d5dff5da5d98652b81f5ecd439277b32361a3a50"}, + {file = "coverage-7.6.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:63c19702db10ad79151a059d2d6336fe0c470f2e18d0d4d1a57f7f9713875dcf"}, + {file = "coverage-7.6.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3985b9be361d8fb6b2d1adc9924d01dec575a1d7453a14cccd73225cb79243ee"}, + {file = "coverage-7.6.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:644ec81edec0f4ad17d51c838a7d01e42811054543b76d4ba2c5d6af741ce2a6"}, + {file = "coverage-7.6.8-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f188a2402f8359cf0c4b1fe89eea40dc13b52e7b4fd4812450da9fcd210181d"}, + {file = "coverage-7.6.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e19122296822deafce89a0c5e8685704c067ae65d45e79718c92df7b3ec3d331"}, + {file = "coverage-7.6.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:13618bed0c38acc418896005732e565b317aa9e98d855a0e9f211a7ffc2d6638"}, + {file = "coverage-7.6.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:193e3bffca48ad74b8c764fb4492dd875038a2f9925530cb094db92bb5e47bed"}, + {file = "coverage-7.6.8-cp310-cp310-win32.whl", hash = "sha256:3988665ee376abce49613701336544041f2117de7b7fbfe91b93d8ff8b151c8e"}, + {file = "coverage-7.6.8-cp310-cp310-win_amd64.whl", hash = "sha256:f56f49b2553d7dd85fd86e029515a221e5c1f8cb3d9c38b470bc38bde7b8445a"}, + {file = "coverage-7.6.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:86cffe9c6dfcfe22e28027069725c7f57f4b868a3f86e81d1c62462764dc46d4"}, + {file = "coverage-7.6.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d82ab6816c3277dc962cfcdc85b1efa0e5f50fb2c449432deaf2398a2928ab94"}, + {file = "coverage-7.6.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13690e923a3932e4fad4c0ebfb9cb5988e03d9dcb4c5150b5fcbf58fd8bddfc4"}, + {file = "coverage-7.6.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4be32da0c3827ac9132bb488d331cb32e8d9638dd41a0557c5569d57cf22c9c1"}, + {file = "coverage-7.6.8-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e6c85bbdc809383b509d732b06419fb4544dca29ebe18480379633623baafb"}, + {file = "coverage-7.6.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:768939f7c4353c0fac2f7c37897e10b1414b571fd85dd9fc49e6a87e37a2e0d8"}, + {file = "coverage-7.6.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e44961e36cb13c495806d4cac67640ac2866cb99044e210895b506c26ee63d3a"}, + {file = "coverage-7.6.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3ea8bb1ab9558374c0ab591783808511d135a833c3ca64a18ec927f20c4030f0"}, + {file = "coverage-7.6.8-cp311-cp311-win32.whl", hash = "sha256:629a1ba2115dce8bf75a5cce9f2486ae483cb89c0145795603d6554bdc83e801"}, + {file = "coverage-7.6.8-cp311-cp311-win_amd64.whl", hash = "sha256:fb9fc32399dca861584d96eccd6c980b69bbcd7c228d06fb74fe53e007aa8ef9"}, + {file = "coverage-7.6.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e683e6ecc587643f8cde8f5da6768e9d165cd31edf39ee90ed7034f9ca0eefee"}, + {file = "coverage-7.6.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1defe91d41ce1bd44b40fabf071e6a01a5aa14de4a31b986aa9dfd1b3e3e414a"}, + {file = "coverage-7.6.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7ad66e8e50225ebf4236368cc43c37f59d5e6728f15f6e258c8639fa0dd8e6d"}, + {file = "coverage-7.6.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fe47da3e4fda5f1abb5709c156eca207eacf8007304ce3019eb001e7a7204cb"}, + {file = "coverage-7.6.8-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:202a2d645c5a46b84992f55b0a3affe4f0ba6b4c611abec32ee88358db4bb649"}, + {file = "coverage-7.6.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4674f0daa1823c295845b6a740d98a840d7a1c11df00d1fd62614545c1583787"}, + {file = "coverage-7.6.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:74610105ebd6f33d7c10f8907afed696e79c59e3043c5f20eaa3a46fddf33b4c"}, + {file = "coverage-7.6.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37cda8712145917105e07aab96388ae76e787270ec04bcb9d5cc786d7cbb8443"}, + {file = "coverage-7.6.8-cp312-cp312-win32.whl", hash = "sha256:9e89d5c8509fbd6c03d0dd1972925b22f50db0792ce06324ba069f10787429ad"}, + {file = "coverage-7.6.8-cp312-cp312-win_amd64.whl", hash = "sha256:379c111d3558272a2cae3d8e57e6b6e6f4fe652905692d54bad5ea0ca37c5ad4"}, + {file = "coverage-7.6.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b0c69f4f724c64dfbfe79f5dfb503b42fe6127b8d479b2677f2b227478db2eb"}, + {file = "coverage-7.6.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c15b32a7aca8038ed7644f854bf17b663bc38e1671b5d6f43f9a2b2bd0c46f63"}, + {file = "coverage-7.6.8-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63068a11171e4276f6ece913bde059e77c713b48c3a848814a6537f35afb8365"}, + {file = "coverage-7.6.8-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f4548c5ead23ad13fb7a2c8ea541357474ec13c2b736feb02e19a3085fac002"}, + {file = "coverage-7.6.8-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b4b4299dd0d2c67caaaf286d58aef5e75b125b95615dda4542561a5a566a1e3"}, + {file = "coverage-7.6.8-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c9ebfb2507751f7196995142f057d1324afdab56db1d9743aab7f50289abd022"}, + {file = "coverage-7.6.8-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c1b4474beee02ede1eef86c25ad4600a424fe36cff01a6103cb4533c6bf0169e"}, + {file = "coverage-7.6.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d9fd2547e6decdbf985d579cf3fc78e4c1d662b9b0ff7cc7862baaab71c9cc5b"}, + {file = "coverage-7.6.8-cp313-cp313-win32.whl", hash = "sha256:8aae5aea53cbfe024919715eca696b1a3201886ce83790537d1c3668459c7146"}, + {file = "coverage-7.6.8-cp313-cp313-win_amd64.whl", hash = "sha256:ae270e79f7e169ccfe23284ff5ea2d52a6f401dc01b337efb54b3783e2ce3f28"}, + {file = "coverage-7.6.8-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:de38add67a0af869b0d79c525d3e4588ac1ffa92f39116dbe0ed9753f26eba7d"}, + {file = "coverage-7.6.8-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b07c25d52b1c16ce5de088046cd2432b30f9ad5e224ff17c8f496d9cb7d1d451"}, + {file = "coverage-7.6.8-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62a66ff235e4c2e37ed3b6104d8b478d767ff73838d1222132a7a026aa548764"}, + {file = "coverage-7.6.8-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09b9f848b28081e7b975a3626e9081574a7b9196cde26604540582da60235fdf"}, + {file = "coverage-7.6.8-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:093896e530c38c8e9c996901858ac63f3d4171268db2c9c8b373a228f459bbc5"}, + {file = "coverage-7.6.8-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9a7b8ac36fd688c8361cbc7bf1cb5866977ece6e0b17c34aa0df58bda4fa18a4"}, + {file = "coverage-7.6.8-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:38c51297b35b3ed91670e1e4efb702b790002e3245a28c76e627478aa3c10d83"}, + {file = "coverage-7.6.8-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2e4e0f60cb4bd7396108823548e82fdab72d4d8a65e58e2c19bbbc2f1e2bfa4b"}, + {file = "coverage-7.6.8-cp313-cp313t-win32.whl", hash = "sha256:6535d996f6537ecb298b4e287a855f37deaf64ff007162ec0afb9ab8ba3b8b71"}, + {file = "coverage-7.6.8-cp313-cp313t-win_amd64.whl", hash = "sha256:c79c0685f142ca53256722a384540832420dff4ab15fec1863d7e5bc8691bdcc"}, + {file = "coverage-7.6.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3ac47fa29d8d41059ea3df65bd3ade92f97ee4910ed638e87075b8e8ce69599e"}, + {file = "coverage-7.6.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:24eda3a24a38157eee639ca9afe45eefa8d2420d49468819ac5f88b10de84f4c"}, + {file = "coverage-7.6.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4c81ed2820b9023a9a90717020315e63b17b18c274a332e3b6437d7ff70abe0"}, + {file = "coverage-7.6.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd55f8fc8fa494958772a2a7302b0354ab16e0b9272b3c3d83cdb5bec5bd1779"}, + {file = "coverage-7.6.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f39e2f3530ed1626c66e7493be7a8423b023ca852aacdc91fb30162c350d2a92"}, + {file = "coverage-7.6.8-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:716a78a342679cd1177bc8c2fe957e0ab91405bd43a17094324845200b2fddf4"}, + {file = "coverage-7.6.8-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:177f01eeaa3aee4a5ffb0d1439c5952b53d5010f86e9d2667963e632e30082cc"}, + {file = "coverage-7.6.8-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:912e95017ff51dc3d7b6e2be158dedc889d9a5cc3382445589ce554f1a34c0ea"}, + {file = "coverage-7.6.8-cp39-cp39-win32.whl", hash = "sha256:4db3ed6a907b555e57cc2e6f14dc3a4c2458cdad8919e40b5357ab9b6db6c43e"}, + {file = "coverage-7.6.8-cp39-cp39-win_amd64.whl", hash = "sha256:428ac484592f780e8cd7b6b14eb568f7c85460c92e2a37cb0c0e5186e1a0d076"}, + {file = "coverage-7.6.8-pp39.pp310-none-any.whl", hash = "sha256:5c52a036535d12590c32c49209e79cabaad9f9ad8aa4cbd875b68c4d67a9cbce"}, + {file = "coverage-7.6.8.tar.gz", hash = "sha256:8b2b8503edb06822c86d82fa64a4a5cb0760bb8f31f26e138ec743f422f37cfc"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -1341,6 +1418,24 @@ pytest = ">=7.0.0,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-cov" +version = "6.0.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pytest-cov-6.0.0.tar.gz", hash = "sha256:fde0b595ca248bb8e2d76f020b465f3b107c9632e6a1d1705f17834c89dcadc0"}, + {file = "pytest_cov-6.0.0-py3-none-any.whl", hash = "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35"}, +] + +[package.dependencies] +coverage = {version = ">=7.5", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] + [[package]] name = "pytest-socket" version = "0.7.0" @@ -1817,4 +1912,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "467fbef71f25acc67e0ea58075128fd0a1fa822f05f9c4ea3f953040aac42458" +content-hash = "3a6a2a86b0b2af7e6d6d947711f12afe71ee582c4b753359be920d27c047e958" diff --git a/libs/neo4j/pyproject.toml b/libs/neo4j/pyproject.toml index c141520..d1393a2 100644 --- a/libs/neo4j/pyproject.toml +++ b/libs/neo4j/pyproject.toml @@ -26,6 +26,7 @@ pytest-asyncio = "^0.23.2" pytest-socket = "^0.7.0" pytest-watcher = "^0.3.4" langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"} +pytest-cov = "^6.0.0" [tool.poetry.group.codespell] optional = true @@ -82,7 +83,7 @@ build-backend = "poetry.core.masonry.api" # section of the configuration file raise errors. # # https://github.com/tophat/syrupy -addopts = "--strict-markers --strict-config --durations=5" +addopts = "--strict-markers --strict-config --durations=5 --cov=langchain_neo4j --cov-report=html --cov-report=term" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ diff --git a/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py b/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py index c6d2aed..82b0153 100644 --- a/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py +++ b/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py @@ -1,17 +1,23 @@ """Test Neo4j functionality.""" +from typing import Any, Optional, Type from unittest.mock import MagicMock, patch import pytest from langchain_neo4j.vectorstores.neo4j_vector import ( + LOGICAL_OPERATORS, IndexType, Neo4jVector, SearchType, _get_search_index_query, + _handle_field_filter, + check_if_not_null, + construct_metadata_filter, dict_to_yaml_str, remove_lucene_chars, ) +from langchain_neo4j.vectorstores.utils import DistanceStrategy @pytest.fixture @@ -44,6 +50,95 @@ def mock_vector_store() -> Neo4jVector: return vector_store +@pytest.fixture +def neo4j_vector_factory() -> Any: + def _create_vector_store( + method: Optional[str] = None, + texts: Optional[list[str]] = None, + text_embeddings: Optional[list[tuple[str, list[float]]]] = None, + query_return_value: Optional[dict] = None, + verify_connectivity_side_effect: Optional[Exception] = None, + auth_error_class: Type[Exception] = Exception, + service_unavailable_class: Type[Exception] = Exception, + search_type: SearchType = SearchType.VECTOR, + **kwargs: Any, + ) -> Any: + mock_neo4j = MagicMock() + mock_driver_instance = MagicMock() + + # Configure verify_connectivity + if verify_connectivity_side_effect: + mock_driver_instance.verify_connectivity.side_effect = ( + verify_connectivity_side_effect + ) + else: + mock_driver_instance.verify_connectivity.return_value = None + + # Configure execute_query + if query_return_value is not None: + mock_driver_instance.execute_query.return_value = ( + [MagicMock(data=lambda: query_return_value)], + None, + None, + ) + else: + mock_driver_instance.execute_query.return_value = ( + [ + MagicMock( + data=lambda: {"versions": ["5.23.0"], "edition": "enterprise"} + ) + ], + None, + None, + ) + + # Assign the mocked driver to GraphDatabase.driver + mock_neo4j.GraphDatabase.driver.return_value = mock_driver_instance + mock_neo4j.exceptions.ServiceUnavailable = service_unavailable_class + mock_neo4j.exceptions.AuthError = auth_error_class + + with patch.dict("sys.modules", {"neo4j": mock_neo4j}): + query_return = ( + [query_return_value] + if query_return_value + else [{"versions": ["5.23.0"], "edition": "enterprise"}] + ) + with patch.object(Neo4jVector, "query", return_value=query_return): + embedding = kwargs.pop("embedding", MagicMock()) + common_kwargs = { + "embedding": embedding, + "url": "bolt://localhost:7687", + "username": "neo4j", + "password": "password", + "search_type": search_type, + **kwargs, + } + + if texts and method == "from_texts": + vector_store = Neo4jVector.from_texts(texts=texts, **common_kwargs) + elif text_embeddings and method == "from_embeddings": + vector_store = Neo4jVector.from_embeddings( + text_embeddings=text_embeddings, **common_kwargs + ) + elif method == "from_existing_index": + vector_store = Neo4jVector.from_existing_index(**common_kwargs) + elif method == "from_existing_relationship_index": + vector_store = Neo4jVector.from_existing_relationship_index( + **common_kwargs + ) + elif method == "from_existing_graph": + vector_store = Neo4jVector.from_existing_graph(**common_kwargs) + else: + vector_store = Neo4jVector(**common_kwargs) + + vector_store.node_label = "Chunk" + vector_store.embedding_node_property = "embedding" + vector_store.text_node_property = "text" + return vector_store + + return _create_vector_store + + def test_escaping_lucene() -> None: """Test escaping lucene characters""" assert remove_lucene_chars("Hello+World") == "Hello World" @@ -217,3 +312,711 @@ def test_build_delete_query_version_below_5_23(mock_vector_store: Neo4jVector) - actual_query = mock_vector_store._build_delete_query() assert actual_query == expected_query + + +def test_get_search_index_query_invalid_search_type() -> None: + invalid_search_type = "INVALID_TYPE" + + with pytest.raises(ValueError) as exc_info: + _get_search_index_query( + search_type=invalid_search_type, # type: ignore + index_type=IndexType.NODE, + ) + + assert "Unsupported SearchType" in str(exc_info.value) + + +def test_check_if_not_null_happy_case() -> None: + props = ["prop1", "prop2", "prop3"] + values = ["value1", 123, True] + check_if_not_null(props, values) + + +def test_check_if_not_null_with_empty_string() -> None: + props = ["prop1", "prop2", "prop3"] + values = ["valid", "valid", ""] + + with pytest.raises(ValueError) as exc_info: + check_if_not_null(props, values) + + assert "must not be None or empty string" in str(exc_info.value) + + +def test_check_if_not_null_with_none_value() -> None: + props = ["prop1", "prop2", "prop3"] + values = ["valid", None, "valid"] + + with pytest.raises(ValueError) as exc_info: + check_if_not_null(props, values) + + assert "must not be None or empty string" in str(exc_info.value) + + +def test_handle_field_filter_invalid_field_type() -> None: + with pytest.raises(ValueError) as exc_info: + _handle_field_filter(field=123, value="some_value") # type: ignore + assert "field should be a string" in str(exc_info.value) + + +def test_handle_field_filter_field_starts_with_dollar() -> None: + with pytest.raises(ValueError) as exc_info: + _handle_field_filter(field="$invalid_field", value="some_value") + assert "Invalid filter condition" in str(exc_info.value) + + +def test_handle_field_filter_invalid_field_name() -> None: + with pytest.raises(ValueError) as exc_info: + _handle_field_filter(field="invalid-field!", value="some_value") + assert "Invalid field name" in str(exc_info.value) + + +def test_handle_field_filter_multiple_keys_in_filter() -> None: + with pytest.raises(ValueError) as exc_info: + _handle_field_filter(field="age", value={"$gt": 30, "$lt": 40}) + assert "Invalid filter condition" in str(exc_info.value) + + +def test_handle_field_filter_invalid_operator() -> None: + with pytest.raises(ValueError) as exc_info: + _handle_field_filter(field="age", value={"$unknown": 30}) + assert "Invalid operator" in str(exc_info.value) + + +@pytest.mark.parametrize("operator", LOGICAL_OPERATORS) +def test_handle_field_filter_logical_operators(operator: str) -> None: + with pytest.raises(NotImplementedError): + _handle_field_filter(field="age", value={operator: {"$gt": 30, "$lt": 40}}) + + +def test_handle_field_filter_nin_operator() -> None: + field = "description" + value = ["sandworm", "spice"] + string, params = _handle_field_filter(field, {"$nin": value}, param_number=1) + expected_string = "n.`description` NOT IN $param_1" + expected_params = {"param_1": value} + assert string == expected_string + assert params == expected_params + + +def test_handle_field_filter_like_operator() -> None: + field = "description" + value = "spice%" + string, params = _handle_field_filter(field, {"$like": value}, param_number=2) + expected_string = "n.`description` CONTAINS $param_2" + expected_params = {"param_2": "spice"} + assert string == expected_string + assert params == expected_params + + +def test_handle_field_filter_ilike_operator() -> None: + field = "description" + value = "spice%" + string, params = _handle_field_filter(field, {"$ilike": value}, param_number=3) + expected_string = "toLower(n.`description`) CONTAINS $param_3" + expected_params = {"param_3": "spice"} + assert string == expected_string + assert params == expected_params + + +def test_handle_field_filter_in_operator_with_unsupported_types() -> None: + field = "tags" + value = {"$in": ["spice", {"unsupported": "type"}]} + with pytest.raises(NotImplementedError) as exc_info: + _handle_field_filter(field, value, param_number=1) + assert "Unsupported type" in str(exc_info.value) + + +def test_handle_field_filter_nin_operator_with_unsupported_types() -> None: + field = "tags" + value = {"$nin": ["spice", {"unsupported": "type"}]} + with pytest.raises(NotImplementedError) as exc_info: + _handle_field_filter(field, value, param_number=2) + assert "Unsupported type" in str(exc_info.value) + + +def test_construct_metadata_filter_invalid_top_level_operator() -> None: + filter_dict = {"$invalid": "value"} + with pytest.raises(ValueError) as exc_info: + construct_metadata_filter(filter_dict) + assert "Expected $and or $or but got: $invalid" in str(exc_info.value) + + +def test_construct_metadata_filter_logical_operator_with_non_list() -> None: + filter_dict = {"$and": {"id": 1}} + with pytest.raises(ValueError) as exc_info: + construct_metadata_filter(filter_dict) + assert "Expected a list" in str(exc_info.value) + + +def test_construct_metadata_filter_logical_operator_with_empty_list_and_operator() -> ( + None +): + filter_dict: dict = {"$and": []} + with pytest.raises(ValueError) as exc_info: + construct_metadata_filter(filter_dict) + assert ( + "Invalid filter condition. Expected a dictionary but got an empty dictionary" + in str(exc_info.value) + ) + + +def test_construct_metadata_filter_logical_operator_with_empty_list_or_operator() -> ( + None +): + filter_dict: dict = {"$or": []} + with pytest.raises(ValueError) as exc_info: + construct_metadata_filter(filter_dict) + assert ( + "Invalid filter condition. Expected a dictionary but got an empty dictionary" + in str(exc_info.value) + ) + + +def test_construct_metadata_filter_multiple_keys_with_operator() -> None: + filter_dict = {"id": 1, "$and": [{"name": "foo"}]} + with pytest.raises(ValueError) as exc_info: + construct_metadata_filter(filter_dict) + assert "Expected a field but got: $and" in str(exc_info.value) + + +def test_construct_metadata_filter_empty_filter() -> None: + filter_dict: dict = {} + with pytest.raises(ValueError) as exc_info: + construct_metadata_filter(filter_dict) + assert "Got an empty dictionary for filters." in str(exc_info.value) + + +def test_construct_metadata_filter_happy_case() -> None: + filter_dict = {"height": {"$gt": 5.0}} + string, params = construct_metadata_filter(filter_dict) + + expected_string = "n.`height` > $param_1" + expected_params = {"param_1": 5.0} + + assert string == expected_string + assert params == expected_params + + +def test_construct_metadata_filter_logical_operator_empty_collect_params() -> None: + filter_dict = {"id": 1, "name": "foo"} + with patch( + "langchain_neo4j.vectorstores.neo4j_vector.collect_params", + return_value=([], {}), + ): + with pytest.raises(ValueError) as exc_info: + construct_metadata_filter(filter_dict) + assert ( + "Invalid filter condition. Expected a dictionary but got an empty dictionary" + in str(exc_info.value) + ) + + +def test_neo4jvector_import_error() -> None: + with patch.dict("sys.modules", {"neo4j": None}): + with pytest.raises(ImportError) as exc_info: + Neo4jVector( + embedding=MagicMock(), + url="bolt://localhost:7687", + username="neo4j", + password="password", + ) + assert ( + "Could not import neo4j python package. Please install it with " + "`pip install neo4j`." in str(exc_info.value) + ) + + +def test_neo4jvector_invalid_distance_strategy() -> None: + with pytest.raises(ValueError) as exc_info: + Neo4jVector( + embedding=MagicMock(), + url="bolt://localhost:7687", + username="neo4j", + password="password", + distance_strategy="INVALID_STRATEGY", # type: ignore + ) + assert "distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'" in str( + exc_info.value + ) + + +def test_neo4jvector_service_unavailable() -> None: + mock_neo4j = MagicMock() + mock_neo4j.exceptions.ServiceUnavailable = Exception + + mock_driver_instance = MagicMock() + mock_driver_instance.verify_connectivity.side_effect = ( + mock_neo4j.exceptions.ServiceUnavailable + ) + mock_neo4j.GraphDatabase.driver.return_value = mock_driver_instance + + with patch.dict("sys.modules", {"neo4j": mock_neo4j}): + with pytest.raises(ValueError) as exc_info: + Neo4jVector( + embedding=MagicMock(), + url="bolt://invalid_host:7687", + username="neo4j", + password="password", + ) + assert ( + "Could not connect to Neo4j database. Please ensure that the url is correct" + in str(exc_info.value) + ) + + +def test_neo4jvector_auth_error(neo4j_vector_factory: Any) -> None: + class MockAuthError(Exception): + pass + + class MockServiceUnavailable(Exception): + pass + + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + verify_connectivity_side_effect=MockAuthError("Authentication Failed"), + auth_error_class=MockAuthError, + service_unavailable_class=MockServiceUnavailable, + ) + + assert ( + "Could not connect to Neo4j database. Please ensure that the username " + "and password are correct" in str(exc_info.value) + ) + + +def test_neo4jvector_version_with_aura(neo4j_vector_factory: Any) -> None: + aura_version_response = {"versions": ["5.11.0-aura"], "edition": "enterprise"} + vector_store = neo4j_vector_factory(query_return_value=aura_version_response) + assert not vector_store.neo4j_version_is_5_23_or_above + + +def test_neo4jvector_version_too_low(neo4j_vector_factory: Any) -> None: + low_version_response = {"versions": ["5.10.0"], "edition": "enterprise"} + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory(query_return_value=low_version_response) + assert "Version index is only supported in Neo4j version 5.11 or greater" in str( + exc_info.value + ) + + +def test_neo4jvector_metadata_filter_version(neo4j_vector_factory: Any) -> None: + version_response = {"versions": ["5.17.0"], "edition": "enterprise"} + vector_store = neo4j_vector_factory(query_return_value=version_response) + assert vector_store.support_metadata_filter is False + + +def test_neo4jvector_relationship_index_error(neo4j_vector_factory: Any) -> None: + texts = ["text1", "text2"] + + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(None, "RELATIONSHIP") + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_texts", texts=texts, search_type=SearchType.VECTOR + ) + assert "Data ingestion is not supported with relationship vector index." in str( + exc_info.value + ) + + +def test_neo4jvector_embedding_dimension_mismatch(neo4j_vector_factory: Any) -> None: + texts = ["text1", "text2"] + + mock_embedding = MagicMock() + mock_embedding.embed_query.return_value = [0.1] * 64 + + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(128, "NODE") + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_texts", + texts=texts, + embedding=mock_embedding, + search_type=SearchType.VECTOR, + ) + assert ( + "The provided embedding function and vector index dimensions do not match." + in str(exc_info.value) + ) + + +def test_neo4jvector_fts_vector_node_label_mismatch(neo4j_vector_factory: Any) -> None: + texts = ["text1", "text2"] + embedding_dimension = 64 + + mock_embedding = MagicMock() + mock_embedding.embed_query.return_value = [0.1] * embedding_dimension + + with patch.object( + Neo4jVector, + "retrieve_existing_index", + return_value=(embedding_dimension, "NODE"), + ), patch.object( + Neo4jVector, "retrieve_existing_fts_index", return_value="DifferentNodeLabel" + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_texts", + texts=texts, + embedding=mock_embedding, + search_type=SearchType.HYBRID, + node_label="TestLabel", + keyword_index_name="keyword_index", + ) + assert "Vector and keyword index don't index the same node label" in str( + exc_info.value + ) + + +def test_similarity_search_by_vector_metadata_filter_unsupported( + neo4j_vector_factory: Any, +) -> None: + """ + Test that similarity_search_by_vector raises ValueError when metadata + filtering is unsupported. + """ + vector_store = neo4j_vector_factory() + vector_store.support_metadata_filter = False + vector_store.search_type = SearchType.VECTOR + vector_store.embedding_dimension = 64 + + with pytest.raises(ValueError) as exc_info: + vector_store.similarity_search_by_vector( + embedding=[0] * 64, + filter={"field": "value"}, + ) + assert ( + "Metadata filtering is only supported in Neo4j version 5.18 or greater" + in str(exc_info.value) + ) + + +def test_similarity_search_by_vector_metadata_filter_hybrid( + neo4j_vector_factory: Any, +) -> None: + vector_store = neo4j_vector_factory() + + vector_store.support_metadata_filter = True + vector_store.search_type = SearchType.HYBRID + vector_store.embedding_dimension = 64 + + with pytest.raises(ValueError) as exc_info: + vector_store.similarity_search_by_vector( + embedding=[0] * 64, + filter={"field": "value"}, + ) + assert ( + "Metadata filtering can't be use in combination with a hybrid search approach" + in str(exc_info.value) + ) + + +def test_from_existing_index_relationship_index_error( + neo4j_vector_factory: Any, +) -> None: + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(64, "RELATIONSHIP") + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_index", + index_name="test_index", + search_type=SearchType.VECTOR, + ) + assert ( + "Relationship vector index is not supported with `from_existing_index` " + "method." in str(exc_info.value) + ) + + +def test_from_existing_index_index_not_found(neo4j_vector_factory: Any) -> None: + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(None, None) + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_index", + embedding=MagicMock(), + index_name="non_existent_index", + ) + assert "The specified vector index name does not exist." in str(exc_info.value) + + +def test_from_existing_index_fts_vector_node_label_mismatch( + neo4j_vector_factory: Any, +) -> None: + embedding_dimension = 64 + + mock_embedding = MagicMock() + mock_embedding.embed_query.return_value = [0.1] * embedding_dimension + + with patch.object( + Neo4jVector, + "retrieve_existing_index", + return_value=(embedding_dimension, "NODE"), + ), patch.object( + Neo4jVector, "retrieve_existing_fts_index", return_value="DifferentNodeLabel" + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_index", + embedding=mock_embedding, + index_name="test_index", + search_type=SearchType.HYBRID, + keyword_index_name="keyword_index", + ) + + assert "Vector and keyword index don't index the same node label" in str( + exc_info.value + ) + + +def test_from_existing_relationship_index_hybrid_not_supported() -> None: + with pytest.raises(ValueError) as exc_info: + Neo4jVector.from_existing_relationship_index( + embedding=MagicMock(), + index_name="test_index", + search_type=SearchType.HYBRID, + ) + assert ( + "Hybrid search is not supported in combination with relationship vector index" + in str(exc_info.value) + ) + + +def test_from_existing_relationship_index_index_not_found( + neo4j_vector_factory: Any, +) -> None: + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(None, None) + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_relationship_index", + index_name="non_existent_index", + ) + assert "The specified vector index name does not exist" in str(exc_info.value) + + +def test_from_existing_relationship_index_node_index_error() -> None: + with patch.object(Neo4jVector, "__init__", return_value=None): + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(64, "NODE") + ): + with pytest.raises(ValueError) as exc_info: + Neo4jVector.from_existing_relationship_index( + embedding=MagicMock(), + index_name="test_index", + ) + assert ( + "Node vector index is not supported with " + "`from_existing_relationship_index` method" in str(exc_info.value) + ) + + +def test_from_existing_relationship_index_embedding_dimension_mismatch( + neo4j_vector_factory: Any, +) -> None: + mock_embedding = MagicMock() + mock_embedding.embed_query.return_value = [0.1] * 64 + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(128, "RELATIONSHIP") + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_relationship_index", + embedding=mock_embedding, + index_name="test_index", + search_type=SearchType.VECTOR, + ) + + assert ( + "The provided embedding function and vector index dimensions do not match" + in str(exc_info.value) + ) + + +def test_from_existing_graph_empty_text_node_properties() -> None: + with pytest.raises(ValueError) as exc_info: + Neo4jVector.from_existing_graph( + embedding=MagicMock(), + node_label="TestLabel", + embedding_node_property="embedding", + text_node_properties=[], + ) + assert "Parameter `text_node_properties` must not be an empty list" in str( + exc_info.value + ) + + +def test_from_existing_graph_relationship_index_error( + neo4j_vector_factory: Any, +) -> None: + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(64, "RELATIONSHIP") + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_graph", + embedding=MagicMock(), + node_label="TestLabel", + embedding_node_property="embedding", + text_node_properties=["text_property"], + search_type=SearchType.HYBRID, + keyword_index_name="keyword_index", + ) + + assert ( + "`from_existing_graph` method does not support existing relationship " + "vector index. Please use `from_existing_relationship_index` method" + in str(exc_info.value) + ) + + +def test_from_existing_graph_embedding_dimension_mismatch( + neo4j_vector_factory: Any, +) -> None: + mock_embedding = MagicMock() + mock_embedding.embed_query.return_value = [0.1] * 64 + + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(128, "NODE") + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_graph", + embedding=mock_embedding, + node_label="TestLabel", + embedding_node_property="embedding", + text_node_properties=["text_property"], + search_type=SearchType.VECTOR, + ) + + assert ( + "The provided embedding function and vector index dimensions do not match" + in str(exc_info.value) + ) + + +def test_from_existing_graph_fts_vector_node_label_mismatch( + neo4j_vector_factory: Any, +) -> None: + mock_embedding = MagicMock() + mock_embedding.embed_query.return_value = [0.1] * 64 + with patch.object( + Neo4jVector, "retrieve_existing_index", return_value=(64, "NODE") + ), patch.object( + Neo4jVector, "retrieve_existing_fts_index", return_value="DifferentNodeLabel" + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_graph", + embedding=mock_embedding, + node_label="TestLabel", + embedding_node_property="embedding", + text_node_properties=["text_property"], + search_type=SearchType.HYBRID, + keyword_index_name="keyword_index", + ) + + assert "Vector and keyword index don't index the same node label" in str( + exc_info.value + ) + + +def test_select_relevance_score_fn_override(neo4j_vector_factory: Any) -> None: + def override_fn(x: int) -> int: + return x * 2 + + vector_store = neo4j_vector_factory( + embedding=MagicMock(), + search_type=SearchType.VECTOR, + relevance_score_fn=override_fn, + ) + fn = vector_store._select_relevance_score_fn() + + assert fn(2) == 4 + + +def test_select_relevance_score_fn_invalid_distance_strategy( + neo4j_vector_factory: Any, +) -> None: + vector_store = neo4j_vector_factory( + embedding=MagicMock(), search_type=SearchType.VECTOR + ) + vector_store._distance_strategy = "INVALID_STRATEGY" + + with pytest.raises(ValueError) as exc_info: + vector_store._select_relevance_score_fn() + + assert ( + "No supported normalization function for distance_strategy of INVALID_STRATEGY" + in str(exc_info.value) + ) + + +def test_select_relevance_score_fn_euclidean_distance( + neo4j_vector_factory: Any, +) -> None: + vector_store = neo4j_vector_factory( + embedding=MagicMock(), distance_strategy=DistanceStrategy.EUCLIDEAN_DISTANCE + ) + + assert vector_store._distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE + + +def test_select_relevance_score_fn_cosine(neo4j_vector_factory: Any) -> None: + vector_store = neo4j_vector_factory( + embedding=MagicMock(), distance_strategy=DistanceStrategy.COSINE + ) + + assert vector_store._distance_strategy == DistanceStrategy.COSINE + + +def test_from_existing_index_keyword_index_not_exist(neo4j_vector_factory: Any) -> None: + mock_embedding = MagicMock() + mock_embedding.embed_query.return_value = [0.1] * 64 + + with ( + patch.object(Neo4jVector, "retrieve_existing_index", return_value=(64, "NODE")), + patch.object(Neo4jVector, "retrieve_existing_fts_index", return_value=None), + ): + with pytest.raises(ValueError) as exc_info: + neo4j_vector_factory( + method="from_existing_index", + embedding=mock_embedding, + index_name="vector_index", + search_type=SearchType.HYBRID, + keyword_index_name="nonexistent_keyword_index", + ) + expected_message = ( + "The specified keyword index name does not exist. " + "Make sure to check if you spelled it correctly" + ) + assert expected_message in str(exc_info.value) + + +def test_select_relevance_score_fn_unsupported_strategy( + neo4j_vector_factory: Any, +) -> None: + vector_store = neo4j_vector_factory( + embedding=MagicMock(), distance_strategy=DistanceStrategy.COSINE + ) + + vector_store._distance_strategy = "UNSUPPORTED_STRATEGY" + + with pytest.raises(ValueError) as exc_info: + vector_store._select_relevance_score_fn() + + expected_message = ( + "No supported normalization function for distance_strategy " + "of UNSUPPORTED_STRATEGY." + "Consider providing relevance_score_fn to PGVector constructor." + ) + + assert expected_message in str(exc_info.value), ( + f"Expected error message to contain '{expected_message}' " + f"but got '{str(exc_info.value)}'" + )