Skip to content

Commit

Permalink
Neo4jGraph 100% coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Dec 12, 2024
1 parent 9912137 commit 551b03f
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 10 deletions.
2 changes: 1 addition & 1 deletion libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str:
formatted_rel_props.append(f"- **{rel_type}**")
for prop in properties:
example = ""
if prop["type"] == "STRING":
if prop["type"] == "STRING" and prop.get("values"):
if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT:
example = (
f'Example: "{clean_string_values(prop["values"][0])}"'
Expand Down
41 changes: 36 additions & 5 deletions libs/neo4j/tests/integration_tests/graphs/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
source=Node(id="foo", type="foo"),
target=Node(id="bar", type="bar"),
type="REL",
properties={"key": "val"},
)
],
source=Document(page_content="source document"),
Expand Down Expand Up @@ -349,7 +350,16 @@ def test_enhanced_schema() -> None:
}
],
},
"rel_props": {},
"rel_props": {
"REL": [
{
"distinct_count": 1,
"property": "key",
"type": "STRING",
"values": ["val"],
}
]
},
"relationships": [{"start": "foo", "type": "REL", "end": "bar"}],
}
# remove metadata portion of schema
Expand All @@ -367,16 +377,37 @@ def test_enhanced_schema_exception() -> None:
assert password is not None

graph = Neo4jGraph(
url=url, username=username, password=password, enhanced_schema=True
url=url,
username=username,
password=password,
enhanced_schema=True,
refresh_schema=False,
)
graph.query("MATCH (n) DETACH DELETE n")
graph.query("CREATE (:Node {foo:'bar'})," "(:Node {foo: 1}), (:Node {foo: [1,2]})")
graph.query(
"CREATE (:Node {foo: 'bar'}), (:Node {foo: 1}), (:Node {foo: [1,2]}), "
"(: EmptyNode)"
)
graph.query(
"MATCH (a:Node {foo: 'bar'}), (b:Node {foo: 1}), "
"(c:Node {foo: [1,2]}), (d: EmptyNode) "
"CREATE (a)-[:REL {foo: 'bar'}]->(b), (b)-[:REL {foo: 1}]->(c), "
"(c)-[:REL {foo: [1,2]}]->(a), (d)-[:EMPTY_REL {}]->(d)"
)
graph.refresh_schema()
expected_output = {
"node_props": {"Node": [{"property": "foo", "type": "STRING"}]},
"rel_props": {},
"relationships": [],
"rel_props": {"REL": [{"property": "foo", "type": "STRING"}]},
"relationships": [
{
"end": "Node",
"start": "Node",
"type": "REL",
},
{"end": "EmptyNode", "start": "EmptyNode", "type": "EMPTY_REL"},
],
}

# remove metadata portion of schema
del graph.structured_schema["metadata"]
assert graph.structured_schema == expected_output
Expand Down
151 changes: 147 additions & 4 deletions libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from types import ModuleType
from typing import Generator, Mapping, Sequence, Union
from typing import Any, Dict, Generator, Mapping, Sequence, Union
from unittest.mock import MagicMock, patch

import pytest
from neo4j.exceptions import ClientError, Neo4jError

