diff --git a/CHANGELOG.md b/CHANGELOG.md index 03b629e..f5affc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,10 +4,14 @@ ### Added -- Enhanced Neo4j driver connection management with more robust error handling -- Simplified connection state checking in Neo4jGraph +- Enhanced Neo4j driver connection management with more robust error handling. +- Simplified connection state checking in Neo4jGraph. - Introduced `effective_search_ratio` parameter in Neo4jVector to enhance query accuracy by adjusting the candidate pool size during similarity searches. +### Fixed + +- Removed deprecated LLMChain from GraphCypherQAChain to resolve instantiation issues with the use_function_response parameter. + ## 0.1.1 ### Changed diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index 48b0843..69ff000 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Union from langchain.chains.base import Chain -from langchain.chains.llm import LLMChain from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import ( @@ -165,8 +164,8 @@ class GraphCypherQAChain(Chain): """ graph: GraphStore = Field(exclude=True) - cypher_generation_chain: LLMChain - qa_chain: Union[LLMChain, Runnable] + cypher_generation_chain: Runnable + qa_chain: Runnable graph_schema: str input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -261,28 +260,34 @@ def from_llm( "You can specify up to two of 'cypher_llm', 'qa_llm'" ", and 'llm', but not all three simultaneously." ) - if cypher_prompt and cypher_llm_kwargs: - raise ValueError( - "Specifying cypher_prompt and cypher_llm_kwargs together is" - " not allowed. Please pass prompt via cypher_llm_kwargs." - ) - if qa_prompt and qa_llm_kwargs: - raise ValueError( - "Specifying qa_prompt and qa_llm_kwargs together is" - " not allowed. Please pass prompt via qa_llm_kwargs." - ) + if cypher_prompt: + if cypher_llm_kwargs: + raise ValueError( + "Specifying cypher_prompt and cypher_llm_kwargs together is" + " not allowed. Please pass prompt via cypher_llm_kwargs." + ) + else: + if cypher_llm_kwargs: + cypher_prompt = cypher_llm_kwargs.pop( + "prompt", CYPHER_GENERATION_PROMPT + ) + else: + cypher_prompt = CYPHER_GENERATION_PROMPT + if qa_prompt: + if qa_llm_kwargs: + raise ValueError( + "Specifying qa_prompt and qa_llm_kwargs together is" + " not allowed. Please pass prompt via qa_llm_kwargs." + ) + else: + if qa_llm_kwargs: + qa_prompt = qa_llm_kwargs.pop("prompt", CYPHER_QA_PROMPT) + else: + qa_prompt = CYPHER_QA_PROMPT use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {} use_cypher_llm_kwargs = ( cypher_llm_kwargs if cypher_llm_kwargs is not None else {} ) - if "prompt" not in use_qa_llm_kwargs: - use_qa_llm_kwargs["prompt"] = ( - qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT - ) - if "prompt" not in use_cypher_llm_kwargs: - use_cypher_llm_kwargs["prompt"] = ( - cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT - ) qa_llm = qa_llm or llm if use_function_response: @@ -299,11 +304,11 @@ def from_llm( except (NotImplementedError, AttributeError): raise ValueError("Provided LLM does not support native tools/functions") else: - qa_chain = LLMChain(llm=qa_llm, **use_qa_llm_kwargs) # type: ignore[arg-type] + qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser() # type: ignore - cypher_generation_chain = LLMChain( - llm=cypher_llm or llm, # type: ignore[arg-type] - **use_cypher_llm_kwargs, # type: ignore[arg-type] + cypher_llm = cypher_llm or llm + cypher_generation_chain = ( + cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser() # type: ignore ) if exclude_types and include_types: @@ -349,7 +354,9 @@ def _call( intermediate_steps: List = [] - generated_cypher = self.cypher_generation_chain.run(args, callbacks=callbacks) + generated_cypher = self.cypher_generation_chain.invoke( + args, callbacks=callbacks + ) # Extract Cypher code if it is wrapped in backticks generated_cypher = extract_cypher(generated_cypher) @@ -391,7 +398,7 @@ def _call( {"question": question, "context": context}, callbacks=callbacks, ) - final_result = result[self.qa_chain.output_key] # type: ignore + final_result = result # type: ignore chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: diff --git a/libs/neo4j/tests/integration_tests/chains/test_graph_database.py b/libs/neo4j/tests/integration_tests/chains/test_graph_database.py index 7fa0b62..56ecc1b 100644 --- a/libs/neo4j/tests/integration_tests/chains/test_graph_database.py +++ b/libs/neo4j/tests/integration_tests/chains/test_graph_database.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock from langchain_core.language_models import BaseLanguageModel -from langchain_core.outputs import Generation, LLMResult +from langchain_core.language_models.fake import FakeListLLM from langchain_neo4j.chains.graph_qa.cypher import GraphCypherQAChain from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph @@ -71,11 +71,7 @@ def test_cypher_generating_run() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = MagicMock(spec=BaseLanguageModel) - llm.generate_prompt.side_effect = [ - LLMResult(generations=[[Generation(text=query)]]), - LLMResult(generations=[[Generation(text="Bruce Willis")]]), - ] + llm = FakeListLLM(responses=[query, "Bruce Willis"]) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, @@ -115,10 +111,7 @@ def test_cypher_top_k() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = MagicMock(spec=BaseLanguageModel) - llm.generate_prompt.side_effect = [ - LLMResult(generations=[[Generation(text=query)]]) - ] + llm = FakeListLLM(responses=[query]) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, @@ -156,11 +149,7 @@ def test_cypher_intermediate_steps() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = MagicMock(spec=BaseLanguageModel) - llm.generate_prompt.side_effect = [ - LLMResult(generations=[[Generation(text=query)]]), - LLMResult(generations=[[Generation(text="Bruce Willis")]]), - ] + llm = FakeListLLM(responses=[query, "Bruce Willis"]) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, @@ -205,10 +194,7 @@ def test_cypher_return_direct() -> None: "WHERE m.title = 'Pulp Fiction' " "RETURN a.name" ) - llm = MagicMock(spec=BaseLanguageModel) - llm.generate_prompt.side_effect = [ - LLMResult(generations=[[Generation(text=query)]]) - ] + llm = FakeListLLM(responses=[query]) chain = GraphCypherQAChain.from_llm( llm=llm, graph=graph, diff --git a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py index c6c08af..a857fb4 100644 --- a/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/neo4j/tests/unit_tests/chains/test_graph_qa.py @@ -3,7 +3,13 @@ from typing import Any, Dict, List from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory -from langchain_core.prompts import PromptTemplate +from langchain_core.messages import SystemMessage +from langchain_core.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, + PromptTemplate, +) from langchain_neo4j.chains.graph_qa.cypher import ( GraphCypherQAChain, @@ -64,8 +70,10 @@ def test_graph_cypher_qa_chain_prompt_selection_1() -> None: cypher_prompt=cypher_prompt, allow_dangerous_requests=True, ) - assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr] - assert chain.cypher_generation_chain.prompt == cypher_prompt + assert hasattr(chain.qa_chain, "first") + assert chain.qa_chain.first == qa_prompt + assert hasattr(chain.cypher_generation_chain, "first") + assert chain.cypher_generation_chain.first == cypher_prompt def test_graph_cypher_qa_chain_prompt_selection_2() -> None: @@ -77,8 +85,10 @@ def test_graph_cypher_qa_chain_prompt_selection_2() -> None: return_intermediate_steps=False, allow_dangerous_requests=True, ) - assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr] - assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT + assert hasattr(chain.qa_chain, "first") + assert chain.qa_chain.first == CYPHER_QA_PROMPT + assert hasattr(chain.cypher_generation_chain, "first") + assert chain.cypher_generation_chain.first == CYPHER_GENERATION_PROMPT def test_graph_cypher_qa_chain_prompt_selection_3() -> None: @@ -94,8 +104,10 @@ def test_graph_cypher_qa_chain_prompt_selection_3() -> None: qa_llm_kwargs={"memory": readonlymemory}, allow_dangerous_requests=True, ) - assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr] - assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT + assert hasattr(chain.qa_chain, "first") + assert chain.qa_chain.first == CYPHER_QA_PROMPT + assert hasattr(chain.cypher_generation_chain, "first") + assert chain.cypher_generation_chain.first == CYPHER_GENERATION_PROMPT def test_graph_cypher_qa_chain_prompt_selection_4() -> None: @@ -115,8 +127,10 @@ def test_graph_cypher_qa_chain_prompt_selection_4() -> None: qa_llm_kwargs={"prompt": qa_prompt, "memory": readonlymemory}, allow_dangerous_requests=True, ) - assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr] - assert chain.cypher_generation_chain.prompt == cypher_prompt + assert hasattr(chain.qa_chain, "first") + assert chain.qa_chain.first == qa_prompt + assert hasattr(chain.cypher_generation_chain, "first") + assert chain.cypher_generation_chain.first == cypher_prompt def test_graph_cypher_qa_chain_prompt_selection_5() -> None: @@ -144,6 +158,30 @@ def test_graph_cypher_qa_chain_prompt_selection_5() -> None: assert True +def test_graph_cypher_qa_chain_prompt_selection_6() -> None: + # Test function response prompt + function_response_system = "Respond as a pirate!" + response_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage(content=function_response_system), + HumanMessagePromptTemplate.from_template("{question}"), + MessagesPlaceholder(variable_name="function_response"), + ] + ) + chain = GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + use_function_response=True, + function_response_system=function_response_system, + allow_dangerous_requests=True, + ) + assert hasattr(chain.qa_chain, "first") + assert chain.qa_chain.first == response_prompt + assert hasattr(chain.cypher_generation_chain, "first") + assert chain.cypher_generation_chain.first == CYPHER_GENERATION_PROMPT + + def test_graph_cypher_qa_chain() -> None: template = """You are a nice chatbot having a conversation with a human. diff --git a/libs/neo4j/tests/unit_tests/llms/fake_llm.py b/libs/neo4j/tests/unit_tests/llms/fake_llm.py index a8e9eb6..6958db0 100644 --- a/libs/neo4j/tests/unit_tests/llms/fake_llm.py +++ b/libs/neo4j/tests/unit_tests/llms/fake_llm.py @@ -59,3 +59,6 @@ def _get_next_response_in_sequence(self) -> str: response = queries[list(queries.keys())[self.response_index]] self.response_index = self.response_index + 1 return response + + def bind_tools(self, tools: Any) -> None: + pass