Skip to content

Commit

Permalink
Adds tests so that Neo4jGraph class now has 100% coverage (langchain-…
Browse files Browse the repository at this point in the history
…ai#21)

* Added tests for _format_schema helper function

* Added mock Neo4j driver to Neo4jGraph tests

* Added _enhanced_schema_cypher tests

* Parametrised _format_schema tests

* Parametrised value_sanitize tests

* test_format_schema refactoring

* More refactoring + minor tests added

* Neo4jGraph 100% coverage

* Fixed linting issues
  • Loading branch information
alexthomas93 authored Dec 13, 2024
1 parent 6b31b1f commit cfa583a
Show file tree
Hide file tree
Showing 3 changed files with 948 additions and 127 deletions.
10 changes: 5 additions & 5 deletions libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str:
"DATE_TIME",
"LOCAL_DATE_TIME",
]:
if prop.get("min") is not None:
if prop.get("min") and prop.get("max"):
example = f'Min: {prop["min"]}, Max: {prop["max"]}'
else:
example = (
Expand All @@ -215,7 +215,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str:
formatted_rel_props.append(f"- **{rel_type}**")
for prop in properties:
example = ""
if prop["type"] == "STRING":
if prop["type"] == "STRING" and prop.get("values"):
if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT:
example = (
f'Example: "{clean_string_values(prop["values"][0])}"'
Expand All @@ -238,8 +238,8 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str:
"DATE_TIME",
"LOCAL_DATE_TIME",
]:
if prop.get("min"): # If we have min/max
example = f'Min: {prop["min"]}, Max: {prop["max"]}'
if prop.get("min") and prop.get("max"): # If we have min/max
example = f'Min: {prop["min"]}, Max: {prop["max"]}'
else: # return a single value
example = (
f'Example: "{prop["values"][0]}"' if prop["values"] else ""
Expand All @@ -252,7 +252,7 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str:
f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
)
formatted_rel_props.append(
f" - `{prop['property']}: {prop['type']}` {example}"
f" - `{prop['property']}`: {prop['type']} {example}"
)
else:
# Format node properties
Expand Down
89 changes: 83 additions & 6 deletions libs/neo4j/tests/integration_tests/graphs/test_neo4j.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import urllib

import pytest
from langchain_core.documents import Document

from langchain_neo4j import Neo4jGraph
Expand All @@ -19,6 +21,7 @@
source=Node(id="foo", type="foo"),
target=Node(id="bar", type="bar"),
type="REL",
properties={"key": "val"},
)
],
source=Document(page_content="source document"),
Expand Down Expand Up @@ -130,7 +133,7 @@ def test_neo4j_timeout() -> None:


def test_neo4j_sanitize_values() -> None:
"""Test that neo4j uses the timeout correctly."""
"""Test that lists with more than 128 elements are removed from the results."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
Expand Down Expand Up @@ -347,7 +350,16 @@ def test_enhanced_schema() -> None:
}
],
},
"rel_props": {},
"rel_props": {
"REL": [
{
"distinct_count": 1,
"property": "key",
"type": "STRING",
"values": ["val"],
}
]
},
"relationships": [{"start": "foo", "type": "REL", "end": "bar"}],
}
# remove metadata portion of schema
Expand All @@ -365,16 +377,37 @@ def test_enhanced_schema_exception() -> None:
assert password is not None

graph = Neo4jGraph(
url=url, username=username, password=password, enhanced_schema=True
url=url,
username=username,
password=password,
enhanced_schema=True,
refresh_schema=False,
)
graph.query("MATCH (n) DETACH DELETE n")
graph.query("CREATE (:Node {foo:'bar'})," "(:Node {foo: 1}), (:Node {foo: [1,2]})")
graph.query(
"CREATE (:Node {foo: 'bar'}), (:Node {foo: 1}), (:Node {foo: [1,2]}), "
"(: EmptyNode)"
)
graph.query(
"MATCH (a:Node {foo: 'bar'}), (b:Node {foo: 1}), "
"(c:Node {foo: [1,2]}), (d: EmptyNode) "
"CREATE (a)-[:REL {foo: 'bar'}]->(b), (b)-[:REL {foo: 1}]->(c), "
"(c)-[:REL {foo: [1,2]}]->(a), (d)-[:EMPTY_REL {}]->(d)"
)
graph.refresh_schema()
expected_output = {
"node_props": {"Node": [{"property": "foo", "type": "STRING"}]},
"rel_props": {},
"relationships": [],
"rel_props": {"REL": [{"property": "foo", "type": "STRING"}]},
"relationships": [
{
"end": "Node",
"start": "Node",
"type": "REL",
},
{"end": "EmptyNode", "start": "EmptyNode", "type": "EMPTY_REL"},
],
}

# remove metadata portion of schema
del graph.structured_schema["metadata"]
assert graph.structured_schema == expected_output
Expand Down Expand Up @@ -558,3 +591,47 @@ def test_neo4j_multiple_close() -> None:
# Test that multiple closes don't raise errors
graph.close()
graph.close() # This should not raise an error


def test_invalid_url() -> None:
"""Test initializing with invalid credentials raises ValueError."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
username = os.environ.get("NEO4J_USERNAME", "neo4j")
password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein")
assert url is not None
assert username is not None
assert password is not None

# Parse the original URL
parsed_url = urllib.parse.urlparse(url)
# Increment the port number by 1 and wrap around if necessary
original_port = parsed_url.port or 7687
new_port = (original_port + 1) % 65535 or 1
# Reconstruct the netloc (hostname:port)
new_netloc = f"{parsed_url.hostname}:{new_port}"
# Rebuild the URL with the new netloc
new_url = parsed_url._replace(netloc=new_netloc).geturl()

with pytest.raises(ValueError) as exc_info:
Neo4jGraph(
url=new_url,
username=username,
password=password,
)
assert "Please ensure that the url is correct" in str(exc_info.value)


def test_invalid_credentials() -> None:
"""Test initializing with invalid credentials raises ValueError."""
url = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
assert url is not None

with pytest.raises(ValueError) as exc_info:
Neo4jGraph(
url=url,
username="invalid_username",
password="invalid_password",
)
assert "Please ensure that the username and password are correct" in str(
exc_info.value
)
Loading

0 comments on commit cfa583a

Please sign in to comment.