diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index e84a4df..514870f 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -346,7 +346,7 @@ def from_llm( if validate_cypher: corrector_schema = [ Schema(el["start"], el["type"], el["end"]) - for el in kwargs["graph"].structured_schema.get("relationships") + for el in kwargs["graph"].get_structured_schema.get("relationships", []) ] cypher_query_corrector = CypherQueryCorrector(corrector_schema) diff --git a/libs/neo4j/tests/unit_tests/llms/__init__.py b/libs/neo4j/tests/integration_tests/chains/__init__.py similarity index 100% rename from libs/neo4j/tests/unit_tests/llms/__init__.py rename to libs/neo4j/tests/integration_tests/chains/__init__.py diff --git a/libs/neo4j/tests/integration_tests/chains/test_graph_database.py b/libs/neo4j/tests/integration_tests/chains/test_graph_database.py index 56ecc1b..eda5683 100644 --- a/libs/neo4j/tests/integration_tests/chains/test_graph_database.py +++ b/libs/neo4j/tests/integration_tests/chains/test_graph_database.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock from langchain_core.language_models import BaseLanguageModel -from langchain_core.language_models.fake import FakeListLLM from langchain_neo4j.chains.graph_qa.cypher import GraphCypherQAChain from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph +from tests.llms.fake_llm import FakeLLM def test_connect_neo4j() -> None: @@ -71,10 +71,13 @@ def test_cypher_generating_run() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = FakeListLLM(responses=[query, "Bruce Willis"]) + llm = FakeLLM( + queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True + ) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, + validate_cypher=True, allow_dangerous_requests=True, ) output = chain.run("Who starred in Pulp Fiction?") @@ -111,7 +114,7 @@ def test_cypher_top_k() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = FakeListLLM(responses=[query]) + llm = FakeLLM(queries={"query": query}, sequential_responses=True) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, @@ -149,7 +152,9 @@ def test_cypher_intermediate_steps() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = FakeListLLM(responses=[query, "Bruce Willis"]) + llm = FakeLLM( + queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True + ) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, @@ -194,7 +199,7 @@ def test_cypher_return_direct() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = FakeListLLM(responses=[query]) + llm = FakeLLM(queries={"query": query}, sequential_responses=True) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, @@ -206,6 +211,46 @@ def test_cypher_return_direct() -> None: assert output == expected_output +def test_function_response() -> None: + """Test returning a function response.""" + url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") + username = os.environ.get("NEO4J_USERNAME", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "pleaseletmein") + + graph = Neo4jGraph( + url=url, + username=username, + password=password, + ) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Create two nodes and a relationship + graph.query( + "CREATE (a:Actor {name:'Bruce Willis'})" + "-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})" + ) + # Refresh schema information + graph.refresh_schema() + + query = ( + "MATCH (a:Actor)-[:ACTED_IN]->(m:Movie) " + "WHERE m.title = 'Pulp Fiction' " + "RETURN a.name" + ) + llm = FakeLLM( + queries={"query": query, "response": "Bruce Willis"}, sequential_responses=True + ) + chain = GraphCypherQAChain.from_llm( + llm=llm, + graph=graph, + allow_dangerous_requests=True, + use_function_response=True, + ) + output = chain.run("Who starred in Pulp Fiction?") + expected_output = "Bruce Willis" + assert output == expected_output + + def test_exclude_types() -> None: """Test exclude types from schema.""" url = os.environ.get("NEO4J_URI", "bolt://localhost:7687") diff --git a/libs/neo4j/tests/llms/__init__.py b/libs/neo4j/tests/llms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/neo4j/tests/unit_tests/llms/fake_llm.py b/libs/neo4j/tests/llms/fake_llm.py similarity index 100% rename from libs/neo4j/tests/unit_tests/llms/fake_llm.py rename to libs/neo4j/tests/llms/fake_llm.py diff --git a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py index a857fb4..34d1f38 100644 --- a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py @@ -1,9 +1,16 @@ import pathlib from csv import DictReader from typing import Any, Dict, List +from unittest.mock import MagicMock, patch +import pytest from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory -from langchain_core.messages import SystemMessage +from langchain_core.language_models.llms import LLM +from langchain_core.messages import ( + AIMessage, + SystemMessage, + ToolMessage, +) from langchain_core.prompts import ( ChatPromptTemplate, HumanMessagePromptTemplate, @@ -15,6 +22,7 @@ GraphCypherQAChain, construct_schema, extract_cypher, + get_function_response, ) from langchain_neo4j.chains.graph_qa.cypher_utils import ( CypherQueryCorrector, @@ -26,7 +34,7 @@ ) from langchain_neo4j.graphs.graph_document import GraphDocument from langchain_neo4j.graphs.graph_store import GraphStore -from tests.unit_tests.llms.fake_llm import FakeLLM +from tests.llms.fake_llm import FakeLLM class FakeGraphStore(GraphStore): @@ -141,21 +149,34 @@ def test_graph_cypher_qa_chain_prompt_selection_5() -> None: readonlymemory = ReadOnlySharedMemory(memory=memory) qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[]) cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[]) - try: + with pytest.raises(ValueError) as exc_info: GraphCypherQAChain.from_llm( llm=FakeLLM(), graph=FakeGraphStore(), verbose=True, return_intermediate_steps=False, - qa_prompt=qa_prompt, cypher_prompt=cypher_prompt, cypher_llm_kwargs={"memory": readonlymemory}, + allow_dangerous_requests=True, + ) + assert ( + "Specifying cypher_prompt and cypher_llm_kwargs together is" + " not allowed. Please pass prompt via cypher_llm_kwargs." + ) == str(exc_info.value) + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + qa_prompt=qa_prompt, qa_llm_kwargs={"memory": readonlymemory}, allow_dangerous_requests=True, ) - assert False - except ValueError: - assert True + assert ( + "Specifying qa_prompt and qa_llm_kwargs together is" + " not allowed. Please pass prompt via qa_llm_kwargs." + ) == str(exc_info.value) def test_graph_cypher_qa_chain_prompt_selection_6() -> None: @@ -182,6 +203,53 @@ def test_graph_cypher_qa_chain_prompt_selection_6() -> None: assert chain.cypher_generation_chain.first == CYPHER_GENERATION_PROMPT +def test_graph_cypher_qa_chain_prompt_selection_7() -> None: + # Pass prompts which do not inherit from BasePromptTemplate + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + cypher_llm_kwargs={"prompt": None}, + allow_dangerous_requests=True, + ) + assert "The cypher_llm_kwargs `prompt` must inherit from BasePromptTemplate" == str( + exc_info.value + ) + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + qa_llm_kwargs={"prompt": None}, + allow_dangerous_requests=True, + ) + assert "The qa_llm_kwargs `prompt` must inherit from BasePromptTemplate" == str( + exc_info.value + ) + + +def test_validate_cypher() -> None: + with patch( + "langchain_neo4j.chains.graph_qa.cypher.CypherQueryCorrector", + autospec=True, + ) as cypher_query_corrector_mock: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + validate_cypher=True, + allow_dangerous_requests=True, + ) + cypher_query_corrector_mock.assert_called_once_with([]) + + +def test_chain_type() -> None: + chain = GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + assert chain._chain_type == "graph_cypher_chain" + + def test_graph_cypher_qa_chain() -> None: template = """You are a nice chatbot having a conversation with a human. @@ -236,6 +304,19 @@ def test_graph_cypher_qa_chain() -> None: assert True +def test_cypher_generation_failure() -> None: + """Test the chain doesn't fail if the Cypher query fails to be generated.""" + llm = FakeLLM(queries={"query": ""}, sequential_responses=True) + chain = GraphCypherQAChain.from_llm( + llm=llm, + graph=FakeGraphStore(), + allow_dangerous_requests=True, + return_direct=True, + ) + response = chain.run("Test question") + assert response == [] + + def test_no_backticks() -> None: """Test if there are no backticks, so the original text should be returned.""" query = "MATCH (n) RETURN n" @@ -257,7 +338,7 @@ def test_exclude_types() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -268,7 +349,8 @@ def test_exclude_types() -> None: expected_schema = ( "Node properties are the following:\n" "Movie {title: STRING},Actor {name: STRING}\n" - "Relationship properties are the following:\n\n" + "Relationship properties are the following:\n" + "ACTED_IN {role: STRING}\n" "The relationships are the following:\n" "(:Actor)-[:ACTED_IN]->(:Movie)" ) @@ -282,7 +364,7 @@ def test_include_types() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -293,7 +375,8 @@ def test_include_types() -> None: expected_schema = ( "Node properties are the following:\n" "Movie {title: STRING},Actor {name: STRING}\n" - "Relationship properties are the following:\n\n" + "Relationship properties are the following:\n" + "ACTED_IN {role: STRING}\n" "The relationships are the following:\n" "(:Actor)-[:ACTED_IN]->(:Movie)" ) @@ -307,7 +390,7 @@ def test_include_types2() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -331,7 +414,7 @@ def test_include_types3() -> None: "Actor": [{"property": "name", "type": "STRING"}], "Person": [{"property": "name", "type": "STRING"}], }, - "rel_props": {}, + "rel_props": {"ACTED_IN": [{"property": "role", "type": "STRING"}]}, "relationships": [ {"start": "Actor", "end": "Movie", "type": "ACTED_IN"}, {"start": "Person", "end": "Movie", "type": "DIRECTED"}, @@ -342,13 +425,136 @@ def test_include_types3() -> None: expected_schema = ( "Node properties are the following:\n" "Movie {title: STRING},Actor {name: STRING}\n" - "Relationship properties are the following:\n\n" + "Relationship properties are the following:\n" + "ACTED_IN {role: STRING}\n" "The relationships are the following:\n" "(:Actor)-[:ACTED_IN]->(:Movie)" ) assert output == expected_schema +def test_include_exclude_types_err() -> None: + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + include_types=["Movie", "Actor"], + exclude_types=["Person", "DIRECTED"], + allow_dangerous_requests=True, + ) + assert ( + "Either `exclude_types` or `include_types` can be provided, but not both" + == str(exc_info.value) + ) + + +def test_get_function_response() -> None: + question = "Who directed Dune?" + context = [{"director": "Denis Villeneuve"}] + messages = get_function_response(question, context) + assert len(messages) == 2 + # Validate AIMessage + ai_message = messages[0] + assert isinstance(ai_message, AIMessage) + assert ai_message.content == "" + assert "tool_calls" in ai_message.additional_kwargs + tool_call = ai_message.additional_kwargs["tool_calls"][0] + assert tool_call["function"]["arguments"] == f'{{"question":"{question}"}}' + # Validate ToolMessage + tool_message = messages[1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == str(context) + + +def test_allow_dangerous_requests_err() -> None: + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + ) + assert ( + "In order to use this chain, you must acknowledge that it can make " + "dangerous requests by setting `allow_dangerous_requests` to `True`." + ) in str(exc_info.value) + + +def test_llm_arg_combinations() -> None: + # No llm + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + graph=FakeGraphStore(), allow_dangerous_requests=True + ) + assert "At least one LLM must be provided" == str(exc_info.value) + # llm only + GraphCypherQAChain.from_llm( + llm=FakeLLM(), graph=FakeGraphStore(), allow_dangerous_requests=True + ) + # qa_llm only + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + qa_llm=FakeLLM(), graph=FakeGraphStore(), allow_dangerous_requests=True + ) + assert ( + "If `llm` is not provided, both `qa_llm` and `cypher_llm` must be provided." + == str(exc_info.value) + ) + # cypher_llm only + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + cypher_llm=FakeLLM(), graph=FakeGraphStore(), allow_dangerous_requests=True + ) + assert ( + "If `llm` is not provided, both `qa_llm` and `cypher_llm` must be provided." + == str(exc_info.value) + ) + # llm + qa_llm + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + qa_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + # llm + cypher_llm + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + cypher_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + # qa_llm + cypher_llm + GraphCypherQAChain.from_llm( + qa_llm=FakeLLM(), + cypher_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + # llm + qa_llm + cypher_llm + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + qa_llm=FakeLLM(), + cypher_llm=FakeLLM(), + graph=FakeGraphStore(), + allow_dangerous_requests=True, + ) + assert ( + "You can specify up to two of 'cypher_llm', 'qa_llm'" + ", and 'llm', but not all three simultaneously." + ) == str(exc_info.value) + + +def test_use_function_response_err() -> None: + llm = MagicMock(spec=LLM) + with pytest.raises(ValueError) as exc_info: + GraphCypherQAChain.from_llm( + llm=llm, + graph=FakeGraphStore(), + allow_dangerous_requests=True, + use_function_response=True, + ) + assert "Provided LLM does not support native tools/functions" == str(exc_info.value) + + HERE = pathlib.Path(__file__).parent UNIT_TESTS_ROOT = HERE.parent