from langchain_neo4j.graphs.neo4j_graph import (
LIST_LIMIT,
Expand Down Expand Up @@ -71,7 +72,10 @@ def mock_neo4j_driver() -> Generator[MagicMock, None, None]:
),
],
)
def test_value_sanitize(description, input_value, expected_output):
def test_value_sanitize(
description: str, input_value: Dict[str, Any], expected_output: Any
) -> None:
"""Test the value_sanitize function."""
assert (
value_sanitize(input_value) == expected_output
), f"Failed test case: {description}"
Expand Down Expand Up @@ -188,6 +192,143 @@ def mock_import(
assert "Could not import neo4j python package." in str(exc_info.value)


def test_neo4j_graph_init_with_empty_credentials() -> None:
"""Test the __init__ method when no credentials have been provided."""
with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver:
mock_driver_instance = MagicMock()
mock_driver.return_value = mock_driver_instance
mock_driver_instance.verify_connectivity.return_value = None
Neo4jGraph(
url="bolt://localhost:7687", username="", password="", refresh_schema=False
)
mock_driver.assert_called_with("bolt://localhost:7687", auth=None)


def test_init_apoc_procedure_not_found(
mock_neo4j_driver: MagicMock,
) -> None:
"""Test an error is raised when APOC is not installed."""
with patch("langchain_neo4j.Neo4jGraph.refresh_schema") as mock_refresh_schema:
err = ClientError()
err.code = "Neo.ClientError.Procedure.ProcedureNotFound"
mock_refresh_schema.side_effect = err
with pytest.raises(ValueError) as exc_info:
Neo4jGraph(url="bolt://localhost:7687", username="", password="")
assert "Could not use APOC procedures." in str(exc_info.value)


def test_init_refresh_schema_other_err(
mock_neo4j_driver: MagicMock,
) -> None:
"""Test any other ClientErrors raised when calling refresh_schema in __init__ are
re-raised."""
with patch("langchain_neo4j.Neo4jGraph.refresh_schema") as mock_refresh_schema:
err = ClientError()
err.code = "other_error"
mock_refresh_schema.side_effect = err
with pytest.raises(ClientError) as exc_info:
Neo4jGraph(url="bolt://localhost:7687", username="", password="")
assert exc_info.value == err


def test_query_fallback_execution(mock_neo4j_driver: MagicMock) -> None:
"""Test the fallback to allow for implicit transactions in query."""
err = Neo4jError()
err.code = "Neo.DatabaseError.Statement.ExecutionFailed"
err.message = "in an implicit transaction"
mock_neo4j_driver.execute_query.side_effect = err
graph = Neo4jGraph(
url="bolt://localhost:7687",
username="neo4j",
password="password",
database="test_db",
sanitize=True,
)
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.data.return_value = {
"key1": "value1",
"oversized_list": list(range(LIST_LIMIT + 1)),
}
mock_session.run.return_value = [mock_result]
mock_neo4j_driver.session.return_value.__enter__.return_value = mock_session
mock_neo4j_driver.session.return_value.__exit__.return_value = None
query = "MATCH (n) RETURN n;"
params = {"param1": "value1"}
json_data = graph.query(query, params)
mock_neo4j_driver.session.assert_called_with(database="test_db")
called_args, _ = mock_session.run.call_args
called_query = called_args[0]
assert called_query.text == query
assert called_query.timeout == graph.timeout
assert called_args[1] == params
assert json_data == [{"key1": "value1"}]


def test_refresh_schema_handles_client_error(mock_neo4j_driver: MagicMock) -> None:
"""Test refresh schema handles a client error which might arise due to a user
not having access to schema information"""

graph = Neo4jGraph(
url="bolt://localhost:7687",
username="neo4j",
password="password",
database="test_db",
)
node_properties = [
{
"output": {
"properties": [{"property": "property_a", "type": "STRING"}],
"labels": "LabelA",
}
}
]
relationships_properties = [
{
"output": {
"type": "REL_TYPE",
"properties": [{"property": "rel_prop", "type": "STRING"}],
}
}
]
relationships = [
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}},
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}},
]

# Mock the query method to raise ClientError for constraint and index queries
graph.query = MagicMock( # type: ignore[method-assign]
side_effect=[
node_properties,
relationships_properties,
relationships,
ClientError("Mock ClientError"),
]
)
graph.refresh_schema()

# Assertions
# Ensure constraints and indexes are empty due to the ClientError
assert graph.structured_schema["metadata"]["constraint"] == []
assert graph.structured_schema["metadata"]["index"] == []

# Ensure the query method was called as expected
assert graph.query.call_count == 4
graph.query.assert_any_call("SHOW CONSTRAINTS")


def test_get_schema(mock_neo4j_driver) -> None:
"""Tests the get_schema property."""
graph = Neo4jGraph(
url="bolt://localhost:7687",
username="neo4j",
password="password",
refresh_schema=False
)
graph.schema = "test"
assert graph.get_schema == "test"


@pytest.mark.parametrize(
"description, schema, is_enhanced, expected_output",
[
Expand Down Expand Up @@ -547,7 +688,9 @@ def mock_import(
),
],
)
def test_format_schema(description, schema, is_enhanced, expected_output):
def test_format_schema(
description: str, schema: Dict, is_enhanced: bool, expected_output: str
) -> None:
result = _format_schema(schema, is_enhanced)
assert result == expected_output, f"Failed test case: {description}"

Expand Down Expand Up @@ -685,7 +828,7 @@ def test_enhanced_schema_cypher_string_exhaustive_false_with_index(
]
}
}
graph.query = MagicMock(return_value=[{"value": ["Single", "Married", "Divorced"]}])
graph.query = MagicMock(return_value=[{"value": ["Single", "Married", "Divorced"]}]) # type: ignore[method-assign]
properties = [{"property": "status", "type": "STRING"}]
query = graph._enhanced_schema_cypher("Person", properties, exhaustive=False)
assert "values: ['Single', 'Married', 'Divorced'], distinct_count: 3" in query
Expand Down

0 comments on commit 551b03f

Please sign in to comment.