Skip to content

Commit

Permalink
feat: streaming (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 committed Oct 17, 2024
1 parent 7b8e03e commit 9227519
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
21 changes: 18 additions & 3 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from langchain_core.messages import BaseMessage
from typing import cast

from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import Generation
from langgraph.graph.state import StateGraph

from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode
from ee.hogai.utils import AssistantNodeName, AssistantState
from ee.hogai.utils import AssistantMessage, AssistantNodeName, AssistantState
from posthog.models.team.team import Team


Expand Down Expand Up @@ -35,8 +38,20 @@ def _compile_graph(self):

def stream(self, messages: list[BaseMessage]):
assistant_graph = self._compile_graph()
return assistant_graph.stream(
generator = assistant_graph.stream(
{"messages": messages},
config={"recursion_limit": 24},
stream_mode="messages",
)

chunks = AIMessageChunk("")
parser = GenerateTrendsNode.output_parser

for message, state in generator:
if state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS:
if isinstance(message, AssistantMessage):
yield message
else:
message = cast(AIMessageChunk, message)
chunks += message
yield parser.parse_result([Generation(text=chunks.content)], partial=True)
22 changes: 14 additions & 8 deletions ee/hogai/trends/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from functools import cached_property
from typing import Union, cast

from django.utils.functional import classproperty
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, BaseMessage, merge_message_runs
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnablePassthrough
from pydantic import ValidationError

from ee.hogai.hardcoded_definitions import hardcoded_prop_defs
Expand Down Expand Up @@ -133,7 +134,7 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:

return conversation

def run(self, state: AssistantState):
def run(self, state: AssistantState, config: RunnableConfig):
intermediate_steps = state.get("intermediate_steps") or []

prompt = (
Expand Down Expand Up @@ -178,7 +179,8 @@ def run(self, state: AssistantState):
"tools": toolkit.render_text_description(),
"tool_names": ", ".join([t["name"] for t in toolkit.tools]),
"intermediate_steps": intermediate_steps,
}
},
config,
),
)
except OutputParserException as e:
Expand Down Expand Up @@ -210,7 +212,7 @@ def router(self, state: AssistantState):
return AssistantNodeName.GENERATE_TRENDS
return AssistantNodeName.CREATE_TRENDS_PLAN

def run(self, state: AssistantState):
def run(self, state: AssistantState, config: RunnableConfig):
toolkit = TrendsAgentToolkit(self._team)
intermediate_steps = state.get("intermediate_steps") or []
action, _ = intermediate_steps[-1]
Expand Down Expand Up @@ -317,7 +319,11 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]:

return conversation

def run(self, state: AssistantState):
@classproperty
def output_parser(cls):
return PydanticOutputParser[GenerateTrendOutputModel](pydantic_object=GenerateTrendOutputModel)

def run(self, state: AssistantState, config: RunnableConfig):
generated_plan = state.get("plan", "")

llm = llm_gpt_4o.with_structured_output(
Expand All @@ -341,11 +347,11 @@ def run(self, state: AssistantState):
# Result from structured output is a parsed dict. Convert to a string since the output parser expects it.
| RunnableLambda(lambda x: json.dumps(x))
# Validate a string input.
| PydanticOutputParser[GenerateTrendOutputModel](pydantic_object=GenerateTrendOutputModel)
| self.output_parser
)

try:
message = chain.invoke({})
message = chain.invoke({}, config)
except OutputParserException as e:
if e.send_to_llm:
observation = str(e.observation)
Expand All @@ -371,5 +377,5 @@ class GenerateTrendsToolsNode(AssistantNode):

name = AssistantNodeName.GENERATE_TRENDS_TOOLS

def run(self, state: AssistantState):
def run(self, state: AssistantState, config: RunnableConfig):
return state
4 changes: 2 additions & 2 deletions posthog/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,11 @@ def draft_sql(self, request: Request, *args, **kwargs) -> Response:
def chat(self, request: Request, *args, **kwargs):
assert request.user is not None
validated_body = Conversation.model_validate(request.data)
chain = Assistant(self.team)
assistant = Assistant(self.team)

def generate():
last_message = None
for message in chain.stream({"question": validated_body.messages[0].content}):
for message in assistant.stream({"messages": validated_body.messages}):

Check warning

Code scanning / CodeQL

Information exposure through an exception Medium

Stack trace information
flows to this location and may be exposed to an external user.
if message:
last_message = message[0].model_dump_json()
yield last_message
Expand Down

0 comments on commit 9227519

Please sign in to comment.