From 922751913822dda9475f1503b77c8e29c90a5465 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Thu, 17 Oct 2024 13:14:25 +0200 Subject: [PATCH] feat: streaming (wip) --- ee/hogai/assistant.py | 21 ++++++++++++++++++--- ee/hogai/trends/nodes.py | 22 ++++++++++++++-------- posthog/api/query.py | 4 ++-- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index a23a3d5064142..9aa305241dd06 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -1,8 +1,11 @@ -from langchain_core.messages import BaseMessage +from typing import cast + +from langchain_core.messages import AIMessageChunk, BaseMessage +from langchain_core.outputs import Generation from langgraph.graph.state import StateGraph from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode -from ee.hogai.utils import AssistantNodeName, AssistantState +from ee.hogai.utils import AssistantMessage, AssistantNodeName, AssistantState from posthog.models.team.team import Team @@ -35,8 +38,20 @@ def _compile_graph(self): def stream(self, messages: list[BaseMessage]): assistant_graph = self._compile_graph() - return assistant_graph.stream( + generator = assistant_graph.stream( {"messages": messages}, config={"recursion_limit": 24}, stream_mode="messages", ) + + chunks = AIMessageChunk("") + parser = GenerateTrendsNode.output_parser + + for message, state in generator: + if state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS: + if isinstance(message, AssistantMessage): + yield message + else: + message = cast(AIMessageChunk, message) + chunks += message + yield parser.parse_result([Generation(text=chunks.content)], partial=True) diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index 3a616306f3ee3..0188b7d4bb0d0 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -4,6 +4,7 @@ from functools import cached_property from typing import Union, cast +from django.utils.functional import classproperty from langchain.agents.format_scratchpad import format_log_to_str from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser from langchain_core.agents import AgentAction, AgentFinish @@ -11,7 +12,7 @@ from langchain_core.messages import AIMessage, BaseMessage, merge_message_runs from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.runnables import RunnableLambda, RunnablePassthrough +from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnablePassthrough from pydantic import ValidationError from ee.hogai.hardcoded_definitions import hardcoded_prop_defs @@ -133,7 +134,7 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: return conversation - def run(self, state: AssistantState): + def run(self, state: AssistantState, config: RunnableConfig): intermediate_steps = state.get("intermediate_steps") or [] prompt = ( @@ -178,7 +179,8 @@ def run(self, state: AssistantState): "tools": toolkit.render_text_description(), "tool_names": ", ".join([t["name"] for t in toolkit.tools]), "intermediate_steps": intermediate_steps, - } + }, + config, ), ) except OutputParserException as e: @@ -210,7 +212,7 @@ def router(self, state: AssistantState): return AssistantNodeName.GENERATE_TRENDS return AssistantNodeName.CREATE_TRENDS_PLAN - def run(self, state: AssistantState): + def run(self, state: AssistantState, config: RunnableConfig): toolkit = TrendsAgentToolkit(self._team) intermediate_steps = state.get("intermediate_steps") or [] action, _ = intermediate_steps[-1] @@ -317,7 +319,11 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: return conversation - def run(self, state: AssistantState): + @classproperty + def output_parser(cls): + return PydanticOutputParser[GenerateTrendOutputModel](pydantic_object=GenerateTrendOutputModel) + + def run(self, state: AssistantState, config: RunnableConfig): generated_plan = state.get("plan", "") llm = llm_gpt_4o.with_structured_output( @@ -341,11 +347,11 @@ def run(self, state: AssistantState): # Result from structured output is a parsed dict. Convert to a string since the output parser expects it. | RunnableLambda(lambda x: json.dumps(x)) # Validate a string input. - | PydanticOutputParser[GenerateTrendOutputModel](pydantic_object=GenerateTrendOutputModel) + | self.output_parser ) try: - message = chain.invoke({}) + message = chain.invoke({}, config) except OutputParserException as e: if e.send_to_llm: observation = str(e.observation) @@ -371,5 +377,5 @@ class GenerateTrendsToolsNode(AssistantNode): name = AssistantNodeName.GENERATE_TRENDS_TOOLS - def run(self, state: AssistantState): + def run(self, state: AssistantState, config: RunnableConfig): return state diff --git a/posthog/api/query.py b/posthog/api/query.py index 6ef669f35ebf0..869adfcf22e46 100644 --- a/posthog/api/query.py +++ b/posthog/api/query.py @@ -169,11 +169,11 @@ def draft_sql(self, request: Request, *args, **kwargs) -> Response: def chat(self, request: Request, *args, **kwargs): assert request.user is not None validated_body = Conversation.model_validate(request.data) - chain = Assistant(self.team) + assistant = Assistant(self.team) def generate(): last_message = None - for message in chain.stream({"question": validated_body.messages[0].content}): + for message in assistant.stream({"messages": validated_body.messages}): if message: last_message = message[0].model_dump_json() yield last_message