diff --git a/ee/api/feature_flag_role_access.py b/ee/api/feature_flag_role_access.py index 6d03c7a4f361c..01aa98a05b9db 100644 --- a/ee/api/feature_flag_role_access.py +++ b/ee/api/feature_flag_role_access.py @@ -1,10 +1,10 @@ from rest_framework import exceptions, mixins, serializers, viewsets from rest_framework.permissions import SAFE_METHODS, BasePermission -from ee.api.role import RoleSerializer +from ee.api.rbac.role import RoleSerializer from ee.models.feature_flag_role_access import FeatureFlagRoleAccess -from ee.models.organization_resource_access import OrganizationResourceAccess -from ee.models.role import Role +from ee.models.rbac.organization_resource_access import OrganizationResourceAccess +from ee.models.rbac.role import Role from posthog.api.feature_flag import FeatureFlagSerializer from posthog.api.routing import TeamAndOrgViewSetMixin from posthog.models import FeatureFlag diff --git a/ee/api/organization_resource_access.py b/ee/api/rbac/organization_resource_access.py similarity index 92% rename from ee/api/organization_resource_access.py rename to ee/api/rbac/organization_resource_access.py index bf886566605b5..9722fc7b02eac 100644 --- a/ee/api/organization_resource_access.py +++ b/ee/api/rbac/organization_resource_access.py @@ -1,7 +1,7 @@ from rest_framework import mixins, serializers, viewsets -from ee.api.role import RolePermissions -from ee.models.organization_resource_access import OrganizationResourceAccess +from ee.api.rbac.role import RolePermissions +from ee.models.rbac.organization_resource_access import OrganizationResourceAccess from posthog.api.routing import TeamAndOrgViewSetMixin diff --git a/ee/api/role.py b/ee/api/rbac/role.py similarity index 97% rename from ee/api/role.py rename to ee/api/rbac/role.py index 96041cd0109ef..ccf8acef1f1dc 100644 --- a/ee/api/role.py +++ b/ee/api/rbac/role.py @@ -5,8 +5,8 @@ from rest_framework.permissions import SAFE_METHODS, BasePermission from ee.models.feature_flag_role_access import FeatureFlagRoleAccess -from ee.models.organization_resource_access import OrganizationResourceAccess -from ee.models.role import Role, RoleMembership +from ee.models.rbac.organization_resource_access import OrganizationResourceAccess +from ee.models.rbac.role import Role, RoleMembership from posthog.api.organization_member import OrganizationMemberSerializer from posthog.api.routing import TeamAndOrgViewSetMixin from posthog.api.shared import UserBasicSerializer diff --git a/ee/api/test/test_feature_flag.py b/ee/api/test/test_feature_flag.py index e3dd5849d607c..0bc7292f7a875 100644 --- a/ee/api/test/test_feature_flag.py +++ b/ee/api/test/test_feature_flag.py @@ -1,6 +1,6 @@ from ee.api.test.base import APILicensedTest -from ee.models.organization_resource_access import OrganizationResourceAccess -from ee.models.role import Role, RoleMembership +from ee.models.rbac.organization_resource_access import OrganizationResourceAccess +from ee.models.rbac.role import Role, RoleMembership from posthog.models.feature_flag import FeatureFlag from posthog.models.organization import OrganizationMembership diff --git a/ee/api/test/test_feature_flag_role_access.py b/ee/api/test/test_feature_flag_role_access.py index 3cd4e947d90c9..d73c1c7384493 100644 --- a/ee/api/test/test_feature_flag_role_access.py +++ b/ee/api/test/test_feature_flag_role_access.py @@ -2,8 +2,8 @@ from ee.api.test.base import APILicensedTest from ee.models.feature_flag_role_access import FeatureFlagRoleAccess -from ee.models.organization_resource_access import OrganizationResourceAccess -from ee.models.role import Role +from ee.models.rbac.organization_resource_access import OrganizationResourceAccess +from ee.models.rbac.role import Role from posthog.models.feature_flag import FeatureFlag from posthog.models.organization import OrganizationMembership from posthog.models.user import User diff --git a/ee/api/test/test_organization_resource_access.py b/ee/api/test/test_organization_resource_access.py index 9123214a092db..98206fe519f44 100644 --- a/ee/api/test/test_organization_resource_access.py +++ b/ee/api/test/test_organization_resource_access.py @@ -2,7 +2,7 @@ from rest_framework import status from ee.api.test.base import APILicensedTest -from ee.models.organization_resource_access import OrganizationResourceAccess +from ee.models.rbac.organization_resource_access import OrganizationResourceAccess from posthog.models.organization import Organization, OrganizationMembership from posthog.test.base import QueryMatchingTest, snapshot_postgres_queries, FuzzyInt diff --git a/ee/api/test/test_role.py b/ee/api/test/test_role.py index 1a3068ff4cf4f..96503162d5fe9 100644 --- a/ee/api/test/test_role.py +++ b/ee/api/test/test_role.py @@ -2,8 +2,8 @@ from rest_framework import status from ee.api.test.base import APILicensedTest -from ee.models.organization_resource_access import OrganizationResourceAccess -from ee.models.role import Role +from ee.models.rbac.organization_resource_access import OrganizationResourceAccess +from ee.models.rbac.role import Role from posthog.models.organization import Organization, OrganizationMembership diff --git a/ee/api/test/test_role_membership.py b/ee/api/test/test_role_membership.py index f89796d9b7c4f..c3e67cf0514d2 100644 --- a/ee/api/test/test_role_membership.py +++ b/ee/api/test/test_role_membership.py @@ -1,7 +1,7 @@ from rest_framework import status from ee.api.test.base import APILicensedTest -from ee.models.role import Role, RoleMembership +from ee.models.rbac.role import Role, RoleMembership from posthog.models.organization import Organization, OrganizationMembership from posthog.models.user import User diff --git a/ee/clickhouse/views/experiments.py b/ee/clickhouse/views/experiments.py index 6df24dc012cea..44fe8b72b5045 100644 --- a/ee/clickhouse/views/experiments.py +++ b/ee/clickhouse/views/experiments.py @@ -184,6 +184,7 @@ class Meta: "created_by", "created_at", "updated_at", + "metrics", ] read_only_fields = [ "id", diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index e47020fdcdf04..d01181b937c4e 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -4,12 +4,18 @@ from langchain_core.messages import AIMessageChunk from langfuse.callback import CallbackHandler from langgraph.graph.state import StateGraph +from pydantic import BaseModel from ee import settings -from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode +from ee.hogai.trends.nodes import ( + CreateTrendsPlanNode, + CreateTrendsPlanToolsNode, + GenerateTrendsNode, + GenerateTrendsToolsNode, +) from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation from posthog.models.team.team import Team -from posthog.schema import VisualizationMessage +from posthog.schema import AssistantGenerationStatusEvent, AssistantGenerationStatusType, VisualizationMessage if settings.LANGFUSE_PUBLIC_KEY: langfuse_handler = CallbackHandler( @@ -39,6 +45,13 @@ def is_message_update( return len(update) == 2 and update[0] == "messages" +def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]: + """ + Update of the state. + """ + return len(update) == 2 and update[0] == "values" + + class Assistant: _team: Team _graph: StateGraph @@ -59,6 +72,10 @@ def _compile_graph(self): generate_trends_node = GenerateTrendsNode(self._team) builder.add_node(GenerateTrendsNode.name, generate_trends_node.run) + generate_trends_tools_node = GenerateTrendsToolsNode(self._team) + builder.add_node(GenerateTrendsToolsNode.name, generate_trends_tools_node.run) + builder.add_edge(GenerateTrendsToolsNode.name, GenerateTrendsNode.name) + builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name) builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router) builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router) @@ -66,31 +83,45 @@ def _compile_graph(self): return builder.compile() - def stream(self, conversation: Conversation) -> Generator[str, None, None]: + def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]: assistant_graph = self._compile_graph() callbacks = [langfuse_handler] if langfuse_handler else [] messages = [message.root for message in conversation.messages] + chunks = AIMessageChunk(content="") + state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None} + generator = assistant_graph.stream( - {"messages": messages}, + state, config={"recursion_limit": 24, "callbacks": callbacks}, - stream_mode=["messages", "updates"], + stream_mode=["messages", "values", "updates"], ) chunks = AIMessageChunk(content="") # Send a chunk to establish the connection avoiding the worker's timeout. - yield "" + yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK) for update in generator: - if is_value_update(update): + if is_state_update(update): + _, new_state = update + state = new_state + + elif is_value_update(update): _, state_update = update - if ( - AssistantNodeName.GENERATE_TRENDS in state_update - and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS] - ): - message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]) - yield message.model_dump_json() + + if AssistantNodeName.GENERATE_TRENDS in state_update: + # Reset chunks when schema validation fails. + chunks = AIMessageChunk(content="") + + if "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]: + message = cast( + VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0] + ) + yield message + elif state_update[AssistantNodeName.GENERATE_TRENDS].get("intermediate_steps", []): + yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR) + elif is_message_update(update): langchain_message, langgraph_state = update[1] if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance( @@ -101,4 +132,4 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: if parsed_message: yield VisualizationMessage( reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer - ).model_dump_json() + ) diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index 6ca6ed8b50d50..d1819b49b705f 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -1,25 +1,23 @@ import itertools -import json import xml.etree.ElementTree as ET from functools import cached_property -from typing import cast +from typing import Optional, cast from langchain.agents.format_scratchpad import format_log_to_str from langchain_core.agents import AgentAction -from langchain_core.exceptions import OutputParserException from langchain_core.messages import AIMessage as LangchainAssistantMessage from langchain_core.messages import BaseMessage, merge_message_runs -from langchain_core.messages import HumanMessage as LangchainHumanMessage -from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI from pydantic import ValidationError from ee.hogai.hardcoded_definitions import hardcoded_prop_defs from ee.hogai.trends.parsers import ( + PydanticOutputParserException, ReActParserException, ReActParserMissingActionException, + parse_generated_trends_output, parse_react_agent_output, ) from ee.hogai.trends.prompts import ( @@ -32,6 +30,8 @@ react_scratchpad_prompt, react_system_prompt, react_user_prompt, + trends_failover_output_prompt, + trends_failover_prompt, trends_group_mapping_prompt, trends_new_plan_prompt, trends_plan_prompt, @@ -43,7 +43,7 @@ TrendsAgentToolkit, TrendsAgentToolModel, ) -from ee.hogai.trends.utils import GenerateTrendOutputModel +from ee.hogai.trends.utils import GenerateTrendOutputModel, filter_trends_conversation from ee.hogai.utils import ( AssistantNode, AssistantNodeName, @@ -53,7 +53,12 @@ from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner from posthog.hogql_queries.query_runner import ExecutionMode from posthog.models.group_type_mapping import GroupTypeMapping -from posthog.schema import CachedTeamTaxonomyQueryResponse, HumanMessage, TeamTaxonomyQuery, VisualizationMessage +from posthog.schema import ( + CachedTeamTaxonomyQueryResponse, + FailureMessage, + TeamTaxonomyQuery, + VisualizationMessage, +) class CreateTrendsPlanNode(AssistantNode): @@ -180,26 +185,33 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: """ Reconstruct the conversation for the agent. On this step we only care about previously asked questions and generated plans. All other messages are filtered out. """ - messages = state.get("messages", []) - if len(messages) == 0: + human_messages, visualization_messages = filter_trends_conversation(state.get("messages", [])) + + if not human_messages: return [] - conversation = [ - HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format( - question=messages[0].content if isinstance(messages[0], HumanMessage) else "" - ) - ] + conversation = [] - for message in messages[1:]: - if isinstance(message, HumanMessage): - conversation.append( - HumanMessagePromptTemplate.from_template( - react_follow_up_prompt, - template_format="mustache", - ).format(feedback=message.content) - ) - elif isinstance(message, VisualizationMessage): - conversation.append(LangchainAssistantMessage(content=message.plan or "")) + for idx, messages in enumerate(itertools.zip_longest(human_messages, visualization_messages)): + human_message, viz_message = messages + + if human_message: + if idx == 0: + conversation.append( + HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format( + question=human_message.content + ) + ) + else: + conversation.append( + HumanMessagePromptTemplate.from_template( + react_follow_up_prompt, + template_format="mustache", + ).format(feedback=human_message.content) + ) + + if viz_message: + conversation.append(LangchainAssistantMessage(content=viz_message.plan or "")) return conversation @@ -262,30 +274,38 @@ class GenerateTrendsNode(AssistantNode): def run(self, state: AssistantState, config: RunnableConfig): generated_plan = state.get("plan", "") + intermediate_steps = state.get("intermediate_steps") or [] + validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None trends_generation_prompt = ChatPromptTemplate.from_messages( [ ("system", trends_system_prompt), ], template_format="mustache", - ) + self._reconstruct_conversation(state) + ) + self._reconstruct_conversation(state, validation_error_message=validation_error_message) merger = merge_message_runs() - chain = ( - trends_generation_prompt - | merger - | self._model - # Result from structured output is a parsed dict. Convert to a string since the output parser expects it. - | RunnableLambda(lambda x: json.dumps(x)) - # Validate a string input. - | PydanticOutputParser[GenerateTrendOutputModel](pydantic_object=GenerateTrendOutputModel) - ) + chain = trends_generation_prompt | merger | self._model | parse_generated_trends_output try: message: GenerateTrendOutputModel = chain.invoke({}, config) - except OutputParserException: + except PydanticOutputParserException as e: + # Generation step is expensive. After a second unsuccessful attempt, it's better to send a failure message. + if len(intermediate_steps) >= 2: + return { + "messages": [ + FailureMessage( + content="Oops! It looks like I’m having trouble generating this trends insight. Could you please try again?" + ) + ], + "intermediate_steps": None, + } + return { - "messages": [VisualizationMessage(plan=generated_plan, reasoning_steps=["Schema validation failed"])] + "intermediate_steps": [ + *intermediate_steps, + (AgentAction("handle_incorrect_response", e.llm_output, e.validation_message), None), + ], } return { @@ -295,11 +315,12 @@ def run(self, state: AssistantState, config: RunnableConfig): reasoning_steps=message.reasoning_steps, answer=message.answer, ) - ] + ], + "intermediate_steps": None, } def router(self, state: AssistantState): - if state.get("tool_argument") is not None: + if state.get("intermediate_steps") is not None: return AssistantNodeName.GENERATE_TRENDS_TOOLS return AssistantNodeName.END @@ -323,7 +344,9 @@ def _group_mapping_prompt(self) -> str: ) return ET.tostring(root, encoding="unicode") - def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: + def _reconstruct_conversation( + self, state: AssistantState, validation_error_message: Optional[str] = None + ) -> list[BaseMessage]: """ Reconstruct the conversation for the generation. Take all previously generated questions, plans, and schemas, and return the history. """ @@ -339,22 +362,7 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: ) ] - stack: list[LangchainHumanMessage] = [] - human_messages: list[LangchainHumanMessage] = [] - visualization_messages: list[VisualizationMessage] = [] - - for message in messages: - if isinstance(message, HumanMessage): - stack.append(LangchainHumanMessage(content=message.content)) - elif isinstance(message, VisualizationMessage) and message.answer: - if stack: - human_messages += merge_message_runs(stack) - stack = [] - visualization_messages.append(message) - - if stack: - human_messages += merge_message_runs(stack) - + human_messages, visualization_messages = filter_trends_conversation(messages) first_ai_message = True for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages): @@ -386,6 +394,13 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "") ) + if validation_error_message: + conversation.append( + HumanMessagePromptTemplate.from_template(trends_failover_prompt, template_format="mustache").format( + validation_error_message=validation_error_message + ) + ) + return conversation @classmethod @@ -404,4 +419,20 @@ class GenerateTrendsToolsNode(AssistantNode): name = AssistantNodeName.GENERATE_TRENDS_TOOLS def run(self, state: AssistantState, config: RunnableConfig): - return state + intermediate_steps = state.get("intermediate_steps", []) + if not intermediate_steps: + return state + + action, _ = intermediate_steps[-1] + prompt = ( + ChatPromptTemplate.from_template(trends_failover_output_prompt, template_format="mustache") + .format_messages(output=action.tool_input, exception_message=action.log)[0] + .content + ) + + return { + "intermediate_steps": [ + *intermediate_steps[:-1], + (action, prompt), + ] + } diff --git a/ee/hogai/trends/parsers.py b/ee/hogai/trends/parsers.py index 2d38cab8b2251..e66f974576939 100644 --- a/ee/hogai/trends/parsers.py +++ b/ee/hogai/trends/parsers.py @@ -3,6 +3,9 @@ from langchain_core.agents import AgentAction from langchain_core.messages import AIMessage as LangchainAIMessage +from pydantic import ValidationError + +from ee.hogai.trends.utils import GenerateTrendOutputModel class ReActParserException(ValueError): @@ -56,3 +59,22 @@ def parse_react_agent_output(message: LangchainAIMessage) -> AgentAction: # JSON does not contain an action. raise ReActParserMalformedJsonException(text) return AgentAction(response["action"], response.get("action_input", {}), text) + + +class PydanticOutputParserException(ValueError): + llm_output: str + """Serialized LLM output.""" + validation_message: str + """Pydantic validation error message.""" + + def __init__(self, llm_output: str, validation_message: str): + super().__init__(llm_output) + self.llm_output = llm_output + self.validation_message = validation_message + + +def parse_generated_trends_output(output: dict) -> GenerateTrendOutputModel: + try: + return GenerateTrendOutputModel.model_validate(output) + except ValidationError as e: + raise PydanticOutputParserException(llm_output=json.dumps(output), validation_message=e.json(include_url=False)) diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py index 38a83ba12bb9b..2543b1efc26e0 100644 --- a/ee/hogai/trends/prompts.py +++ b/ee/hogai/trends/prompts.py @@ -292,3 +292,23 @@ trends_question_prompt = """ Answer to this question: {{question}} """ + +trends_failover_output_prompt = """ +Generation output: +``` +{{output}} +``` + +Exception message: +``` +{{exception_message}} +``` +""" + +trends_failover_prompt = """ +The result of the previous generation raised the Pydantic validation exception. + +{{validation_error_message}} + +Fix the error and return the correct response. +""" diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index 1371213cfeab8..1e89c45458a9a 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -1,12 +1,26 @@ +import json from unittest.mock import patch from django.test import override_settings +from langchain_core.agents import AgentAction from langchain_core.messages import AIMessage as LangchainAIMessage from langchain_core.runnables import RunnableLambda -from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode -from ee.hogai.trends.parsers import AgentAction -from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage +from ee.hogai.trends.nodes import ( + CreateTrendsPlanNode, + CreateTrendsPlanToolsNode, + GenerateTrendsNode, + GenerateTrendsToolsNode, +) +from ee.hogai.trends.utils import GenerateTrendOutputModel +from ee.hogai.utils import AssistantNodeName +from posthog.schema import ( + AssistantMessage, + ExperimentalAITrendsQuery, + FailureMessage, + HumanMessage, + VisualizationMessage, +) from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person @@ -73,6 +87,22 @@ def test_agent_reconstructs_conversation_and_omits_unknown_messages(self): self.assertIn("Text", history[0].content) self.assertNotIn("{{question}}", history[0].content) + def test_agent_reconstructs_conversation_with_failures(self): + node = CreateTrendsPlanNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + ] + } + ) + self.assertEqual(len(history), 1) + self.assertEqual(history[0].type, "human") + self.assertIn("Text", history[0].content) + self.assertNotIn("{{question}}", history[0].content) + def test_agent_filters_out_low_count_events(self): _create_person(distinct_ids=["test"], team=self.team) for i in range(26): @@ -165,6 +195,27 @@ class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest): def setUp(self): self.schema = ExperimentalAITrendsQuery(series=[]) + def test_node_runs(self): + node = GenerateTrendsNode(self.team) + with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: + generator_model_mock.return_value = RunnableLambda( + lambda _: GenerateTrendOutputModel(reasoning_steps=["step"], answer=self.schema).model_dump() + ) + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "plan": "Plan", + }, + {}, + ) + self.assertEqual( + new_state, + { + "messages": [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])], + "intermediate_steps": None, + }, + ) + def test_agent_reconstructs_conversation(self): node = GenerateTrendsNode(self.team) history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")]}) @@ -276,3 +327,145 @@ def test_agent_reconstructs_conversation_and_merges_messages(self): self.assertIn("Answer to this question:", history[5].content) self.assertNotIn("{{question}}", history[5].content) self.assertIn("Follow\nUp", history[5].content) + + def test_failover_with_incorrect_schema(self): + node = GenerateTrendsNode(self.team) + with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: + schema = GenerateTrendOutputModel(reasoning_steps=[], answer=None).model_dump() + # Emulate an incorrect JSON. It should be an object. + schema["answer"] = [] + generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) + + new_state = node.run({"messages": [HumanMessage(content="Text")]}, {}) + self.assertIn("intermediate_steps", new_state) + self.assertEqual(len(new_state["intermediate_steps"]), 1) + + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], + }, + {}, + ) + self.assertIn("intermediate_steps", new_state) + self.assertEqual(len(new_state["intermediate_steps"]), 2) + + def test_node_leaves_failover(self): + node = GenerateTrendsNode(self.team) + with patch( + "ee.hogai.trends.nodes.GenerateTrendsNode._model", + return_value=RunnableLambda( + lambda _: GenerateTrendOutputModel(reasoning_steps=[], answer=self.schema).model_dump() + ), + ): + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], + }, + {}, + ) + self.assertIsNone(new_state["intermediate_steps"]) + + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [ + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + ], + }, + {}, + ) + self.assertIsNone(new_state["intermediate_steps"]) + + def test_node_leaves_failover_after_second_unsuccessful_attempt(self): + node = GenerateTrendsNode(self.team) + with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: + schema = GenerateTrendOutputModel(reasoning_steps=[], answer=None).model_dump() + # Emulate an incorrect JSON. It should be an object. + schema["answer"] = [] + generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) + + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [ + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + ], + }, + {}, + ) + self.assertIsNone(new_state["intermediate_steps"]) + self.assertEqual(len(new_state["messages"]), 1) + self.assertIsInstance(new_state["messages"][0], FailureMessage) + + def test_agent_reconstructs_conversation_with_failover(self): + action = AgentAction(tool="fix", tool_input="validation error", log="exception") + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [HumanMessage(content="Text")], + "plan": "randomplan", + "intermediate_steps": [(action, "uniqexception")], + }, + "uniqexception", + ) + self.assertEqual(len(history), 4) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Text", history[2].content) + self.assertEqual(history[3].type, "human") + self.assertIn("Pydantic", history[3].content) + self.assertIn("uniqexception", history[3].content) + + def test_agent_reconstructs_conversation_with_failed_messages(self): + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + ], + "plan": "randomplan", + }, + ) + self.assertEqual(len(history), 3) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Text", history[2].content) + + def test_router(self): + node = GenerateTrendsNode(self.team) + state = node.router({"messages": [], "intermediate_steps": None}) + self.assertEqual(state, AssistantNodeName.END) + state = node.router( + {"messages": [], "intermediate_steps": [(AgentAction(tool="", tool_input="", log=""), None)]} + ) + self.assertEqual(state, AssistantNodeName.GENERATE_TRENDS_TOOLS) + + +class TestGenerateTrendsToolsNode(ClickhouseTestMixin, APIBaseTest): + def test_tools_node(self): + node = GenerateTrendsToolsNode(self.team) + action = AgentAction(tool="fix", tool_input="validationerror", log="pydanticexception") + state = node.run({"messages": [], "intermediate_steps": [(action, None)]}, {}) + self.assertIsNotNone("validationerror", state["intermediate_steps"][0][1]) + self.assertIn("validationerror", state["intermediate_steps"][0][1]) + self.assertIn("pydanticexception", state["intermediate_steps"][0][1]) diff --git a/ee/hogai/trends/test/test_utils.py b/ee/hogai/trends/test/test_utils.py new file mode 100644 index 0000000000000..de9b8733129ec --- /dev/null +++ b/ee/hogai/trends/test/test_utils.py @@ -0,0 +1,39 @@ +from langchain_core.messages import HumanMessage as LangchainHumanMessage + +from ee.hogai.trends import utils +from posthog.schema import ExperimentalAITrendsQuery, FailureMessage, HumanMessage, VisualizationMessage +from posthog.test.base import BaseTest + + +class TestTrendsUtils(BaseTest): + def test_merge_human_messages(self): + res = utils.merge_human_messages( + [ + LangchainHumanMessage(content="Text"), + LangchainHumanMessage(content="Text"), + LangchainHumanMessage(content="Te"), + LangchainHumanMessage(content="xt"), + ] + ) + self.assertEqual(len(res), 1) + self.assertEqual(res, [LangchainHumanMessage(content="Text\nTe\nxt")]) + + def test_filter_trends_conversation(self): + human_messages, visualization_messages = utils.filter_trends_conversation( + [ + HumanMessage(content="Text"), + FailureMessage(content="Error"), + HumanMessage(content="Text"), + VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan"), + HumanMessage(content="Text2"), + VisualizationMessage(answer=None, plan="plan"), + ] + ) + self.assertEqual(len(human_messages), 2) + self.assertEqual(len(visualization_messages), 1) + self.assertEqual( + human_messages, [LangchainHumanMessage(content="Text"), LangchainHumanMessage(content="Text2")] + ) + self.assertEqual( + visualization_messages, [VisualizationMessage(answer=ExperimentalAITrendsQuery(series=[]), plan="plan")] + ) diff --git a/ee/hogai/trends/utils.py b/ee/hogai/trends/utils.py index 080f85f0256d0..5e1a8052707c8 100644 --- a/ee/hogai/trends/utils.py +++ b/ee/hogai/trends/utils.py @@ -1,10 +1,53 @@ +from collections.abc import Sequence from typing import Optional +from langchain_core.messages import HumanMessage as LangchainHumanMessage +from langchain_core.messages import merge_message_runs from pydantic import BaseModel -from posthog.schema import ExperimentalAITrendsQuery +from ee.hogai.utils import AssistantMessageUnion +from posthog.schema import ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage class GenerateTrendOutputModel(BaseModel): reasoning_steps: Optional[list[str]] = None answer: Optional[ExperimentalAITrendsQuery] = None + + +def merge_human_messages(messages: list[LangchainHumanMessage]) -> list[LangchainHumanMessage]: + """ + Filters out duplicated human messages and merges them into one message. + """ + contents = set() + filtered_messages = [] + for message in messages: + if message.content in contents: + continue + contents.add(message.content) + filtered_messages.append(message) + return merge_message_runs(filtered_messages) + + +def filter_trends_conversation( + messages: Sequence[AssistantMessageUnion], +) -> tuple[list[LangchainHumanMessage], list[VisualizationMessage]]: + """ + Splits, filters and merges the message history to be consumable by agents. + """ + stack: list[LangchainHumanMessage] = [] + human_messages: list[LangchainHumanMessage] = [] + visualization_messages: list[VisualizationMessage] = [] + + for message in messages: + if isinstance(message, HumanMessage): + stack.append(LangchainHumanMessage(content=message.content)) + elif isinstance(message, VisualizationMessage) and message.answer: + if stack: + human_messages += merge_human_messages(stack) + stack = [] + visualization_messages.append(message) + + if stack: + human_messages += merge_human_messages(stack) + + return human_messages, visualization_messages diff --git a/ee/hogai/utils.py b/ee/hogai/utils.py index 65de9303b3ffc..70d1ea969621a 100644 --- a/ee/hogai/utils.py +++ b/ee/hogai/utils.py @@ -10,9 +10,9 @@ from pydantic import BaseModel, Field from posthog.models.team.team import Team -from posthog.schema import AssistantMessage, HumanMessage, RootAssistantMessage, VisualizationMessage +from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, RootAssistantMessage, VisualizationMessage -AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage] +AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage, FailureMessage] class Conversation(BaseModel): @@ -24,7 +24,6 @@ class AssistantState(TypedDict): messages: Annotated[Sequence[AssistantMessageUnion], operator.add] intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]] plan: Optional[str] - tool_argument: Optional[str] class AssistantNodeName(StrEnum): diff --git a/ee/models/__init__.py b/ee/models/__init__.py index fd87f76bd54eb..ff5a8fe2dff40 100644 --- a/ee/models/__init__.py +++ b/ee/models/__init__.py @@ -5,7 +5,7 @@ from .hook import Hook from .license import License from .property_definition import EnterprisePropertyDefinition -from .role import Role, RoleMembership +from .rbac.role import Role, RoleMembership __all__ = [ "EnterpriseEventDefinition", diff --git a/ee/models/organization_resource_access.py b/ee/models/rbac/organization_resource_access.py similarity index 95% rename from ee/models/organization_resource_access.py rename to ee/models/rbac/organization_resource_access.py index 924b3e9db2855..de4c86d95a8bc 100644 --- a/ee/models/organization_resource_access.py +++ b/ee/models/rbac/organization_resource_access.py @@ -2,6 +2,8 @@ from posthog.models.organization import Organization +# NOTE: This will be deprecated in favour of the AccessControl model + class OrganizationResourceAccess(models.Model): class AccessLevel(models.IntegerChoices): diff --git a/ee/models/role.py b/ee/models/rbac/role.py similarity index 95% rename from ee/models/role.py rename to ee/models/rbac/role.py index f37170818dbc3..97201835adb1a 100644 --- a/ee/models/role.py +++ b/ee/models/rbac/role.py @@ -1,6 +1,6 @@ from django.db import models -from ee.models.organization_resource_access import OrganizationResourceAccess +from ee.models.rbac.organization_resource_access import OrganizationResourceAccess from posthog.models.utils import UUIDModel diff --git a/ee/urls.py b/ee/urls.py index f0cf168acffb0..7c722bc31852f 100644 --- a/ee/urls.py +++ b/ee/urls.py @@ -6,6 +6,7 @@ from django.urls.conf import path from ee.api import integration +from .api.rbac import organization_resource_access, role from .api import ( authentication, @@ -15,8 +16,6 @@ feature_flag_role_access, hooks, license, - organization_resource_access, - role, sentry_stats, subscription, ) @@ -49,6 +48,7 @@ def extend_api_router() -> None: "organization_role_memberships", ["organization_id", "role_id"], ) + # Start: routes to be deprecated project_feature_flags_router.register( r"role_access", feature_flag_role_access.FeatureFlagRoleAccessViewSet, @@ -61,6 +61,7 @@ def extend_api_router() -> None: "organization_resource_access", ["organization_id"], ) + # End: routes to be deprecated register_grandfathered_environment_nested_viewset(r"hooks", hooks.HookViewSet, "environment_hooks", ["team_id"]) register_grandfathered_environment_nested_viewset( r"explicit_members", diff --git a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png index d6fe88cab16e5..eedf3e01916a0 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png and b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png index 50e2dea81e948..c19e174167b7c 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png and b/frontend/__snapshots__/scenes-app-max-ai--empty-thread-loading--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png new file mode 100644 index 0000000000000..84a2a7828cdd5 Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png new file mode 100644 index 0000000000000..43ad8593f41e8 Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--generation-failure-thread--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png b/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png index d0f525ffb382c..80aded4c79433 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png and b/frontend/__snapshots__/scenes-app-max-ai--thread--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread--light.png b/frontend/__snapshots__/scenes-app-max-ai--thread--light.png index f8d64397cb918..22c88333171fd 100644 Binary files a/frontend/__snapshots__/scenes-app-max-ai--thread--light.png and b/frontend/__snapshots__/scenes-app-max-ai--thread--light.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--dark.png b/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--dark.png new file mode 100644 index 0000000000000..eccd3471a1f51 Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--dark.png differ diff --git a/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--light.png b/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--light.png new file mode 100644 index 0000000000000..da0efa6c70d73 Binary files /dev/null and b/frontend/__snapshots__/scenes-app-max-ai--thread-with-failed-generation--light.png differ diff --git a/frontend/__snapshots__/scenes-other-onboarding--onboarding-billing--dark.png b/frontend/__snapshots__/scenes-other-onboarding--onboarding-billing--dark.png index a11674466fae0..032e85e82d703 100644 Binary files a/frontend/__snapshots__/scenes-other-onboarding--onboarding-billing--dark.png and b/frontend/__snapshots__/scenes-other-onboarding--onboarding-billing--dark.png differ diff --git a/frontend/src/lib/constants.tsx b/frontend/src/lib/constants.tsx index 3d563ed32d7da..a13f3277ec41f 100644 --- a/frontend/src/lib/constants.tsx +++ b/frontend/src/lib/constants.tsx @@ -220,6 +220,8 @@ export const FEATURE_FLAGS = { LEGACY_ACTION_WEBHOOKS: 'legacy-action-webhooks', // owner: @mariusandra #team-cdp SESSION_REPLAY_URL_TRIGGER: 'session-replay-url-trigger', // owner: @richard-better #team-replay REPLAY_TEMPLATES: 'replay-templates', // owner: @raquelmsmith #team-replay + EXPERIMENTS_HOGQL: 'experiments-hogql', // owner: @jurajmajerik #team-experiments + ROLE_BASED_ACCESS_CONTROL: 'role-based-access-control', // owner: @zach EXPERIMENTS_HOLDOUTS: 'experiments-holdouts', // owner: @jurajmajerik #team-experiments MESSAGING: 'messaging', // owner @mariusandra #team-cdp SESSION_REPLAY_URL_BLOCKLIST: 'session-replay-url-blocklist', // owner: @richard-better #team-replay diff --git a/frontend/src/loadPostHogJS.tsx b/frontend/src/loadPostHogJS.tsx index 2c5a3285ef509..badabf1105246 100644 --- a/frontend/src/loadPostHogJS.tsx +++ b/frontend/src/loadPostHogJS.tsx @@ -29,16 +29,38 @@ export function loadPostHogJS(): void { bootstrap: window.POSTHOG_USER_IDENTITY_WITH_FLAGS ? window.POSTHOG_USER_IDENTITY_WITH_FLAGS : {}, opt_in_site_apps: true, api_transport: 'fetch', - loaded: (posthog) => { - if (posthog.sessionRecording) { - posthog.sessionRecording._forceAllowLocalhostNetworkCapture = true + loaded: (loadedInstance) => { + if (loadedInstance.sessionRecording) { + loadedInstance.sessionRecording._forceAllowLocalhostNetworkCapture = true } if (window.IMPERSONATED_SESSION) { - posthog.opt_out_capturing() + loadedInstance.sessionManager?.resetSessionId() + loadedInstance.opt_out_capturing() } else { - posthog.opt_in_capturing() + loadedInstance.opt_in_capturing() } + + const Cypress = (window as any).Cypress + + if (Cypress) { + Object.entries(Cypress.env()).forEach(([key, value]) => { + if (key.startsWith('POSTHOG_PROPERTY_')) { + loadedInstance.register_for_session({ + [key.replace('POSTHOG_PROPERTY_', 'E2E_TESTING_').toLowerCase()]: value, + }) + } + }) + } + + // This is a helpful flag to set to automatically reset the recording session on load for testing multiple recordings + const shouldResetSessionOnLoad = loadedInstance.getFeatureFlag(FEATURE_FLAGS.SESSION_RESET_ON_LOAD) + if (shouldResetSessionOnLoad) { + loadedInstance.sessionManager?.resetSessionId() + } + + // Make sure we have access to the object in window for debugging + window.posthog = loadedInstance }, scroll_root_selector: ['main', 'html'], autocapture: { @@ -52,26 +74,6 @@ export function loadPostHogJS(): void { : undefined, }) ) - - const Cypress = (window as any).Cypress - - if (Cypress) { - Object.entries(Cypress.env()).forEach(([key, value]) => { - if (key.startsWith('POSTHOG_PROPERTY_')) { - posthog.register_for_session({ - [key.replace('POSTHOG_PROPERTY_', 'E2E_TESTING_').toLowerCase()]: value, - }) - } - }) - } - - // This is a helpful flag to set to automatically reset the recording session on load for testing multiple recordings - const shouldResetSessionOnLoad = posthog.getFeatureFlag(FEATURE_FLAGS.SESSION_RESET_ON_LOAD) - if (shouldResetSessionOnLoad) { - posthog.sessionManager?.resetSessionId() - } - // Make sure we have access to the object in window for debugging - window.posthog = posthog } else { posthog.init('fake token', { autocapture: false, diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 75b26c88bc4de..5c28b52cf14e1 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -617,6 +617,24 @@ } ] }, + "AssistantEventType": { + "enum": ["status", "message"], + "type": "string" + }, + "AssistantGenerationStatusEvent": { + "additionalProperties": false, + "properties": { + "type": { + "$ref": "#/definitions/AssistantGenerationStatusType" + } + }, + "required": ["type"], + "type": "object" + }, + "AssistantGenerationStatusType": { + "enum": ["ack", "generation_error"], + "type": "string" + }, "AssistantMessage": { "additionalProperties": false, "properties": { @@ -632,7 +650,7 @@ "type": "object" }, "AssistantMessageType": { - "enum": ["human", "ai", "ai/viz"], + "enum": ["human", "ai", "ai/viz", "ai/failure"], "type": "string" }, "AutocompleteCompletionItem": { @@ -1315,12 +1333,25 @@ "expected_loss": { "type": "number" }, + "funnels_query": { + "$ref": "#/definitions/FunnelsQuery" + }, "insight": { - "$ref": "#/definitions/FunnelsQueryResponse" + "items": { + "items": { + "type": "object" + }, + "type": "array" + }, + "type": "array" }, "is_cached": { "type": "boolean" }, + "kind": { + "const": "ExperimentFunnelsQuery", + "type": "string" + }, "last_refresh": { "format": "date-time", "type": "string" @@ -1361,6 +1392,7 @@ "expected_loss", "insight", "is_cached", + "kind", "last_refresh", "next_allowed_client_refresh", "probability", @@ -1385,6 +1417,9 @@ "description": "What triggered the calculation of the query, leave empty if user/immediate", "type": "string" }, + "count_query": { + "$ref": "#/definitions/TrendsQuery" + }, "credible_intervals": { "additionalProperties": { "items": { @@ -1396,12 +1431,22 @@ }, "type": "object" }, + "exposure_query": { + "$ref": "#/definitions/TrendsQuery" + }, "insight": { - "$ref": "#/definitions/TrendsQueryResponse" + "items": { + "type": "object" + }, + "type": "array" }, "is_cached": { "type": "boolean" }, + "kind": { + "const": "ExperimentTrendsQuery", + "type": "string" + }, "last_refresh": { "format": "date-time", "type": "string" @@ -1444,6 +1489,7 @@ "credible_intervals", "insight", "is_cached", + "kind", "last_refresh", "next_allowed_client_refresh", "p_value", @@ -3732,8 +3778,21 @@ "expected_loss": { "type": "number" }, + "funnels_query": { + "$ref": "#/definitions/FunnelsQuery" + }, "insight": { - "$ref": "#/definitions/FunnelsQueryResponse" + "items": { + "items": { + "type": "object" + }, + "type": "array" + }, + "type": "array" + }, + "kind": { + "const": "ExperimentFunnelsQuery", + "type": "string" }, "probability": { "additionalProperties": { @@ -3758,6 +3817,7 @@ "credible_intervals", "expected_loss", "insight", + "kind", "probability", "significance_code", "significant", @@ -3768,6 +3828,9 @@ { "additionalProperties": false, "properties": { + "count_query": { + "$ref": "#/definitions/TrendsQuery" + }, "credible_intervals": { "additionalProperties": { "items": { @@ -3779,8 +3842,18 @@ }, "type": "object" }, + "exposure_query": { + "$ref": "#/definitions/TrendsQuery" + }, "insight": { - "$ref": "#/definitions/TrendsQueryResponse" + "items": { + "type": "object" + }, + "type": "array" + }, + "kind": { + "const": "ExperimentTrendsQuery", + "type": "string" }, "p_value": { "type": "number" @@ -3807,6 +3880,7 @@ "required": [ "credible_intervals", "insight", + "kind", "p_value", "probability", "significance_code", @@ -5215,8 +5289,21 @@ "expected_loss": { "type": "number" }, + "funnels_query": { + "$ref": "#/definitions/FunnelsQuery" + }, "insight": { - "$ref": "#/definitions/FunnelsQueryResponse" + "items": { + "items": { + "type": "object" + }, + "type": "array" + }, + "type": "array" + }, + "kind": { + "const": "ExperimentFunnelsQuery", + "type": "string" }, "probability": { "additionalProperties": { @@ -5238,6 +5325,7 @@ } }, "required": [ + "kind", "insight", "variants", "probability", @@ -5282,6 +5370,9 @@ "ExperimentTrendsQueryResponse": { "additionalProperties": false, "properties": { + "count_query": { + "$ref": "#/definitions/TrendsQuery" + }, "credible_intervals": { "additionalProperties": { "items": { @@ -5293,8 +5384,18 @@ }, "type": "object" }, + "exposure_query": { + "$ref": "#/definitions/TrendsQuery" + }, "insight": { - "$ref": "#/definitions/TrendsQueryResponse" + "items": { + "type": "object" + }, + "type": "array" + }, + "kind": { + "const": "ExperimentTrendsQuery", + "type": "string" }, "p_value": { "type": "number" @@ -5319,6 +5420,7 @@ } }, "required": [ + "kind", "insight", "variants", "probability", @@ -5450,6 +5552,20 @@ "required": ["kind", "series"], "type": "object" }, + "FailureMessage": { + "additionalProperties": false, + "properties": { + "content": { + "type": "string" + }, + "type": { + "const": "ai/failure", + "type": "string" + } + }, + "required": ["type"], + "type": "object" + }, "FeaturePropertyFilter": { "additionalProperties": false, "properties": { @@ -8962,8 +9078,21 @@ "expected_loss": { "type": "number" }, + "funnels_query": { + "$ref": "#/definitions/FunnelsQuery" + }, "insight": { - "$ref": "#/definitions/FunnelsQueryResponse" + "items": { + "items": { + "type": "object" + }, + "type": "array" + }, + "type": "array" + }, + "kind": { + "const": "ExperimentFunnelsQuery", + "type": "string" }, "probability": { "additionalProperties": { @@ -8985,6 +9114,7 @@ } }, "required": [ + "kind", "insight", "variants", "probability", @@ -8998,6 +9128,9 @@ { "additionalProperties": false, "properties": { + "count_query": { + "$ref": "#/definitions/TrendsQuery" + }, "credible_intervals": { "additionalProperties": { "items": { @@ -9009,8 +9142,18 @@ }, "type": "object" }, + "exposure_query": { + "$ref": "#/definitions/TrendsQuery" + }, "insight": { - "$ref": "#/definitions/TrendsQueryResponse" + "items": { + "type": "object" + }, + "type": "array" + }, + "kind": { + "const": "ExperimentTrendsQuery", + "type": "string" }, "p_value": { "type": "number" @@ -9035,6 +9178,7 @@ } }, "required": [ + "kind", "insight", "variants", "probability", @@ -9607,8 +9751,21 @@ "expected_loss": { "type": "number" }, + "funnels_query": { + "$ref": "#/definitions/FunnelsQuery" + }, "insight": { - "$ref": "#/definitions/FunnelsQueryResponse" + "items": { + "items": { + "type": "object" + }, + "type": "array" + }, + "type": "array" + }, + "kind": { + "const": "ExperimentFunnelsQuery", + "type": "string" }, "probability": { "additionalProperties": { @@ -9633,6 +9790,7 @@ "credible_intervals", "expected_loss", "insight", + "kind", "probability", "significance_code", "significant", @@ -9643,6 +9801,9 @@ { "additionalProperties": false, "properties": { + "count_query": { + "$ref": "#/definitions/TrendsQuery" + }, "credible_intervals": { "additionalProperties": { "items": { @@ -9654,8 +9815,18 @@ }, "type": "object" }, + "exposure_query": { + "$ref": "#/definitions/TrendsQuery" + }, "insight": { - "$ref": "#/definitions/TrendsQueryResponse" + "items": { + "type": "object" + }, + "type": "array" + }, + "kind": { + "const": "ExperimentTrendsQuery", + "type": "string" }, "p_value": { "type": "number" @@ -9682,6 +9853,7 @@ "required": [ "credible_intervals", "insight", + "kind", "p_value", "probability", "significance_code", @@ -10735,6 +10907,9 @@ }, { "$ref": "#/definitions/HumanMessage" + }, + { + "$ref": "#/definitions/FailureMessage" } ] }, diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index 273605a42f6d7..993b393eb7fd3 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -1621,7 +1621,10 @@ export enum ExperimentSignificanceCode { } export interface ExperimentTrendsQueryResponse { - insight: TrendsQueryResponse + kind: NodeKind.ExperimentTrendsQuery + insight: Record[] + count_query?: TrendsQuery + exposure_query?: TrendsQuery variants: ExperimentVariantTrendsBaseStats[] probability: Record significant: boolean @@ -1633,7 +1636,9 @@ export interface ExperimentTrendsQueryResponse { export type CachedExperimentTrendsQueryResponse = CachedQueryResponse export interface ExperimentFunnelsQueryResponse { - insight: FunnelsQueryResponse + kind: NodeKind.ExperimentFunnelsQuery + insight: Record[][] + funnels_query?: FunnelsQuery variants: ExperimentVariantFunnelsBaseStats[] probability: Record significant: boolean @@ -2096,6 +2101,7 @@ export enum AssistantMessageType { Human = 'human', Assistant = 'ai', Visualization = 'ai/viz', + Failure = 'ai/failure', } export interface HumanMessage { @@ -2115,4 +2121,23 @@ export interface VisualizationMessage { answer?: ExperimentalAITrendsQuery } -export type RootAssistantMessage = VisualizationMessage | AssistantMessage | HumanMessage +export interface FailureMessage { + type: AssistantMessageType.Failure + content?: string +} + +export type RootAssistantMessage = VisualizationMessage | AssistantMessage | HumanMessage | FailureMessage + +export enum AssistantEventType { + Status = 'status', + Message = 'message', +} + +export enum AssistantGenerationStatusType { + Acknowledged = 'ack', + GenerationError = 'generation_error', +} + +export interface AssistantGenerationStatusEvent { + type: AssistantGenerationStatusType +} diff --git a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx index 2d8bba2f256bc..95a525987b8a2 100644 --- a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx +++ b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx @@ -6,6 +6,7 @@ import api from 'lib/api' import posthog from 'posthog-js' import { preflightLogic } from 'scenes/PreflightCheck/preflightLogic' import { Scene } from 'scenes/sceneTypes' +import { teamLogic } from 'scenes/teamLogic' import { urls } from 'scenes/urls' import { @@ -16,6 +17,7 @@ import { manualLinkSources, ManualLinkSourceType, PipelineTab, + ProductKey, SourceConfig, SourceFieldConfig, } from '~/types' @@ -731,6 +733,8 @@ export const sourceWizardLogic = kea([ ['resetTable', 'createTableSuccess'], dataWarehouseSettingsLogic, ['loadSources'], + teamLogic, + ['addProductIntent'], ], }), reducers({ @@ -1129,6 +1133,9 @@ export const sourceWizardLogic = kea([ setManualLinkingProvider: () => { actions.onNext() }, + selectConnector: () => { + actions.addProductIntent({ product_type: ProductKey.DATA_WAREHOUSE, intent_context: 'selected connector' }) + }, })), urlToAction(({ actions }) => ({ '/data-warehouse/:kind/redirect': ({ kind = '' }, searchParams) => { diff --git a/frontend/src/scenes/experiments/Experiment.stories.tsx b/frontend/src/scenes/experiments/Experiment.stories.tsx index daab995ff3aa1..8d2aecd75ab1e 100644 --- a/frontend/src/scenes/experiments/Experiment.stories.tsx +++ b/frontend/src/scenes/experiments/Experiment.stories.tsx @@ -116,6 +116,7 @@ const MOCK_FUNNEL_EXPERIMENT: Experiment = { interval: 'day', filter_test_accounts: true, }, + metrics: [], archived: false, created_by: { id: 1, @@ -172,6 +173,7 @@ const MOCK_TREND_EXPERIMENT: Experiment = { }, }, }, + metrics: [], parameters: { feature_flag_variants: [ { @@ -277,6 +279,7 @@ const MOCK_TREND_EXPERIMENT_MANY_VARIANTS: Experiment = { }, }, }, + metrics: [], parameters: { feature_flag_variants: [ { diff --git a/frontend/src/scenes/experiments/ExperimentView/Goal.tsx b/frontend/src/scenes/experiments/ExperimentView/Goal.tsx index b0ef5701ec5c7..c68acb47f9df4 100644 --- a/frontend/src/scenes/experiments/ExperimentView/Goal.tsx +++ b/frontend/src/scenes/experiments/ExperimentView/Goal.tsx @@ -238,17 +238,16 @@ export function Goal(): JSX.Element { Change goal - {experimentInsightType === InsightType.TRENDS && - !experimentMathAggregationForTrends(experiment.filters) && ( - <> - -
-
- -
+ {experimentInsightType === InsightType.TRENDS && !experimentMathAggregationForTrends() && ( + <> + +
+
+
- - )} +
+ + )}
) diff --git a/frontend/src/scenes/experiments/ExperimentView/SecondaryMetricsTable.tsx b/frontend/src/scenes/experiments/ExperimentView/SecondaryMetricsTable.tsx index aadaebee40729..e9c19a72eb589 100644 --- a/frontend/src/scenes/experiments/ExperimentView/SecondaryMetricsTable.tsx +++ b/frontend/src/scenes/experiments/ExperimentView/SecondaryMetricsTable.tsx @@ -157,7 +157,6 @@ export function SecondaryMetricsTable({ experiment.secondary_metrics?.forEach((metric, idx) => { const targetResults = secondaryMetricResults?.[idx] - const targetResultFilters = targetResults?.filters const winningVariant = getHighestProbabilityVariant(targetResults || null) const Header = (): JSX.Element => ( @@ -206,7 +205,7 @@ export function SecondaryMetricsTable({ )} ] - {experimentMathAggregationForTrends(targetResultFilters) ? 'metric' : 'count'} + {experimentMathAggregationForTrends() ? 'metric' : 'count'} ), diff --git a/frontend/src/scenes/experiments/ExperimentView/SummaryTable.tsx b/frontend/src/scenes/experiments/ExperimentView/SummaryTable.tsx index e046d0f3a52fe..1814cb8717795 100644 --- a/frontend/src/scenes/experiments/ExperimentView/SummaryTable.tsx +++ b/frontend/src/scenes/experiments/ExperimentView/SummaryTable.tsx @@ -59,9 +59,7 @@ export function SummaryTable(): JSX.Element { {experimentResults.insight?.[0] && 'action' in experimentResults.insight[0] && ( )} - - {experimentMathAggregationForTrends(experimentResults?.filters) ? 'metric' : 'count'} - + {experimentMathAggregationForTrends() ? 'metric' : 'count'} ), render: function Key(_, variant): JSX.Element { diff --git a/frontend/src/scenes/experiments/ExperimentView/components.tsx b/frontend/src/scenes/experiments/ExperimentView/components.tsx index e61c32505c857..43a7b46e58f74 100644 --- a/frontend/src/scenes/experiments/ExperimentView/components.tsx +++ b/frontend/src/scenes/experiments/ExperimentView/components.tsx @@ -22,6 +22,7 @@ import { FEATURE_FLAGS } from 'lib/constants' import { dayjs } from 'lib/dayjs' import { IconAreaChart } from 'lib/lemon-ui/icons' import { More } from 'lib/lemon-ui/LemonButton/More' +import { featureFlagLogic } from 'lib/logic/featureFlagLogic' import { useEffect, useState } from 'react' import { urls } from 'scenes/urls' @@ -29,7 +30,15 @@ import { groupsModel } from '~/models/groupsModel' import { filtersToQueryNode } from '~/queries/nodes/InsightQuery/utils/filtersToQueryNode' import { queryFromFilters } from '~/queries/nodes/InsightViz/utils' import { Query } from '~/queries/Query/Query' -import { InsightVizNode, NodeKind } from '~/queries/schema' +import { + CachedExperimentFunnelsQueryResponse, + CachedExperimentTrendsQueryResponse, + ExperimentFunnelsQueryResponse, + ExperimentTrendsQueryResponse, + InsightQueryNode, + InsightVizNode, + NodeKind, +} from '~/queries/schema' import { Experiment, Experiment as ExperimentType, @@ -108,10 +117,54 @@ export function ResultsQuery({ targetResults, showTable, }: { - targetResults: ExperimentResults['result'] | null + targetResults: ExperimentResults['result'] | ExperimentTrendsQueryResponse | ExperimentFunnelsQueryResponse | null showTable: boolean }): JSX.Element { - if (!targetResults?.filters) { + const { featureFlags } = useValues(featureFlagLogic) + if (featureFlags[FEATURE_FLAGS.EXPERIMENTS_HOGQL]) { + const newQueryResults = targetResults as unknown as + | CachedExperimentTrendsQueryResponse + | CachedExperimentFunnelsQueryResponse + + const query = + newQueryResults.kind === NodeKind.ExperimentTrendsQuery + ? newQueryResults.count_query + : newQueryResults.funnels_query + const fakeInsightId = Math.random().toString(36).substring(2, 15) + + return ( + + ) + } + + const oldQueryResults = targetResults as ExperimentResults['result'] + + if (!oldQueryResults?.filters) { return <> } @@ -119,22 +172,22 @@ export function ResultsQuery({ }: { icon?: JSX.Element }): JSX.Element { - const { experimentResults, experiment } = useValues(experimentLogic) + const { experimentResults, experiment, featureFlags } = useValues(experimentLogic) // keep in sync with https://github.com/PostHog/posthog/blob/master/ee/clickhouse/queries/experiments/funnel_experiment_result.py#L71 // :TRICKY: In the case of no results, we still want users to explore the query, so they can debug further. @@ -160,18 +213,41 @@ export function ExploreButton({ icon = }: { icon?: JSX.Element properties: [], } - const query: InsightVizNode = { - kind: NodeKind.InsightVizNode, - source: filtersToQueryNode( - transformResultFilters( - experimentResults?.filters - ? { ...experimentResults.filters, explicit_date: true } - : filtersFromExperiment - ) - ), - showTable: true, - showLastComputation: true, - showLastComputationRefresh: false, + let query: InsightVizNode + if (featureFlags[FEATURE_FLAGS.EXPERIMENTS_HOGQL]) { + const newQueryResults = experimentResults as unknown as + | CachedExperimentTrendsQueryResponse + | CachedExperimentFunnelsQueryResponse + + const source = + newQueryResults.kind === NodeKind.ExperimentTrendsQuery + ? newQueryResults.count_query + : newQueryResults.funnels_query + + query = { + kind: NodeKind.InsightVizNode, + source: source as InsightQueryNode, + } + } else { + const oldQueryResults = experimentResults as ExperimentResults['result'] + + if (!oldQueryResults?.filters) { + return <> + } + + query = { + kind: NodeKind.InsightVizNode, + source: filtersToQueryNode( + transformResultFilters( + oldQueryResults?.filters + ? { ...oldQueryResults.filters, explicit_date: true } + : filtersFromExperiment + ) + ), + showTable: true, + showLastComputation: true, + showLastComputationRefresh: false, + } } return ( diff --git a/frontend/src/scenes/experiments/experimentLogic.tsx b/frontend/src/scenes/experiments/experimentLogic.tsx index 19e33aca83831..4db270269a634 100644 --- a/frontend/src/scenes/experiments/experimentLogic.tsx +++ b/frontend/src/scenes/experiments/experimentLogic.tsx @@ -5,6 +5,7 @@ import { loaders } from 'kea-loaders' import { router, urlToAction } from 'kea-router' import api from 'lib/api' import { EXPERIMENT_DEFAULT_DURATION, FunnelLayout } from 'lib/constants' +import { FEATURE_FLAGS } from 'lib/constants' import { dayjs } from 'lib/dayjs' import { lemonToast } from 'lib/lemon-ui/LemonToast/LemonToast' import { Tooltip } from 'lib/lemon-ui/Tooltip' @@ -27,7 +28,15 @@ import { cohortsModel } from '~/models/cohortsModel' import { groupsModel } from '~/models/groupsModel' import { filtersToQueryNode } from '~/queries/nodes/InsightQuery/utils/filtersToQueryNode' import { queryNodeToFilter } from '~/queries/nodes/InsightQuery/utils/queryNodeToFilter' -import { FunnelsQuery, InsightVizNode, TrendsQuery } from '~/queries/schema' +import { + CachedExperimentFunnelsQueryResponse, + CachedExperimentTrendsQueryResponse, + ExperimentTrendsQuery, + FunnelsQuery, + InsightVizNode, + NodeKind, + TrendsQuery, +} from '~/queries/schema' import { isFunnelsQuery } from '~/queries/utils' import { ActionFilter as ActionFilterType, @@ -62,6 +71,7 @@ const NEW_EXPERIMENT: Experiment = { name: '', feature_flag_key: '', filters: {}, + metrics: [], parameters: { feature_flag_variants: [ { key: 'control', rollout_percentage: 50 }, @@ -767,10 +777,36 @@ export const experimentLogic = kea([ }, }, experimentResults: [ - null as ExperimentResults['result'] | null, + null as + | ExperimentResults['result'] + | CachedExperimentTrendsQueryResponse + | CachedExperimentFunnelsQueryResponse + | null, { - loadExperimentResults: async (refresh?: boolean) => { + loadExperimentResults: async ( + refresh?: boolean + ): Promise< + | ExperimentResults['result'] + | CachedExperimentTrendsQueryResponse + | CachedExperimentFunnelsQueryResponse + | null + > => { try { + if (values.featureFlags[FEATURE_FLAGS.EXPERIMENTS_HOGQL]) { + const query = values.experiment.metrics[0].query + + const response: ExperimentResults = await api.create( + `api/projects/${values.currentTeamId}/query`, + { query } + ) + + return { + ...response, + fakeInsightId: Math.random().toString(36).substring(2, 15), + last_refresh: response.last_refresh || '', + } as unknown as CachedExperimentTrendsQueryResponse | CachedExperimentFunnelsQueryResponse + } + const refreshParam = refresh ? '?refresh=true' : '' const response: ExperimentResults = await api.get( `api/projects/${values.currentTeamId}/experiments/${values.experimentId}/results${refreshParam}` @@ -862,8 +898,13 @@ export const experimentLogic = kea([ (experimentId): Experiment['id'] => experimentId, ], experimentInsightType: [ - (s) => [s.experiment], - (experiment): InsightType => { + (s) => [s.experiment, s.featureFlags], + (experiment, featureFlags): InsightType => { + if (featureFlags[FEATURE_FLAGS.EXPERIMENTS_HOGQL]) { + const query = experiment?.metrics?.[0]?.query + return query?.kind === NodeKind.ExperimentTrendsQuery ? InsightType.TRENDS : InsightType.FUNNELS + } + return experiment?.filters?.insight || InsightType.FUNNELS }, ], @@ -909,31 +950,40 @@ export const experimentLogic = kea([ }, ], experimentMathAggregationForTrends: [ - () => [], - () => - (filters?: FilterType): PropertyMathType | CountPerActorMathType | undefined => { - // Find out if we're using count per actor math aggregates averages per user - const userMathValue = ( - [...(filters?.events || []), ...(filters?.actions || [])] as ActionFilterType[] - ).filter((entity) => - Object.values(CountPerActorMathType).includes(entity?.math as CountPerActorMathType) - )[0]?.math - - // alternatively, if we're using property math - // remove 'sum' property math from the list of math types - // since we can handle that as a regular case - const targetValues = Object.values(PropertyMathType).filter( - (value) => value !== PropertyMathType.Sum - ) - // sync with the backend at https://github.com/PostHog/posthog/blob/master/ee/clickhouse/queries/experiments/trend_experiment_result.py#L44 - // the function uses_math_aggregation_by_user_or_property_value + (s) => [s.experiment, s.featureFlags], + (experiment, featureFlags) => (): PropertyMathType | CountPerActorMathType | undefined => { + let entities: { math?: string }[] = [] + + if (featureFlags[FEATURE_FLAGS.EXPERIMENTS_HOGQL]) { + const query = experiment?.metrics?.[0]?.query as ExperimentTrendsQuery + if (!query) { + return undefined + } + entities = query.count_query?.series || [] + } else { + const filters = experiment?.filters + if (!filters) { + return undefined + } + entities = [...(filters?.events || []), ...(filters?.actions || [])] as ActionFilterType[] + } - const propertyMathValue = ( - [...(filters?.events || []), ...(filters?.actions || [])] as ActionFilterType[] - ).filter((entity) => targetValues.includes(entity?.math as PropertyMathType))[0]?.math + // Find out if we're using count per actor math aggregates averages per user + const userMathValue = entities.filter((entity) => + Object.values(CountPerActorMathType).includes(entity?.math as CountPerActorMathType) + )[0]?.math - return (userMathValue ?? propertyMathValue) as PropertyMathType | CountPerActorMathType | undefined - }, + // alternatively, if we're using property math + // remove 'sum' property math from the list of math types + // since we can handle that as a regular case + const targetValues = Object.values(PropertyMathType).filter((value) => value !== PropertyMathType.Sum) + + const propertyMathValue = entities.filter((entity) => + targetValues.includes(entity?.math as PropertyMathType) + )[0]?.math + + return (userMathValue ?? propertyMathValue) as PropertyMathType | CountPerActorMathType | undefined + }, ], minimumDetectableEffect: [ (s) => [s.experiment, s.experimentInsightType, s.conversionMetrics, s.trendResults], @@ -1126,7 +1176,14 @@ export const experimentLogic = kea([ conversionRateForVariant: [ () => [], () => - (experimentResults: Partial | null, variantKey: string): number | null => { + ( + experimentResults: + | Partial + | CachedExperimentFunnelsQueryResponse + | CachedExperimentTrendsQueryResponse + | null, + variantKey: string + ): number | null => { if (!experimentResults || !experimentResults.insight) { return null } @@ -1144,34 +1201,47 @@ export const experimentLogic = kea([ }, ], getIndexForVariant: [ - () => [], - () => - (experimentResults: Partial | null, variant: string): number | null => { - // TODO: Would be nice for every secondary metric to have the same colour for variants - const insightType = experimentResults?.filters?.insight - let result: number | null = null + (s) => [s.experimentInsightType], + (experimentInsightType) => + ( + experimentResults: + | Partial + | CachedExperimentTrendsQueryResponse + | CachedExperimentFunnelsQueryResponse + | null, + variant: string + ): number | null => { // Ensures we get the right index from results, so the UI can // display the right colour for the variant if (!experimentResults || !experimentResults.insight) { return null } + let index = -1 - if (insightType === InsightType.FUNNELS) { + if (experimentInsightType === InsightType.FUNNELS) { // Funnel Insight is displayed in order of decreasing count - index = ([...experimentResults.insight] as FunnelStep[][]) - .sort((a, b) => b[0]?.count - a[0]?.count) - .findIndex( - (variantFunnel: FunnelStep[]) => variantFunnel[0]?.breakdown_value?.[0] === variant - ) + index = (Array.isArray(experimentResults.insight) ? [...experimentResults.insight] : []) + .sort((a, b) => { + const aCount = (a && Array.isArray(a) && a[0]?.count) || 0 + const bCount = (b && Array.isArray(b) && b[0]?.count) || 0 + return bCount - aCount + }) + .findIndex((variantFunnel) => { + if (!Array.isArray(variantFunnel) || !variantFunnel[0]?.breakdown_value) { + return false + } + const breakdownValue = variantFunnel[0].breakdown_value + return Array.isArray(breakdownValue) && breakdownValue[0] === variant + }) } else { index = (experimentResults.insight as TrendResult[]).findIndex( (variantTrend: TrendResult) => variantTrend.breakdown_value === variant ) } - result = index === -1 ? null : index + const result = index === -1 ? null : index - if (result !== null && insightType === InsightType.FUNNELS) { - result++ + if (result !== null && experimentInsightType === InsightType.FUNNELS) { + return result + 1 } return result }, @@ -1179,10 +1249,15 @@ export const experimentLogic = kea([ countDataForVariant: [ (s) => [s.experimentMathAggregationForTrends], (experimentMathAggregationForTrends) => - (experimentResults: Partial | null, variant: string): number | null => { - const usingMathAggregationType = experimentMathAggregationForTrends( - experimentResults?.filters || {} - ) + ( + experimentResults: + | Partial + | CachedExperimentTrendsQueryResponse + | CachedExperimentFunnelsQueryResponse + | null, + variant: string + ): number | null => { + const usingMathAggregationType = experimentMathAggregationForTrends() if (!experimentResults || !experimentResults.insight) { return null } @@ -1223,7 +1298,14 @@ export const experimentLogic = kea([ exposureCountDataForVariant: [ () => [], () => - (experimentResults: Partial | null, variant: string): number | null => { + ( + experimentResults: + | Partial + | CachedExperimentTrendsQueryResponse + | CachedExperimentFunnelsQueryResponse + | null, + variant: string + ): number | null => { if (!experimentResults || !experimentResults.variants) { return null } @@ -1241,14 +1323,21 @@ export const experimentLogic = kea([ ], getHighestProbabilityVariant: [ () => [], - () => (results: ExperimentResults['result'] | null) => { - if (results && results.probability) { - const maxValue = Math.max(...Object.values(results.probability)) - return Object.keys(results.probability).find( - (key) => Math.abs(results.probability[key] - maxValue) < Number.EPSILON - ) - } - }, + () => + ( + results: + | ExperimentResults['result'] + | CachedExperimentTrendsQueryResponse + | CachedExperimentFunnelsQueryResponse + | null + ) => { + if (results && results.probability) { + const maxValue = Math.max(...Object.values(results.probability)) + return Object.keys(results.probability).find( + (key) => Math.abs(results.probability[key] - maxValue) < Number.EPSILON + ) + } + }, ], sortedExperimentResultVariants: [ (s) => [s.experimentResults, s.experiment], diff --git a/frontend/src/scenes/max/Max.stories.tsx b/frontend/src/scenes/max/Max.stories.tsx index 27045963c6e42..88ffe32ea7559 100644 --- a/frontend/src/scenes/max/Max.stories.tsx +++ b/frontend/src/scenes/max/Max.stories.tsx @@ -1,10 +1,10 @@ import { Meta, StoryFn } from '@storybook/react' -import { BindLogic, useActions } from 'kea' +import { BindLogic, useActions, useValues } from 'kea' import { useEffect } from 'react' import { mswDecorator, useStorybookMocks } from '~/mocks/browser' -import { chatResponseChunk } from './__mocks__/chatResponse.mocks' +import { chatResponseChunk, failureChunk, generationFailureChunk } from './__mocks__/chatResponse.mocks' import { MaxInstance } from './Max' import { maxLogic } from './maxLogic' @@ -104,3 +104,43 @@ EmptyThreadLoading.parameters = { waitForLoadersToDisappear: false, }, } + +export const GenerationFailureThread: StoryFn = () => { + useStorybookMocks({ + post: { + '/api/environments/:team_id/query/chat/': (_, res, ctx) => res(ctx.text(generationFailureChunk)), + }, + }) + + const sessionId = 'd210b263-8521-4c5b-b3c4-8e0348df574b' + + const { askMax, setMessageStatus } = useActions(maxLogic({ sessionId })) + const { thread, threadLoading } = useValues(maxLogic({ sessionId })) + useEffect(() => { + askMax('What are my most popular pages?') + }, []) + useEffect(() => { + if (thread.length === 2 && !threadLoading) { + setMessageStatus(1, 'error') + } + }, [thread.length, threadLoading]) + + return