Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include variable scope clause in deprecated Cypher query #6

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ __pycache__
.mypy_cache_test
.env
.venv*
.idea
49 changes: 34 additions & 15 deletions libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,34 +84,39 @@ class IndexType(str, enum.Enum):


def _get_search_index_query(
search_type: SearchType, index_type: IndexType = DEFAULT_INDEX_TYPE
search_type: SearchType,
index_type: IndexType = DEFAULT_INDEX_TYPE,
neo4j_version_is_5_23_or_above: bool = False,
) -> str:
if index_type == IndexType.NODE:
type_to_query_map = {
SearchType.VECTOR: (
if search_type == SearchType.VECTOR:
return (
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score "
),
SearchType.HYBRID: (
"CALL { "
)
elif search_type == SearchType.HYBRID:
call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { "

query_body = (
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
# We use 0 as min
"RETURN n.node AS node, (n.score / max) AS score UNION "
"CALL db.index.fulltext.queryNodes($keyword_index, $query, "
"{limit: $k}) YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
# We use 0 as min
"RETURN n.node AS node, (n.score / max) AS score "
"} "
# dedup
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
),
}
return type_to_query_map[search_type]
)

call_suffix = (
"} WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
)

return call_prefix + query_body + call_suffix
else:
raise ValueError(f"Unsupported SearchType: {search_type}")
else:
return (
"CALL db.index.vector.queryRelationships($index, $k, $embedding) "
Expand Down Expand Up @@ -666,6 +671,10 @@ def verify_version(self) -> None:
else:
version_tuple = tuple(map(int, version.split(".")))

self.neo4j_version_is_5_23_or_above = self._check_if_version_5_23_or_above(
version_tuple
)

target_version = (5, 11, 0)

if version_tuple < target_version:
Expand All @@ -682,6 +691,14 @@ def verify_version(self) -> None:
# Flag for enterprise
self._is_enterprise = True if db_data[0]["edition"] == "enterprise" else False

def _check_if_version_5_23_or_above(self, version_tuple: tuple[int, ...]) -> bool:
"""
Check if the connected Neo4j database version supports the required features.

Sets a flag if the connected Neo4j version is 5.23 or above.
"""
return version_tuple >= (5, 23, 0)

def retrieve_existing_index(self) -> Tuple[Optional[int], Optional[str]]:
"""
Check if the vector index exists in the Neo4j database
Expand Down Expand Up @@ -1064,7 +1081,9 @@ def similarity_search_with_score_by_vector(
index_query = base_index_query + filter_snippets + base_cosine_query

else:
index_query = _get_search_index_query(self.search_type, self._index_type)
index_query = _get_search_index_query(
self.search_type, self._index_type, self.neo4j_version_is_5_23_or_above
)
filter_params = {}

if self._index_type == IndexType.RELATIONSHIP:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
texts = ["foo", "bar", "baz", "It is the end of the world. Take shelter!"]

"""
cd tests/integration_tests/vectorstores/docker-compose
cd tests/integration_tests/docker-compose
docker-compose -f neo4j.yml up
"""

Expand Down
47 changes: 47 additions & 0 deletions libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Test Neo4j functionality."""

from langchain_neo4j.vectorstores.neo4j_vector import (
IndexType,
SearchType,
_get_search_index_query,
dict_to_yaml_str,
remove_lucene_chars,
)
Expand Down Expand Up @@ -65,3 +68,47 @@ def test_converting_to_yaml() -> None:
)

assert yaml_str == expected_output


def test_get_search_index_query_hybrid_node_neo4j_5_23_above() -> None:
expected_query = (
"CALL () { "
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score UNION "
"CALL db.index.fulltext.queryNodes($keyword_index, $query, "
"{limit: $k}) YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
)

actual_query = _get_search_index_query(SearchType.HYBRID, IndexType.NODE, True)

assert actual_query == expected_query


def test_get_search_index_query_hybrid_node_neo4j_5_23_below() -> None:
expected_query = (
"CALL { "
"CALL db.index.vector.queryNodes($index, $k, $embedding) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score UNION "
"CALL db.index.fulltext.queryNodes($keyword_index, $query, "
"{limit: $k}) YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS max "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / max) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $k "
)

actual_query = _get_search_index_query(SearchType.HYBRID, IndexType.NODE, False)

assert actual_query == expected_query
Loading