diff --git a/docker-compose.base.yml b/docker-compose.base.yml index 137416e40422c..f582d84197a9a 100644 --- a/docker-compose.base.yml +++ b/docker-compose.base.yml @@ -285,7 +285,8 @@ services: environment: - TEMPORAL_ADDRESS=temporal:7233 - TEMPORAL_CORS_ORIGINS=http://localhost:3000 - image: temporalio/ui:2.10.3 + - TEMPORAL_CSRF_COOKIE_INSECURE=true + image: temporalio/ui:2.31.2 ports: - 8081:8080 temporal-django-worker: diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index e47020fdcdf04..d01181b937c4e 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -4,12 +4,18 @@ from langchain_core.messages import AIMessageChunk from langfuse.callback import CallbackHandler from langgraph.graph.state import StateGraph +from pydantic import BaseModel from ee import settings -from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode +from ee.hogai.trends.nodes import ( + CreateTrendsPlanNode, + CreateTrendsPlanToolsNode, + GenerateTrendsNode, + GenerateTrendsToolsNode, +) from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation from posthog.models.team.team import Team -from posthog.schema import VisualizationMessage +from posthog.schema import AssistantGenerationStatusEvent, AssistantGenerationStatusType, VisualizationMessage if settings.LANGFUSE_PUBLIC_KEY: langfuse_handler = CallbackHandler( @@ -39,6 +45,13 @@ def is_message_update( return len(update) == 2 and update[0] == "messages" +def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]: + """ + Update of the state. + """ + return len(update) == 2 and update[0] == "values" + + class Assistant: _team: Team _graph: StateGraph @@ -59,6 +72,10 @@ def _compile_graph(self): generate_trends_node = GenerateTrendsNode(self._team) builder.add_node(GenerateTrendsNode.name, generate_trends_node.run) + generate_trends_tools_node = GenerateTrendsToolsNode(self._team) + builder.add_node(GenerateTrendsToolsNode.name, generate_trends_tools_node.run) + builder.add_edge(GenerateTrendsToolsNode.name, GenerateTrendsNode.name) + builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name) builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router) builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router) @@ -66,31 +83,45 @@ def _compile_graph(self): return builder.compile() - def stream(self, conversation: Conversation) -> Generator[str, None, None]: + def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]: assistant_graph = self._compile_graph() callbacks = [langfuse_handler] if langfuse_handler else [] messages = [message.root for message in conversation.messages] + chunks = AIMessageChunk(content="") + state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None} + generator = assistant_graph.stream( - {"messages": messages}, + state, config={"recursion_limit": 24, "callbacks": callbacks}, - stream_mode=["messages", "updates"], + stream_mode=["messages", "values", "updates"], ) chunks = AIMessageChunk(content="") # Send a chunk to establish the connection avoiding the worker's timeout. - yield "" + yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK) for update in generator: - if is_value_update(update): + if is_state_update(update): + _, new_state = update + state = new_state + + elif is_value_update(update): _, state_update = update - if ( - AssistantNodeName.GENERATE_TRENDS in state_update - and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS] - ): - message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]) - yield message.model_dump_json() + + if AssistantNodeName.GENERATE_TRENDS in state_update: + # Reset chunks when schema validation fails. + chunks = AIMessageChunk(content="") + + if "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]: + message = cast( + VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0] + ) + yield message + elif state_update[AssistantNodeName.GENERATE_TRENDS].get("intermediate_steps", []): + yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR) + elif is_message_update(update): langchain_message, langgraph_state = update[1] if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance( @@ -101,4 +132,4 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: if parsed_message: yield VisualizationMessage( reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer - ).model_dump_json() + ) diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index 845c71fe4ee5e..d1819b49b705f 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -1,29 +1,37 @@ import itertools -import json import xml.etree.ElementTree as ET from functools import cached_property -from typing import Union, cast +from typing import Optional, cast from langchain.agents.format_scratchpad import format_log_to_str -from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser -from langchain_core.agents import AgentAction, AgentFinish -from langchain_core.exceptions import OutputParserException +from langchain_core.agents import AgentAction from langchain_core.messages import AIMessage as LangchainAssistantMessage from langchain_core.messages import BaseMessage, merge_message_runs -from langchain_core.messages import HumanMessage as LangchainHumanMessage -from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI from pydantic import ValidationError from ee.hogai.hardcoded_definitions import hardcoded_prop_defs +from ee.hogai.trends.parsers import ( + PydanticOutputParserException, + ReActParserException, + ReActParserMissingActionException, + parse_generated_trends_output, + parse_react_agent_output, +) from ee.hogai.trends.prompts import ( react_definitions_prompt, react_follow_up_prompt, + react_malformed_json_prompt, + react_missing_action_correction_prompt, + react_missing_action_prompt, + react_pydantic_validation_exception_prompt, react_scratchpad_prompt, react_system_prompt, react_user_prompt, + trends_failover_output_prompt, + trends_failover_prompt, trends_group_mapping_prompt, trends_new_plan_prompt, trends_plan_prompt, @@ -35,7 +43,7 @@ TrendsAgentToolkit, TrendsAgentToolModel, ) -from ee.hogai.trends.utils import GenerateTrendOutputModel +from ee.hogai.trends.utils import GenerateTrendOutputModel, filter_trends_conversation from ee.hogai.utils import ( AssistantNode, AssistantNodeName, @@ -45,7 +53,12 @@ from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner from posthog.hogql_queries.query_runner import ExecutionMode from posthog.models.group_type_mapping import GroupTypeMapping -from posthog.schema import CachedTeamTaxonomyQueryResponse, HumanMessage, TeamTaxonomyQuery, VisualizationMessage +from posthog.schema import ( + CachedTeamTaxonomyQueryResponse, + FailureMessage, + TeamTaxonomyQuery, + VisualizationMessage, +) class CreateTrendsPlanNode(AssistantNode): @@ -75,40 +88,42 @@ def run(self, state: AssistantState, config: RunnableConfig): ) toolkit = TrendsAgentToolkit(self._team) - output_parser = ReActJsonSingleInputOutputParser() merger = merge_message_runs() - agent = prompt | merger | self._model | output_parser + agent = prompt | merger | self._model | parse_react_agent_output try: result = cast( - Union[AgentAction, AgentFinish], + AgentAction, agent.invoke( { "tools": toolkit.render_text_description(), "tool_names": ", ".join([t["name"] for t in toolkit.tools]), - "agent_scratchpad": format_log_to_str( - [(action, output) for action, output in intermediate_steps if output is not None] - ), + "agent_scratchpad": self._get_agent_scratchpad(intermediate_steps), }, config, ), ) - except OutputParserException as e: - text = str(e) - if e.send_to_llm: - observation = str(e.observation) - text = str(e.llm_output) + except ReActParserException as e: + if isinstance(e, ReActParserMissingActionException): + # When the agent doesn't output the "Action:" block, we need to correct the log and append the action block, + # so that it has a higher chance to recover. + corrected_log = str( + ChatPromptTemplate.from_template(react_missing_action_correction_prompt, template_format="mustache") + .format_messages(output=e.llm_output)[0] + .content + ) + result = AgentAction( + "handle_incorrect_response", + react_missing_action_prompt, + corrected_log, + ) else: - observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question." - result = AgentAction("handle_incorrect_response", observation, text) - - if isinstance(result, AgentFinish): - # Exceptional case - return { - "plan": result.log, - "intermediate_steps": None, - } + result = AgentAction( + "handle_incorrect_response", + react_malformed_json_prompt, + e.llm_output, + ) return { "intermediate_steps": [*intermediate_steps, (result, None)], @@ -170,29 +185,44 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: """ Reconstruct the conversation for the agent. On this step we only care about previously asked questions and generated plans. All other messages are filtered out. """ - messages = state.get("messages", []) - if len(messages) == 0: + human_messages, visualization_messages = filter_trends_conversation(state.get("messages", [])) + + if not human_messages: return [] - conversation = [ - HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format( - question=messages[0].content if isinstance(messages[0], HumanMessage) else "" - ) - ] + conversation = [] - for message in messages[1:]: - if isinstance(message, HumanMessage): - conversation.append( - HumanMessagePromptTemplate.from_template( - react_follow_up_prompt, - template_format="mustache", - ).format(feedback=message.content) - ) - elif isinstance(message, VisualizationMessage): - conversation.append(LangchainAssistantMessage(content=message.plan or "")) + for idx, messages in enumerate(itertools.zip_longest(human_messages, visualization_messages)): + human_message, viz_message = messages + + if human_message: + if idx == 0: + conversation.append( + HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format( + question=human_message.content + ) + ) + else: + conversation.append( + HumanMessagePromptTemplate.from_template( + react_follow_up_prompt, + template_format="mustache", + ).format(feedback=human_message.content) + ) + + if viz_message: + conversation.append(LangchainAssistantMessage(content=viz_message.plan or "")) return conversation + def _get_agent_scratchpad(self, scratchpad: list[tuple[AgentAction, str | None]]) -> str: + actions = [] + for action, observation in scratchpad: + if observation is None: + continue + actions.append((action, observation)) + return format_log_to_str(actions) + class CreateTrendsPlanToolsNode(AssistantNode): name = AssistantNodeName.CREATE_TRENDS_PLAN_TOOLS @@ -205,8 +235,12 @@ def run(self, state: AssistantState, config: RunnableConfig): try: input = TrendsAgentToolModel.model_validate({"name": action.tool, "arguments": action.tool_input}).root except ValidationError as e: - feedback = f"Invalid tool call. Pydantic exception: {e.errors(include_url=False)}" - return {"intermediate_steps": [*intermediate_steps, (action, feedback)]} + observation = ( + ChatPromptTemplate.from_template(react_pydantic_validation_exception_prompt, template_format="mustache") + .format_messages(exception=e.errors(include_url=False))[0] + .content + ) + return {"intermediate_steps": [*intermediate_steps[:-1], (action, observation)]} # The plan has been found. Move to the generation. if input.name == "final_answer": @@ -240,30 +274,38 @@ class GenerateTrendsNode(AssistantNode): def run(self, state: AssistantState, config: RunnableConfig): generated_plan = state.get("plan", "") + intermediate_steps = state.get("intermediate_steps") or [] + validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None trends_generation_prompt = ChatPromptTemplate.from_messages( [ ("system", trends_system_prompt), ], template_format="mustache", - ) + self._reconstruct_conversation(state) + ) + self._reconstruct_conversation(state, validation_error_message=validation_error_message) merger = merge_message_runs() - chain = ( - trends_generation_prompt - | merger - | self._model - # Result from structured output is a parsed dict. Convert to a string since the output parser expects it. - | RunnableLambda(lambda x: json.dumps(x)) - # Validate a string input. - | PydanticOutputParser[GenerateTrendOutputModel](pydantic_object=GenerateTrendOutputModel) - ) + chain = trends_generation_prompt | merger | self._model | parse_generated_trends_output try: message: GenerateTrendOutputModel = chain.invoke({}, config) - except OutputParserException: + except PydanticOutputParserException as e: + # Generation step is expensive. After a second unsuccessful attempt, it's better to send a failure message. + if len(intermediate_steps) >= 2: + return { + "messages": [ + FailureMessage( + content="Oops! It looks like I’m having trouble generating this trends insight. Could you please try again?" + ) + ], + "intermediate_steps": None, + } + return { - "messages": [VisualizationMessage(plan=generated_plan, reasoning_steps=["Schema validation failed"])] + "intermediate_steps": [ + *intermediate_steps, + (AgentAction("handle_incorrect_response", e.llm_output, e.validation_message), None), + ], } return { @@ -273,11 +315,12 @@ def run(self, state: AssistantState, config: RunnableConfig): reasoning_steps=message.reasoning_steps, answer=message.answer, ) - ] + ], + "intermediate_steps": None, } def router(self, state: AssistantState): - if state.get("tool_argument") is not None: + if state.get("intermediate_steps") is not None: return AssistantNodeName.GENERATE_TRENDS_TOOLS return AssistantNodeName.END @@ -301,7 +344,9 @@ def _group_mapping_prompt(self) -> str: ) return ET.tostring(root, encoding="unicode") - def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: + def _reconstruct_conversation( + self, state: AssistantState, validation_error_message: Optional[str] = None + ) -> list[BaseMessage]: """ Reconstruct the conversation for the generation. Take all previously generated questions, plans, and schemas, and return the history. """ @@ -317,22 +362,7 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: ) ] - stack: list[LangchainHumanMessage] = [] - human_messages: list[LangchainHumanMessage] = [] - visualization_messages: list[VisualizationMessage] = [] - - for message in messages: - if isinstance(message, HumanMessage): - stack.append(LangchainHumanMessage(content=message.content)) - elif isinstance(message, VisualizationMessage) and message.answer: - if stack: - human_messages += merge_message_runs(stack) - stack = [] - visualization_messages.append(message) - - if stack: - human_messages += merge_message_runs(stack) - + human_messages, visualization_messages = filter_trends_conversation(messages) first_ai_message = True for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages): @@ -364,6 +394,13 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "") ) + if validation_error_message: + conversation.append( + HumanMessagePromptTemplate.from_template(trends_failover_prompt, template_format="mustache").format( + validation_error_message=validation_error_message + ) + ) + return conversation @classmethod @@ -382,4 +419,20 @@ class GenerateTrendsToolsNode(AssistantNode): name = AssistantNodeName.GENERATE_TRENDS_TOOLS def run(self, state: AssistantState, config: RunnableConfig): - return state + intermediate_steps = state.get("intermediate_steps", []) + if not intermediate_steps: + return state + + action, _ = intermediate_steps[-1] + prompt = ( + ChatPromptTemplate.from_template(trends_failover_output_prompt, template_format="mustache") + .format_messages(output=action.tool_input, exception_message=action.log)[0] + .content + ) + + return { + "intermediate_steps": [ + *intermediate_steps[:-1], + (action, prompt), + ] + } diff --git a/ee/hogai/trends/parsers.py b/ee/hogai/trends/parsers.py new file mode 100644 index 0000000000000..e66f974576939 --- /dev/null +++ b/ee/hogai/trends/parsers.py @@ -0,0 +1,80 @@ +import json +import re + +from langchain_core.agents import AgentAction +from langchain_core.messages import AIMessage as LangchainAIMessage +from pydantic import ValidationError + +from ee.hogai.trends.utils import GenerateTrendOutputModel + + +class ReActParserException(ValueError): + llm_output: str + + def __init__(self, llm_output: str): + super().__init__(llm_output) + self.llm_output = llm_output + + +class ReActParserMalformedJsonException(ReActParserException): + pass + + +class ReActParserMissingActionException(ReActParserException): + """ + The ReAct agent didn't output the "Action:" block. + """ + + pass + + +ACTION_LOG_PREFIX = "Action:" + + +def parse_react_agent_output(message: LangchainAIMessage) -> AgentAction: + """ + A ReAct agent must output in this format: + + Some thoughts... + Action: + ```json + {"action": "action_name", "action_input": "action_input"} + ``` + """ + text = str(message.content) + if ACTION_LOG_PREFIX not in text: + raise ReActParserMissingActionException(text) + found = re.compile(r"^.*?`{3}(?:json)?\n?(.*?)`{3}.*?$", re.DOTALL).search(text) + if not found: + # JSON not found. + raise ReActParserMalformedJsonException(text) + try: + action = found.group(1).strip() + response = json.loads(action) + is_complete = "action" in response and "action_input" in response + except Exception: + # JSON is malformed or has a wrong type. + raise ReActParserMalformedJsonException(text) + if not is_complete: + # JSON does not contain an action. + raise ReActParserMalformedJsonException(text) + return AgentAction(response["action"], response.get("action_input", {}), text) + + +class PydanticOutputParserException(ValueError): + llm_output: str + """Serialized LLM output.""" + validation_message: str + """Pydantic validation error message.""" + + def __init__(self, llm_output: str, validation_message: str): + super().__init__(llm_output) + self.llm_output = llm_output + self.validation_message = validation_message + + +def parse_generated_trends_output(output: dict) -> GenerateTrendOutputModel: + try: + return GenerateTrendOutputModel.model_validate(output) + except ValidationError as e: + raise PydanticOutputParserException(llm_output=json.dumps(output), validation_message=e.json(include_url=False)) diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py index c53ae5d3453a5..2543b1efc26e0 100644 --- a/ee/hogai/trends/prompts.py +++ b/ee/hogai/trends/prompts.py @@ -178,6 +178,29 @@ Improve the previously generated plan based on the feedback: {{feedback}} """ +react_missing_action_prompt = """ +Your previous answer didn't output the `Action:` block. You must always follow the format described in the system prompt. +""" + +react_missing_action_correction_prompt = """ +{{output}} +Action: I didn't output the `Action:` block. +""" + +react_malformed_json_prompt = """ +Your previous answer had a malformed JSON. You must return a correct JSON response containing the `action` and `action_input` fields. +""" + +react_pydantic_validation_exception_prompt = """ +The action input you previously provided didn't pass the validation and raised a Pydantic validation exception. + + +{{exception}} + + +You must fix the exception and try again. +""" + trends_system_prompt = """ You're a recognized head of product growth with the skills of a top-tier data engineer. Your task is to implement queries of trends insights for customers using a JSON schema. You will be given a plan describing series and breakdowns. Answer the user's questions as best you can. @@ -269,3 +292,23 @@ trends_question_prompt = """ Answer to this question: {{question}} """ + +trends_failover_output_prompt = """ +Generation output: +``` +{{output}} +``` + +Exception message: +``` +{{exception_message}} +``` +""" + +trends_failover_prompt = """ +The result of the previous generation raised the Pydantic validation exception. + +{{validation_error_message}} + +Fix the error and return the correct response. +""" diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index dc297570c1fd1..1e89c45458a9a 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -1,7 +1,26 @@ +import json +from unittest.mock import patch + from django.test import override_settings +from langchain_core.agents import AgentAction +from langchain_core.messages import AIMessage as LangchainAIMessage +from langchain_core.runnables import RunnableLambda -from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode -from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage +from ee.hogai.trends.nodes import ( + CreateTrendsPlanNode, + CreateTrendsPlanToolsNode, + GenerateTrendsNode, + GenerateTrendsToolsNode, +) +from ee.hogai.trends.utils import GenerateTrendOutputModel +from ee.hogai.utils import AssistantNodeName +from posthog.schema import ( + AssistantMessage, + ExperimentalAITrendsQuery, + FailureMessage, + HumanMessage, + VisualizationMessage, +) from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person @@ -68,6 +87,22 @@ def test_agent_reconstructs_conversation_and_omits_unknown_messages(self): self.assertIn("Text", history[0].content) self.assertNotIn("{{question}}", history[0].content) + def test_agent_reconstructs_conversation_with_failures(self): + node = CreateTrendsPlanNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + ] + } + ) + self.assertEqual(len(history), 1) + self.assertEqual(history[0].type, "human") + self.assertIn("Text", history[0].content) + self.assertNotIn("{{question}}", history[0].content) + def test_agent_filters_out_low_count_events(self): _create_person(distinct_ids=["test"], team=self.team) for i in range(26): @@ -86,12 +121,101 @@ def test_agent_preserves_low_count_events_for_smaller_teams(self): self.assertIn("distinctevent", node._events_prompt) self.assertIn("all events", node._events_prompt) + def test_agent_scratchpad(self): + node = CreateTrendsPlanNode(self.team) + scratchpad = [ + (AgentAction(tool="test1", tool_input="input1", log="log1"), "test"), + (AgentAction(tool="test2", tool_input="input2", log="log2"), None), + (AgentAction(tool="test3", tool_input="input3", log="log3"), ""), + ] + prompt = node._get_agent_scratchpad(scratchpad) + self.assertIn("log1", prompt) + self.assertIn("log3", prompt) + + def test_agent_handles_output_without_action_block(self): + with patch( + "ee.hogai.trends.nodes.CreateTrendsPlanNode._model", + return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")), + ): + node = CreateTrendsPlanNode(self.team) + state_update = node.run({"messages": [HumanMessage(content="Question")]}, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, obs = state_update["intermediate_steps"][0] + self.assertIsNone(obs) + self.assertIn("I don't want to output an action.", action.log) + self.assertIn("Action:", action.log) + self.assertIn("Action:", action.tool_input) + + def test_agent_handles_output_with_malformed_json(self): + with patch( + "ee.hogai.trends.nodes.CreateTrendsPlanNode._model", + return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")), + ): + node = CreateTrendsPlanNode(self.team) + state_update = node.run({"messages": [HumanMessage(content="Question")]}, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, obs = state_update["intermediate_steps"][0] + self.assertIsNone(obs) + self.assertIn("Thought.\nAction: abc", action.log) + self.assertIn("action", action.tool_input) + self.assertIn("action_input", action.tool_input) + + +@override_settings(IN_UNIT_TESTING=True) +class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest): + def test_node_handles_action_name_validation_error(self): + state = { + "intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")], + "messages": [], + } + node = CreateTrendsPlanToolsNode(self.team) + state_update = node.run(state, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, observation = state_update["intermediate_steps"][0] + self.assertIsNotNone(observation) + self.assertIn("", observation) + + def test_node_handles_action_input_validation_error(self): + state = { + "intermediate_steps": [ + (AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test") + ], + "messages": [], + } + node = CreateTrendsPlanToolsNode(self.team) + state_update = node.run(state, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, observation = state_update["intermediate_steps"][0] + self.assertIsNotNone(observation) + self.assertIn("", observation) + @override_settings(IN_UNIT_TESTING=True) class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest): def setUp(self): self.schema = ExperimentalAITrendsQuery(series=[]) + def test_node_runs(self): + node = GenerateTrendsNode(self.team) + with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: + generator_model_mock.return_value = RunnableLambda( + lambda _: GenerateTrendOutputModel(reasoning_steps=["step"], answer=self.schema).model_dump() + ) + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "plan": "Plan", + }, + {}, + ) + self.assertEqual( + new_state, + { + "messages": [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])], + "intermediate_steps": None, + }, + ) + def test_agent_reconstructs_conversation(self): node = GenerateTrendsNode(self.team) history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")]}) @@ -203,3 +327,145 @@ def test_agent_reconstructs_conversation_and_merges_messages(self): self.assertIn("Answer to this question:", history[5].content) self.assertNotIn("{{question}}", history[5].content) self.assertIn("Follow\nUp", history[5].content) + + def test_failover_with_incorrect_schema(self): + node = GenerateTrendsNode(self.team) + with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: + schema = GenerateTrendOutputModel(reasoning_steps=[], answer=None).model_dump() + # Emulate an incorrect JSON. It should be an object. + schema["answer"] = [] + generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) + + new_state = node.run({"messages": [HumanMessage(content="Text")]}, {}) + self.assertIn("intermediate_steps", new_state) + self.assertEqual(len(new_state["intermediate_steps"]), 1) + + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], + }, + {}, + ) + self.assertIn("intermediate_steps", new_state) + self.assertEqual(len(new_state["intermediate_steps"]), 2) + + def test_node_leaves_failover(self): + node = GenerateTrendsNode(self.team) + with patch( + "ee.hogai.trends.nodes.GenerateTrendsNode._model", + return_value=RunnableLambda( + lambda _: GenerateTrendOutputModel(reasoning_steps=[], answer=self.schema).model_dump() + ), + ): + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], + }, + {}, + ) + self.assertIsNone(new_state["intermediate_steps"]) + + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [ + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + ], + }, + {}, + ) + self.assertIsNone(new_state["intermediate_steps"]) + + def test_node_leaves_failover_after_second_unsuccessful_attempt(self): + node = GenerateTrendsNode(self.team) + with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: + schema = GenerateTrendOutputModel(reasoning_steps=[], answer=None).model_dump() + # Emulate an incorrect JSON. It should be an object. + schema["answer"] = [] + generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) + + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [ + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + ], + }, + {}, + ) + self.assertIsNone(new_state["intermediate_steps"]) + self.assertEqual(len(new_state["messages"]), 1) + self.assertIsInstance(new_state["messages"][0], FailureMessage) + + def test_agent_reconstructs_conversation_with_failover(self): + action = AgentAction(tool="fix", tool_input="validation error", log="exception") + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [HumanMessage(content="Text")], + "plan": "randomplan", + "intermediate_steps": [(action, "uniqexception")], + }, + "uniqexception", + ) + self.assertEqual(len(history), 4) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Text", history[2].content) + self.assertEqual(history[3].type, "human") + self.assertIn("Pydantic", history[3].content) + self.assertIn("uniqexception", history[3].content) + + def test_agent_reconstructs_conversation_with_failed_messages(self): + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + ], + "plan": "randomplan", + }, + ) + self.assertEqual(len(history), 3) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Text", history[2].content) + + def test_router(self): + node = GenerateTrendsNode(self.team) + state = node.router({"messages": [], "intermediate_steps": None}) + self.assertEqual(state, AssistantNodeName.END) + state = node.router( + {"messages": [], "intermediate_steps": [(AgentAction(tool="", tool_input="", log=""), None)]} + ) + self.assertEqual(state, AssistantNodeName.GENERATE_TRENDS_TOOLS) + + +class TestGenerateTrendsToolsNode(ClickhouseTestMixin, APIBaseTest): + def test_tools_node(self): + node = GenerateTrendsToolsNode(self.team) + action = AgentAction(tool="fix", tool_input="validationerror", log="pydanticexception") + state = node.run({"messages": [], "intermediate_steps": [(action, None)]}, {}) + self.assertIsNotNone("validationerror", state["intermediate_steps"][0][1]) + self.assertIn("validationerror", state["intermediate_steps"][0][1]) + self.assertIn("pydanticexception", state["intermediate_steps"][0][1]) diff --git a/ee/hogai/trends/test/test_parsers.py b/ee/hogai/trends/test/test_parsers.py new file mode 100644 index 0000000000000..c32ff7f146b4d --- /dev/null +++ b/ee/hogai/trends/test/test_parsers.py @@ -0,0 +1,78 @@ +from langchain_core.messages import AIMessage as LangchainAIMessage + +from ee.hogai.trends.parsers import ( + ReActParserMalformedJsonException, + ReActParserMissingActionException, + parse_react_agent_output, +) +from posthog.test.base import BaseTest + + +class TestParsers(BaseTest): + def test_parse_react_agent_output(self): + res = parse_react_agent_output( + LangchainAIMessage( + content=""" + Some thoughts... + Action: + ```json + {"action": "action_name", "action_input": "action_input"} + ``` + """ + ) + ) + self.assertEqual(res.tool, "action_name") + self.assertEqual(res.tool_input, "action_input") + + res = parse_react_agent_output( + LangchainAIMessage( + content=""" + Some thoughts... + Action: + ``` + {"action": "tool", "action_input": {"key": "value"}} + ``` + """ + ) + ) + self.assertEqual(res.tool, "tool") + self.assertEqual(res.tool_input, {"key": "value"}) + + self.assertRaises( + ReActParserMissingActionException, parse_react_agent_output, LangchainAIMessage(content="Some thoughts...") + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction: abc"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction:"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction: {}"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{}\n```"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{not a json}\n```"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action":"tool"}\n```'), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action_input":"input"}\n```'), + ) diff --git a/ee/hogai/trends/test/test_utils.py b/ee/hogai/trends/test/test_utils.py new file mode 100644 index 0000000000000..de9b8733129ec --- /dev/null +++ b/ee/hogai/trends/test/test_utils.py @@ -0,0 +1,39 @@ +from langchain_core.messages import HumanMessage as LangchainHumanMessage + +from ee.hogai.trends import utils +from posthog.schema import ExperimentalAITrendsQuery, FailureMessage, HumanMessage, VisualizationMessage +from posthog.test.base import BaseTest + + +class TestTrendsUtils(BaseTest): + def test_merge_human_messages(self): + res = utils.merge_human_messages( + [ + LangchainHumanMessage(content="Text"), + LangchainHumanMessage(content="Text"), + LangchainHumanMessage(content="Te"), + LangchainHumanMessage(content="xt"), + ] + ) + self.assertEqual(len(res), 1) + self.assertEqual(res, [LangchainHumanMessage(content="Text\nTe\nxt")]) + + def test_filter_trends_conversation(self): + human_messages, visualization_messages = utils.filter_trends_conversation( + [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan"), + HumanMessage(content="Text2"), + VisualizationMessage(answer=None, plan="plan"), + ] + ) + self.assertEqual(len(human_messages), 2) + self.assertEqual(len(visualization_messages), 1) + self.assertEqual( + human_messages, [LangchainHumanMessage(content="Text"), LangchainHumanMessage(content="Text2")] + ) + self.assertEqual( + visualization_messages, [VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan")] + ) diff --git a/ee/hogai/trends/utils.py b/ee/hogai/trends/utils.py index 080f85f0256d0..5e1a8052707c8 100644 --- a/ee/hogai/trends/utils.py +++ b/ee/hogai/trends/utils.py @@ -1,10 +1,53 @@ +from collections.abc import Sequence from typing import Optional +from langchain_core.messages import HumanMessage as LangchainHumanMessage +from langchain_core.messages import merge_message_runs from pydantic import BaseModel -from posthog.schema import ExperimentalAITrendsQuery +from ee.hogai.utils import AssistantMessageUnion +from posthog.schema import ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage class GenerateTrendOutputModel(BaseModel): reasoning_steps: Optional[list[str]] = None answer: Optional[ExperimentalAITrendsQuery] = None + + +def merge_human_messages(messages: list[LangchainHumanMessage]) -> list[LangchainHumanMessage]: + """ + Filters out duplicated human messages and merges them into one message. + """ + contents = set() + filtered_messages = [] + for message in messages: + if message.content in contents: + continue + contents.add(message.content) + filtered_messages.append(message) + return merge_message_runs(filtered_messages) + + +def filter_trends_conversation( + messages: Sequence[AssistantMessageUnion], +) -> tuple[list[LangchainHumanMessage], list[VisualizationMessage]]: + """ + Splits, filters and merges the message history to be consumable by agents. + """ + stack: list[LangchainHumanMessage] = [] + human_messages: list[LangchainHumanMessage] = [] + visualization_messages: list[VisualizationMessage] = [] + + for message in messages: + if isinstance(message, HumanMessage): + stack.append(LangchainHumanMessage(content=message.content)) + elif isinstance(message, VisualizationMessage) and message.answer: + if stack: + human_messages += merge_human_messages(stack) + stack = [] + visualization_messages.append(message) + + if stack: + human_messages += merge_human_messages(stack) + + return human_messages, visualization_messages diff --git a/ee/hogai/utils.py b/ee/hogai/utils.py index 65de9303b3ffc..70d1ea969621a 100644 --- a/ee/hogai/utils.py +++ b/ee/hogai/utils.py @@ -10,9 +10,9 @@ from pydantic import BaseModel, Field from posthog.models.team.team import Team -from posthog.schema import AssistantMessage, HumanMessage, RootAssistantMessage, VisualizationMessage +from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, RootAssistantMessage, VisualizationMessage -AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage] +AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage, FailureMessage] class Conversation(BaseModel): @@ -24,7 +24,6 @@ class AssistantState(TypedDict): messages: Annotated[Sequence[AssistantMessageUnion], operator.add] intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]] plan: Optional[str] - tool_argument: Optional[str] class AssistantNodeName(StrEnum): diff --git a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png index d6fe88cab16e5..eedf3e01916a0 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png and b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png index 50e2dea81e948..c19e174167b7c 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png and b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png new file mode 100644 index 0000000000000..84a2a7828cdd5 Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png new file mode 100644 index 0000000000000..43ad8593f41e8 Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png b/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png index d0f525ffb382c..80aded4c79433 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png and b/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread--light.png b/frontend/__snapshots__/scenes-app-max-ai--thread--light.png index f8d64397cb918..22c88333171fd 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--thread--light.png and b/frontend/__snapshots__/scenes-app-max-ai--thread--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--dark.png b/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--dark.png new file mode 100644 index 0000000000000..eccd3471a1f51 Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--light.png b/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--light.png new file mode 100644 index 0000000000000..da0efa6c70d73 Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--light.png differ diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 6c76fa1d919f4..e8baa13682fd7 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -617,6 +617,24 @@ } ] }, + "AssistantEventType": { + "enum": ["status", "message"], + "type": "string" + }, + "AssistantGenerationStatusEvent": { + "additionalProperties": false, + "properties": { + "type": { + "$ref": "#/definitions/AssistantGenerationStatusType" + } + }, + "required": ["type"], + "type": "object" + }, + "AssistantGenerationStatusType": { + "enum": ["ack", "generation_error"], + "type": "string" + }, "AssistantMessage": { "additionalProperties": false, "properties": { @@ -632,7 +650,7 @@ "type": "object" }, "AssistantMessageType": { - "enum": ["human", "ai", "ai/viz"], + "enum": ["human", "ai", "ai/viz", "ai/failure"], "type": "string" }, "AutocompleteCompletionItem": { @@ -5239,6 +5257,9 @@ "experiment_id": { "type": "integer" }, + "funnels_query": { + "$ref": "#/definitions/FunnelsQuery" + }, "kind": { "const": "ExperimentFunnelsQuery", "type": "string" @@ -5249,12 +5270,9 @@ }, "response": { "$ref": "#/definitions/ExperimentFunnelsQueryResponse" - }, - "source": { - "$ref": "#/definitions/FunnelsQuery" } }, - "required": ["experiment_id", "kind", "source"], + "required": ["experiment_id", "funnels_query", "kind"], "type": "object" }, "ExperimentFunnelsQueryResponse": { @@ -5537,6 +5555,20 @@ "required": ["kind", "series"], "type": "object" }, + "FailureMessage": { + "additionalProperties": false, + "properties": { + "content": { + "type": "string" + }, + "type": { + "const": "ai/failure", + "type": "string" + } + }, + "required": ["type"], + "type": "object" + }, "FeaturePropertyFilter": { "additionalProperties": false, "properties": { @@ -10881,6 +10913,9 @@ }, { "$ref": "#/definitions/HumanMessage" + }, + { + "$ref": "#/definitions/FailureMessage" } ] }, diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index cf4208b331062..a437aa9cea660 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -1654,7 +1654,7 @@ export type CachedExperimentFunnelsQueryResponse = CachedQueryResponse { kind: NodeKind.ExperimentFunnelsQuery - source: FunnelsQuery + funnels_query: FunnelsQuery experiment_id: integer } @@ -2104,6 +2104,7 @@ export enum AssistantMessageType { Human = 'human', Assistant = 'ai', Visualization = 'ai/viz', + Failure = 'ai/failure', } export interface HumanMessage { @@ -2123,4 +2124,23 @@ export interface VisualizationMessage { answer?: ExperimentalAITrendsQuery } -export type RootAssistantMessage = VisualizationMessage | AssistantMessage | HumanMessage +export interface FailureMessage { + type: AssistantMessageType.Failure + content?: string +} + +export type RootAssistantMessage = VisualizationMessage | AssistantMessage | HumanMessage | FailureMessage + +export enum AssistantEventType { + Status = 'status', + Message = 'message', +} + +export enum AssistantGenerationStatusType { + Acknowledged = 'ack', + GenerationError = 'generation_error', +} + +export interface AssistantGenerationStatusEvent { + type: AssistantGenerationStatusType +} diff --git a/frontend/src/scenes/experiments/MetricSelector.tsx b/frontend/src/scenes/experiments/MetricSelector.tsx index 4f4e7b6e1e262..e9af066af176f 100644 --- a/frontend/src/scenes/experiments/MetricSelector.tsx +++ b/frontend/src/scenes/experiments/MetricSelector.tsx @@ -133,6 +133,11 @@ export function ExperimentInsightCreator({ insightProps }: { insightProps: Insig sortable={isTrends ? undefined : true} showNestedArrow={isTrends ? undefined : true} showNumericalPropsOnly={isTrends} + actionsTaxonomicGroupTypes={[ + TaxonomicFilterGroupType.Events, + TaxonomicFilterGroupType.Actions, + TaxonomicFilterGroupType.DataWarehouse, + ]} propertiesTaxonomicGroupTypes={[ TaxonomicFilterGroupType.EventProperties, TaxonomicFilterGroupType.PersonProperties, diff --git a/frontend/src/scenes/experiments/experimentLogic.tsx b/frontend/src/scenes/experiments/experimentLogic.tsx index 4db270269a634..3e365140b96d5 100644 --- a/frontend/src/scenes/experiments/experimentLogic.tsx +++ b/frontend/src/scenes/experiments/experimentLogic.tsx @@ -827,9 +827,42 @@ export const experimentLogic = kea([ }, ], secondaryMetricResults: [ - null as SecondaryMetricResults[] | null, + null as + | SecondaryMetricResults[] + | (CachedExperimentTrendsQueryResponse | CachedExperimentFunnelsQueryResponse)[] + | null, { - loadSecondaryMetricResults: async (refresh?: boolean) => { + loadSecondaryMetricResults: async ( + refresh?: boolean + ): Promise< + | SecondaryMetricResults[] + | (CachedExperimentTrendsQueryResponse | CachedExperimentFunnelsQueryResponse)[] + | null + > => { + if (values.featureFlags[FEATURE_FLAGS.EXPERIMENTS_HOGQL]) { + const secondaryMetrics = + values.experiment?.metrics?.filter((metric) => metric.type === 'secondary') || [] + + return (await Promise.all( + secondaryMetrics.map(async (metric) => { + try { + const response: ExperimentResults = await api.create( + `api/projects/${values.currentTeamId}/query`, + { query: metric.query } + ) + + return { + ...response, + fakeInsightId: Math.random().toString(36).substring(2, 15), + last_refresh: response.last_refresh || '', + } + } catch (error) { + return {} + } + }) + )) as unknown as (CachedExperimentTrendsQueryResponse | CachedExperimentFunnelsQueryResponse)[] + } + const refreshParam = refresh ? '&refresh=true' : '' return await Promise.all( @@ -846,6 +879,7 @@ export const experimentLogic = kea([ last_refresh: secResults.last_refresh, } } + return { ...secResults.result, fakeInsightId: Math.random().toString(36).substring(2, 15), @@ -1255,9 +1289,10 @@ export const experimentLogic = kea([ | CachedExperimentTrendsQueryResponse | CachedExperimentFunnelsQueryResponse | null, - variant: string + variant: string, + type: 'primary' | 'secondary' = 'primary' ): number | null => { - const usingMathAggregationType = experimentMathAggregationForTrends() + const usingMathAggregationType = type === 'primary' ? experimentMathAggregationForTrends() : false if (!experimentResults || !experimentResults.insight) { return null } @@ -1392,15 +1427,31 @@ export const experimentLogic = kea([ }, ], tabularSecondaryMetricResults: [ - (s) => [s.experiment, s.secondaryMetricResults], - (experiment, secondaryMetricResults): TabularSecondaryMetricResults[] => { + (s) => [s.experiment, s.secondaryMetricResults, s.conversionRateForVariant, s.countDataForVariant], + ( + experiment, + secondaryMetricResults, + conversionRateForVariant, + countDataForVariant + ): TabularSecondaryMetricResults[] => { + if (!secondaryMetricResults) { + return [] + } + const variantsWithResults: TabularSecondaryMetricResults[] = [] experiment?.parameters?.feature_flag_variants?.forEach((variant) => { const metricResults: SecondaryMetricResult[] = [] experiment?.secondary_metrics?.forEach((metric, idx) => { + let result + if (metric.filters.insight === InsightType.FUNNELS) { + result = conversionRateForVariant(secondaryMetricResults?.[idx], variant.key) + } else { + result = countDataForVariant(secondaryMetricResults?.[idx], variant.key, 'secondary') + } + metricResults.push({ insightType: metric.filters.insight || InsightType.TRENDS, - result: secondaryMetricResults?.[idx]?.result?.[variant.key], + result: result || undefined, }) }) diff --git a/frontend/src/scenes/max/Max.stories.tsx b/frontend/src/scenes/max/Max.stories.tsx index 27045963c6e42..88ffe32ea7559 100644 --- a/frontend/src/scenes/max/Max.stories.tsx +++ b/frontend/src/scenes/max/Max.stories.tsx @@ -1,10 +1,10 @@ import { Meta, StoryFn } from '@storybook/react' -import { BindLogic, useActions } from 'kea' +import { BindLogic, useActions, useValues } from 'kea' import { useEffect } from 'react' import { mswDecorator, useStorybookMocks } from '~/mocks/browser' -import { chatResponseChunk } from './__mocks__/chatResponse.mocks' +import { chatResponseChunk, failureChunk, generationFailureChunk } from './__mocks__/chatResponse.mocks' import { MaxInstance } from './Max' import { maxLogic } from './maxLogic' @@ -104,3 +104,43 @@ EmptyThreadLoading.parameters = { waitForLoadersToDisappear: false, }, } + +export const GenerationFailureThread: StoryFn = () => { + useStorybookMocks({ + post: { + '/api/environments/:team_id/query/chat/': (_, res, ctx) => res(ctx.text(generationFailureChunk)), + }, + }) + + const sessionId = 'd210b263-8521-4c5b-b3c4-8e0348df574b' + + const { askMax, setMessageStatus } = useActions(maxLogic({ sessionId })) + const { thread, threadLoading } = useValues(maxLogic({ sessionId })) + useEffect(() => { + askMax('What are my most popular pages?') + }, []) + useEffect(() => { + if (thread.length === 2 && !threadLoading) { + setMessageStatus(1, 'error') + } + }, [thread.length, threadLoading]) + + return