Skip to content

Commit

Permalink
Removes LLMChain from GraphCypherQAChain
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Dec 3, 2024
1 parent 3489bf5 commit af58bdc
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 55 deletions.
61 changes: 34 additions & 27 deletions libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Dict, List, Optional, Union

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import (
Expand Down Expand Up @@ -165,8 +164,8 @@ class GraphCypherQAChain(Chain):
"""

graph: GraphStore = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: Union[LLMChain, Runnable]
cypher_generation_chain: Runnable
qa_chain: Runnable
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
Expand Down Expand Up @@ -261,28 +260,34 @@ def from_llm(
"You can specify up to two of 'cypher_llm', 'qa_llm'"
", and 'llm', but not all three simultaneously."
)
if cypher_prompt and cypher_llm_kwargs:
raise ValueError(
"Specifying cypher_prompt and cypher_llm_kwargs together is"
" not allowed. Please pass prompt via cypher_llm_kwargs."
)
if qa_prompt and qa_llm_kwargs:
raise ValueError(
"Specifying qa_prompt and qa_llm_kwargs together is"
" not allowed. Please pass prompt via qa_llm_kwargs."
)
if cypher_prompt:
if cypher_llm_kwargs:
raise ValueError(
"Specifying cypher_prompt and cypher_llm_kwargs together is"
" not allowed. Please pass prompt via cypher_llm_kwargs."
)
else:
if cypher_llm_kwargs:
cypher_prompt = cypher_llm_kwargs.pop(
"prompt", CYPHER_GENERATION_PROMPT
)
else:
cypher_prompt = CYPHER_GENERATION_PROMPT
if qa_prompt:
if qa_llm_kwargs:
raise ValueError(
"Specifying qa_prompt and qa_llm_kwargs together is"
" not allowed. Please pass prompt via qa_llm_kwargs."
)
else:
if qa_llm_kwargs:
qa_prompt = qa_llm_kwargs.pop("prompt", CYPHER_QA_PROMPT)
else:
qa_prompt = CYPHER_QA_PROMPT
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
use_cypher_llm_kwargs = (
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
)
if "prompt" not in use_qa_llm_kwargs:
use_qa_llm_kwargs["prompt"] = (
qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT
)
if "prompt" not in use_cypher_llm_kwargs:
use_cypher_llm_kwargs["prompt"] = (
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
)

qa_llm = qa_llm or llm
if use_function_response:
Expand All @@ -299,11 +304,11 @@ def from_llm(
except (NotImplementedError, AttributeError):
raise ValueError("Provided LLM does not support native tools/functions")
else:
qa_chain = LLMChain(llm=qa_llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser() # type: ignore

cypher_generation_chain = LLMChain(
llm=cypher_llm or llm, # type: ignore[arg-type]
**use_cypher_llm_kwargs, # type: ignore[arg-type]
cypher_llm = cypher_llm or llm
cypher_generation_chain = (
cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser() # type: ignore
)

if exclude_types and include_types:
Expand Down Expand Up @@ -349,7 +354,9 @@ def _call(

intermediate_steps: List = []

generated_cypher = self.cypher_generation_chain.run(args, callbacks=callbacks)
generated_cypher = self.cypher_generation_chain.invoke(
args, callbacks=callbacks
)

# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
Expand Down Expand Up @@ -391,7 +398,7 @@ def _call(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key] # type: ignore
final_result = result # type: ignore

chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
Expand Down
24 changes: 5 additions & 19 deletions libs/neo4j/tests/integration_tests/chains/test_graph_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from langchain.chains.loading import load_chain
from langchain_core.language_models import BaseLanguageModel
from langchain_core.outputs import Generation, LLMResult
from langchain_core.language_models.fake import FakeListLLM

from langchain_neo4j.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph
Expand Down Expand Up @@ -73,11 +73,7 @@ def test_cypher_generating_run() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = MagicMock(spec=BaseLanguageModel)
llm.generate_prompt.side_effect = [
LLMResult(generations=[[Generation(text=query)]]),
LLMResult(generations=[[Generation(text="Bruce Willis")]]),
]
llm = FakeListLLM(responses=[query, "Bruce Willis"])
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand Down Expand Up @@ -117,10 +113,7 @@ def test_cypher_top_k() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = MagicMock(spec=BaseLanguageModel)
llm.generate_prompt.side_effect = [
LLMResult(generations=[[Generation(text=query)]])
]
llm = FakeListLLM(responses=[query])
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand Down Expand Up @@ -158,11 +151,7 @@ def test_cypher_intermediate_steps() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = MagicMock(spec=BaseLanguageModel)
llm.generate_prompt.side_effect = [
LLMResult(generations=[[Generation(text=query)]]),
LLMResult(generations=[[Generation(text="Bruce Willis")]]),
]
llm = FakeListLLM(responses=[query, "Bruce Willis"])
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand Down Expand Up @@ -207,10 +196,7 @@ def test_cypher_return_direct() -> None:
"WHERE m.title = 'Pulp Fiction' "
"RETURN a.name"
)
llm = MagicMock(spec=BaseLanguageModel)
llm.generate_prompt.side_effect = [
LLMResult(generations=[[Generation(text=query)]])
]
llm = FakeListLLM(responses=[query])
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
Expand Down
46 changes: 37 additions & 9 deletions libs/neo4j/tests/unit_tests/chains/test_graph_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from typing import Any, Dict, List

from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import SystemMessage
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
)

from langchain_neo4j.chains.graph_qa.cypher import (
GraphCypherQAChain,
Expand Down Expand Up @@ -64,8 +70,8 @@ def test_graph_cypher_qa_chain_prompt_selection_1() -> None:
cypher_prompt=cypher_prompt,
allow_dangerous_requests=True,
)
assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr]
assert chain.cypher_generation_chain.prompt == cypher_prompt
assert chain.qa_chain.first == qa_prompt # type: ignore[attr-defined]
assert chain.cypher_generation_chain.first == cypher_prompt # type: ignore[attr-defined]


def test_graph_cypher_qa_chain_prompt_selection_2() -> None:
Expand All @@ -77,8 +83,8 @@ def test_graph_cypher_qa_chain_prompt_selection_2() -> None:
return_intermediate_steps=False,
allow_dangerous_requests=True,
)
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr]
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
assert chain.qa_chain.first == CYPHER_QA_PROMPT # type: ignore[attr-defined]
assert chain.cypher_generation_chain.first == CYPHER_GENERATION_PROMPT # type: ignore[attr-defined]


def test_graph_cypher_qa_chain_prompt_selection_3() -> None:
Expand All @@ -94,8 +100,8 @@ def test_graph_cypher_qa_chain_prompt_selection_3() -> None:
qa_llm_kwargs={"memory": readonlymemory},
allow_dangerous_requests=True,
)
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr]
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
assert chain.qa_chain.first == CYPHER_QA_PROMPT # type: ignore[attr-defined]
assert chain.cypher_generation_chain.first == CYPHER_GENERATION_PROMPT # type: ignore[attr-defined]


def test_graph_cypher_qa_chain_prompt_selection_4() -> None:
Expand All @@ -115,8 +121,8 @@ def test_graph_cypher_qa_chain_prompt_selection_4() -> None:
qa_llm_kwargs={"prompt": qa_prompt, "memory": readonlymemory},
allow_dangerous_requests=True,
)
assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr]
assert chain.cypher_generation_chain.prompt == cypher_prompt
assert chain.qa_chain.first == qa_prompt # type: ignore[attr-defined]
assert chain.cypher_generation_chain.first == cypher_prompt # type: ignore[attr-defined]


def test_graph_cypher_qa_chain_prompt_selection_5() -> None:
Expand Down Expand Up @@ -144,6 +150,28 @@ def test_graph_cypher_qa_chain_prompt_selection_5() -> None:
assert True


def test_graph_cypher_qa_chain_prompt_selection_6() -> None:
# Test function response prompt
function_response_system = "Respond as a pirate!"
response_prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=function_response_system),
HumanMessagePromptTemplate.from_template("{question}"),
MessagesPlaceholder(variable_name="function_response"),
]
)
chain = GraphCypherQAChain.from_llm(
llm=FakeLLM(),
graph=FakeGraphStore(),
verbose=True,
use_function_response=True,
function_response_system=function_response_system,
allow_dangerous_requests=True,
)
assert chain.qa_chain.first == response_prompt # type: ignore[attr-defined]
assert chain.cypher_generation_chain.first == CYPHER_GENERATION_PROMPT # type: ignore[attr-defined]


def test_graph_cypher_qa_chain() -> None:
template = """You are a nice chatbot having a conversation with a human.
Expand Down
3 changes: 3 additions & 0 deletions libs/neo4j/tests/unit_tests/llms/fake_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ def _get_next_response_in_sequence(self) -> str:
response = queries[list(queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response

def bind_tools(self, tools: Any) -> None:
pass

0 comments on commit af58bdc

Please sign in to comment.