From 89a1579326251deac28fd98546253cc521a73975 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Fri, 25 Oct 2024 14:50:20 +0200 Subject: [PATCH 01/22] feat: fail over --- ee/hogai/assistant.py | 14 +++++++++++++- ee/hogai/trends/nodes.py | 33 +++++++++++++++++++++++++++------ ee/hogai/trends/prompts.py | 10 ++++++++++ ee/hogai/utils.py | 1 - 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index d1aa9656257a9..539baedf8bd1d 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -6,7 +6,12 @@ from langgraph.graph.state import StateGraph from ee import settings -from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode +from ee.hogai.trends.nodes import ( + CreateTrendsPlanNode, + CreateTrendsPlanToolsNode, + GenerateTrendsNode, + GenerateTrendsToolsNode, +) from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation from posthog.models.team.team import Team from posthog.schema import VisualizationMessage @@ -59,6 +64,10 @@ def _compile_graph(self): generate_trends_node = GenerateTrendsNode(self._team) builder.add_node(GenerateTrendsNode.name, generate_trends_node.run) + generate_trends_tools_node = GenerateTrendsToolsNode(self._team) + builder.add_node(GenerateTrendsToolsNode.name, generate_trends_tools_node.run) + builder.add_edge(GenerateTrendsToolsNode.name, GenerateTrendsNode.name) + builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name) builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router) builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router) @@ -99,3 +108,6 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: yield VisualizationMessage( reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer ).model_dump_json() + # elif state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS_TOOLS: + # # Reset tool output parser when encountered a validation error + # chunks = AIMessageChunk(content="") diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index 4727ff07f4f78..f30d52dad62c1 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -2,7 +2,7 @@ import json import xml.etree.ElementTree as ET from functools import cached_property -from typing import Union, cast +from typing import Optional, Union, cast from langchain.agents.format_scratchpad import format_log_to_str from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser @@ -24,6 +24,7 @@ react_scratchpad_prompt, react_system_prompt, react_user_prompt, + trends_failover_prompt, trends_group_mapping_prompt, trends_new_plan_prompt, trends_plan_prompt, @@ -236,13 +237,15 @@ class GenerateTrendsNode(AssistantNode): def run(self, state: AssistantState, config: RunnableConfig): generated_plan = state.get("plan", "") + intermediate_steps = state.get("intermediate_steps", []) + validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None trends_generation_prompt = ChatPromptTemplate.from_messages( [ ("system", trends_system_prompt), ], template_format="mustache", - ) + self._reconstruct_conversation(state) + ) + self._reconstruct_conversation(state, validation_error_message=validation_error_message) merger = merge_message_runs() chain = ( @@ -257,9 +260,14 @@ def run(self, state: AssistantState, config: RunnableConfig): try: message: GenerateTrendOutputModel = chain.invoke({}, config) - except OutputParserException: + except OutputParserException as e: + if e.send_to_llm: + observation = str(e.observation) + else: + observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question." + return { - "messages": [VisualizationMessage(plan=generated_plan, reasoning_steps=["Schema validation failed"])] + "intermediate_steps": [(AgentAction("handle_incorrect_response", observation, str(e)), None)], } return { @@ -297,7 +305,9 @@ def _group_mapping_prompt(self) -> str: ) return ET.tostring(root, encoding="unicode") - def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: + def _reconstruct_conversation( + self, state: AssistantState, validation_error_message: Optional[str] = None + ) -> list[BaseMessage]: """ Reconstruct the conversation for the generation. Take all previously generated questions, plans, and schemas, and return the history. """ @@ -360,6 +370,13 @@ def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "") ) + if validation_error_message: + conversation.append( + HumanMessagePromptTemplate.from_template(trends_failover_prompt, template_format="mustache").format( + exception_message=validation_error_message + ) + ) + return conversation @classmethod @@ -378,4 +395,8 @@ class GenerateTrendsToolsNode(AssistantNode): name = AssistantNodeName.GENERATE_TRENDS_TOOLS def run(self, state: AssistantState, config: RunnableConfig): - return state + intermediate_steps = state.get("intermediate_steps", []) + if not intermediate_steps: + return state + action, _ = intermediate_steps[-1] + return {"intermediate_steps": (action, action.log)} diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py index c53ae5d3453a5..bb644922e3045 100644 --- a/ee/hogai/trends/prompts.py +++ b/ee/hogai/trends/prompts.py @@ -269,3 +269,13 @@ trends_question_prompt = """ Answer to this question: {{question}} """ + +trends_failover_prompt = """ +The result of your previous generatin raised the Pydantic validation exception: + +``` +{{exception_message}} +``` + +Fix the error and return the correct response. +""" diff --git a/ee/hogai/utils.py b/ee/hogai/utils.py index 65de9303b3ffc..3f7712b8ac82c 100644 --- a/ee/hogai/utils.py +++ b/ee/hogai/utils.py @@ -24,7 +24,6 @@ class AssistantState(TypedDict): messages: Annotated[Sequence[AssistantMessageUnion], operator.add] intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]] plan: Optional[str] - tool_argument: Optional[str] class AssistantNodeName(StrEnum): From 680ea43616ad237ffc408cd18d4c2923aec913ec Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Wed, 23 Oct 2024 19:06:15 +0200 Subject: [PATCH 02/22] fix: fallback streaming --- ee/hogai/assistant.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index 539baedf8bd1d..bd54f01ed56a8 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -44,6 +44,13 @@ def is_message_update( return len(update) == 2 and update[0] == "messages" +def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]: + """ + Update of the state. + """ + return len(update) == 2 and update[0] == "values" + + class Assistant: _team: Team _graph: StateGraph @@ -80,23 +87,33 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: callbacks = [langfuse_handler] if langfuse_handler else [] messages = [message.root for message in conversation.messages] + chunks = AIMessageChunk(content="") + state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None} + generator = assistant_graph.stream( - {"messages": messages}, + state, config={"recursion_limit": 24, "callbacks": callbacks}, - stream_mode=["messages", "updates"], + stream_mode=["messages", "values", "updates"], ) - chunks = AIMessageChunk(content="") - for update in generator: - if is_value_update(update): + if is_state_update(update): + _, new_state = update + state = new_state + + elif is_value_update(update): _, state_update = update - if ( - AssistantNodeName.GENERATE_TRENDS in state_update - and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS] - ): - message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]) - yield message.model_dump_json() + + if AssistantNodeName.GENERATE_TRENDS in state_update: + # Reset chunks when schema validation fails. + chunks = AIMessageChunk(content="") + + if "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]: + message = cast( + VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0] + ) + yield message.model_dump_json() + elif is_message_update(update): langchain_message, langgraph_state = update[1] if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance( @@ -108,6 +125,3 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: yield VisualizationMessage( reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer ).model_dump_json() - # elif state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS_TOOLS: - # # Reset tool output parser when encountered a validation error - # chunks = AIMessageChunk(content="") From c19028b1fcdc1b0cc2d1ff4d3073788089927c4f Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Fri, 25 Oct 2024 15:18:59 +0200 Subject: [PATCH 03/22] test: assistant test --- ee/hogai/test/__init__.py | 0 ee/hogai/test/test_assistant.py | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 ee/hogai/test/__init__.py create mode 100644 ee/hogai/test/test_assistant.py diff --git a/ee/hogai/test/__init__.py b/ee/hogai/test/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py new file mode 100644 index 0000000000000..b2eea5a217ec5 --- /dev/null +++ b/ee/hogai/test/test_assistant.py @@ -0,0 +1,45 @@ +import json +from unittest.mock import patch + +from django.test import override_settings +from langchain_core.runnables import RunnableLambda + +from ee.hogai.assistant import Assistant +from ee.hogai.trends.utils import GenerateTrendOutputModel +from ee.hogai.utils import Conversation +from posthog.schema import HumanMessage, VisualizationMessage +from posthog.test.base import ( + NonAtomicBaseTest, +) + + +@override_settings(IN_UNIT_TESTING=True) +class TestAssistant(NonAtomicBaseTest): + def test_assistant(self): + mocked_planner_response = """ + Action: + ``` + {"action": "final_answer", "action_input": "Plan"} + ``` + """ + generator_response = GenerateTrendOutputModel(reasoning_steps=[], answer=None) + with ( + patch( + "ee.hogai.trends.nodes.CreateTrendsPlanNode._model", + return_value=RunnableLambda(lambda _: mocked_planner_response), + ) as planner_model_mock, + patch( + "ee.hogai.trends.nodes.GenerateTrendsNode._model", + return_value=RunnableLambda(lambda _: generator_response.model_dump()), + ) as generator_model_mock, + ): + assistant = Assistant(self.team) + generator = assistant.stream( + Conversation(messages=[HumanMessage(content="Launch the chain.")], session_id="id") + ) + self.assertEqual( + json.loads(next(generator)), + VisualizationMessage(answer=None, reasoning_steps=[], plan="Plan").model_dump(), + ) + self.assertEqual(planner_model_mock.call_count, 1) + self.assertEqual(generator_model_mock.call_count, 1) From 88409eeb18a2e75e182d63bf832000d83a1ba4ef Mon Sep 17 00:00:00 2001 From: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:23:58 +0000 Subject: [PATCH 04/22] Update query snapshots --- posthog/api/test/__snapshots__/test_api_docs.ambr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/posthog/api/test/__snapshots__/test_api_docs.ambr b/posthog/api/test/__snapshots__/test_api_docs.ambr index 6ef31c6530176..a5f9b394809ae 100644 --- a/posthog/api/test/__snapshots__/test_api_docs.ambr +++ b/posthog/api/test/__snapshots__/test_api_docs.ambr @@ -97,8 +97,8 @@ '/home/runner/work/posthog/posthog/posthog/api/survey.py: Warning [SurveyViewSet > SurveySerializer]: unable to resolve type hint for function "get_conditions". Consider using a type hint or @extend_schema_field. Defaulting to string.', '/home/runner/work/posthog/posthog/posthog/api/web_experiment.py: Warning [WebExperimentViewSet]: could not derive type of path parameter "project_id" because model "posthog.models.web_experiment.WebExperiment" contained no such field. Consider annotating parameter with @extend_schema. Defaulting to "string".', 'Warning: encountered multiple names for the same choice set (HrefMatchingEnum). This may be unwanted even though the generated schema is technically correct. Add an entry to ENUM_NAME_OVERRIDES to fix the naming.', - 'Warning: enum naming encountered a non-optimally resolvable collision for fields named "kind". The same name has been used for multiple choice sets in multiple components. The collision was resolved with "Kind069Enum". add an entry to ENUM_NAME_OVERRIDES to fix the naming.', 'Warning: enum naming encountered a non-optimally resolvable collision for fields named "kind". The same name has been used for multiple choice sets in multiple components. The collision was resolved with "KindCfaEnum". add an entry to ENUM_NAME_OVERRIDES to fix the naming.', + 'Warning: enum naming encountered a non-optimally resolvable collision for fields named "kind". The same name has been used for multiple choice sets in multiple components. The collision was resolved with "Kind069Enum". add an entry to ENUM_NAME_OVERRIDES to fix the naming.', 'Warning: enum naming encountered a non-optimally resolvable collision for fields named "type". The same name has been used for multiple choice sets in multiple components. The collision was resolved with "TypeF73Enum". add an entry to ENUM_NAME_OVERRIDES to fix the naming.', 'Warning: encountered multiple names for the same choice set (EffectivePrivilegeLevelEnum). This may be unwanted even though the generated schema is technically correct. Add an entry to ENUM_NAME_OVERRIDES to fix the naming.', 'Warning: encountered multiple names for the same choice set (MembershipLevelEnum). This may be unwanted even though the generated schema is technically correct. Add an entry to ENUM_NAME_OVERRIDES to fix the naming.', From 30a5cc630928f3f094bb3395b6f39ab66af3d341 Mon Sep 17 00:00:00 2001 From: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:36:40 +0000 Subject: [PATCH 05/22] Update query snapshots --- .../test/__snapshots__/test_trends.ambr | 204 +++++++++++++++--- 1 file changed, 172 insertions(+), 32 deletions(-) diff --git a/posthog/hogql_queries/insights/trends/test/__snapshots__/test_trends.ambr b/posthog/hogql_queries/insights/trends/test/__snapshots__/test_trends.ambr index 6027f7ca7bb42..4ae57feb8cb96 100644 --- a/posthog/hogql_queries/insights/trends/test/__snapshots__/test_trends.ambr +++ b/posthog/hogql_queries/insights/trends/test/__snapshots__/test_trends.ambr @@ -851,14 +851,49 @@ # --- # name: TestTrends.test_dau_with_breakdown_filtering_with_sampling.1 ''' - /* celery:posthog.tasks.tasks.sync_insight_caching_state */ - SELECT team_id, - date_diff('second', max(timestamp), now()) AS age - FROM events - WHERE timestamp > date_sub(DAY, 3, now()) - AND timestamp < now() - GROUP BY team_id - ORDER BY age; + SELECT groupArray(1)(date)[1] AS date, + arrayFold((acc, x) -> arrayMap(i -> plus(acc[i], x[i]), range(1, plus(length(date), 1))), groupArray(ifNull(total, 0)), arrayWithConstant(length(date), reinterpretAsFloat64(0))) AS total, + if(ifNull(ifNull(greaterOrEquals(row_number, 25), 0), 0), '$$_posthog_breakdown_other_$$', breakdown_value) AS breakdown_value + FROM + (SELECT arrayMap(number -> plus(toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC'))), toIntervalDay(number)), range(0, plus(coalesce(dateDiff('day', toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC'))), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))))), 1))) AS date, + arrayMap(_match_date -> arraySum(arraySlice(groupArray(ifNull(count, 0)), indexOf(groupArray(day_start) AS _days_for_count, _match_date) AS _index, plus(minus(arrayLastIndex(x -> ifNull(equals(x, _match_date), isNull(x) + and isNull(_match_date)), _days_for_count), _index), 1))), date) AS total, + breakdown_value AS breakdown_value, + rowNumberInAllBlocks() AS row_number + FROM + (SELECT sum(total) AS count, + day_start AS day_start, + breakdown_value AS breakdown_value + FROM + (SELECT count(DISTINCT if(not(empty(e__override.distinct_id)), e__override.person_id, e.person_id)) AS total, + toStartOfDay(toTimeZone(e.timestamp, 'UTC')) AS day_start, + ifNull(nullIf(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', '')), ''), '$$_posthog_breakdown_null_$$') AS breakdown_value + FROM events AS e SAMPLE 1.0 + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 2) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS e__override ON equals(e.distinct_id, e__override.distinct_id) + WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))), equals(e.event, 'sign up')) + GROUP BY day_start, + breakdown_value) + GROUP BY day_start, + breakdown_value + ORDER BY day_start ASC, breakdown_value ASC) + GROUP BY breakdown_value + ORDER BY if(ifNull(equals(breakdown_value, '$$_posthog_breakdown_other_$$'), 0), 2, if(ifNull(equals(breakdown_value, '$$_posthog_breakdown_null_$$'), 0), 1, 0)) ASC, arraySum(total) DESC, breakdown_value ASC) + WHERE isNotNull(breakdown_value) + GROUP BY breakdown_value + ORDER BY if(ifNull(equals(breakdown_value, '$$_posthog_breakdown_other_$$'), 0), 2, if(ifNull(equals(breakdown_value, '$$_posthog_breakdown_null_$$'), 0), 1, 0)) ASC, arraySum(total) DESC, breakdown_value ASC + LIMIT 50000 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 ''' # --- # name: TestTrends.test_dau_with_breakdown_filtering_with_sampling.10 @@ -1075,38 +1110,143 @@ # --- # name: TestTrends.test_dau_with_breakdown_filtering_with_sampling.2 ''' - /* celery:posthog.tasks.tasks.sync_insight_caching_state */ - SELECT team_id, - date_diff('second', max(timestamp), now()) AS age - FROM events - WHERE timestamp > date_sub(DAY, 3, now()) - AND timestamp < now() - GROUP BY team_id - ORDER BY age; + SELECT groupArray(1)(date)[1] AS date, + arrayFold((acc, x) -> arrayMap(i -> plus(acc[i], x[i]), range(1, plus(length(date), 1))), groupArray(ifNull(total, 0)), arrayWithConstant(length(date), reinterpretAsFloat64(0))) AS total, + if(ifNull(ifNull(greaterOrEquals(row_number, 25), 0), 0), '$$_posthog_breakdown_other_$$', breakdown_value) AS breakdown_value + FROM + (SELECT arrayMap(number -> plus(toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC'))), toIntervalDay(number)), range(0, plus(coalesce(dateDiff('day', toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC'))), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))))), 1))) AS date, + arrayMap(_match_date -> arraySum(arraySlice(groupArray(ifNull(count, 0)), indexOf(groupArray(day_start) AS _days_for_count, _match_date) AS _index, plus(minus(arrayLastIndex(x -> ifNull(equals(x, _match_date), isNull(x) + and isNull(_match_date)), _days_for_count), _index), 1))), date) AS total, + breakdown_value AS breakdown_value, + rowNumberInAllBlocks() AS row_number + FROM + (SELECT sum(total) AS count, + day_start AS day_start, + breakdown_value AS breakdown_value + FROM + (SELECT count(DISTINCT if(not(empty(e__override.distinct_id)), e__override.person_id, e.person_id)) AS total, + toStartOfDay(toTimeZone(e.timestamp, 'UTC')) AS day_start, + ifNull(nullIf(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', '')), ''), '$$_posthog_breakdown_null_$$') AS breakdown_value + FROM events AS e SAMPLE 1.0 + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 2) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS e__override ON equals(e.distinct_id, e__override.distinct_id) + WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))), equals(e.event, 'sign up')) + GROUP BY day_start, + breakdown_value) + GROUP BY day_start, + breakdown_value + ORDER BY day_start ASC, breakdown_value ASC) + GROUP BY breakdown_value + ORDER BY if(ifNull(equals(breakdown_value, '$$_posthog_breakdown_other_$$'), 0), 2, if(ifNull(equals(breakdown_value, '$$_posthog_breakdown_null_$$'), 0), 1, 0)) ASC, arraySum(total) DESC, breakdown_value ASC) + WHERE isNotNull(breakdown_value) + GROUP BY breakdown_value + ORDER BY if(ifNull(equals(breakdown_value, '$$_posthog_breakdown_other_$$'), 0), 2, if(ifNull(equals(breakdown_value, '$$_posthog_breakdown_null_$$'), 0), 1, 0)) ASC, arraySum(total) DESC, breakdown_value ASC + LIMIT 50000 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 ''' # --- # name: TestTrends.test_dau_with_breakdown_filtering_with_sampling.3 ''' - /* celery:posthog.tasks.tasks.sync_insight_caching_state */ - SELECT team_id, - date_diff('second', max(timestamp), now()) AS age - FROM events - WHERE timestamp > date_sub(DAY, 3, now()) - AND timestamp < now() - GROUP BY team_id - ORDER BY age; + SELECT groupArray(1)(date)[1] AS date, + arrayFold((acc, x) -> arrayMap(i -> plus(acc[i], x[i]), range(1, plus(length(date), 1))), groupArray(ifNull(total, 0)), arrayWithConstant(length(date), reinterpretAsFloat64(0))) AS total, + arrayMap(i -> if(ifNull(ifNull(greaterOrEquals(row_number, 25), 0), 0), '$$_posthog_breakdown_other_$$', i), breakdown_value) AS breakdown_value + FROM + (SELECT arrayMap(number -> plus(toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC'))), toIntervalDay(number)), range(0, plus(coalesce(dateDiff('day', toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC'))), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))))), 1))) AS date, + arrayMap(_match_date -> arraySum(arraySlice(groupArray(ifNull(count, 0)), indexOf(groupArray(day_start) AS _days_for_count, _match_date) AS _index, plus(minus(arrayLastIndex(x -> ifNull(equals(x, _match_date), isNull(x) + and isNull(_match_date)), _days_for_count), _index), 1))), date) AS total, + breakdown_value AS breakdown_value, + rowNumberInAllBlocks() AS row_number + FROM + (SELECT sum(total) AS count, + day_start AS day_start, + [ifNull(toString(breakdown_value_1), '$$_posthog_breakdown_null_$$')] AS breakdown_value + FROM + (SELECT count(DISTINCT if(not(empty(e__override.distinct_id)), e__override.person_id, e.person_id)) AS total, + toStartOfDay(toTimeZone(e.timestamp, 'UTC')) AS day_start, + ifNull(nullIf(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', '')), ''), '$$_posthog_breakdown_null_$$') AS breakdown_value_1 + FROM events AS e SAMPLE 1.0 + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 2) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS e__override ON equals(e.distinct_id, e__override.distinct_id) + WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))), equals(e.event, 'sign up')) + GROUP BY day_start, + breakdown_value_1) + GROUP BY day_start, + breakdown_value_1 + ORDER BY day_start ASC, breakdown_value ASC) + GROUP BY breakdown_value + ORDER BY if(has(breakdown_value, '$$_posthog_breakdown_other_$$'), 2, if(has(breakdown_value, '$$_posthog_breakdown_null_$$'), 1, 0)) ASC, arraySum(total) DESC, breakdown_value ASC) + WHERE arrayExists(x -> isNotNull(x), breakdown_value) + GROUP BY breakdown_value + ORDER BY if(has(breakdown_value, '$$_posthog_breakdown_other_$$'), 2, if(has(breakdown_value, '$$_posthog_breakdown_null_$$'), 1, 0)) ASC, arraySum(total) DESC, breakdown_value ASC + LIMIT 50000 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 ''' # --- # name: TestTrends.test_dau_with_breakdown_filtering_with_sampling.4 ''' - /* celery:posthog.tasks.tasks.sync_insight_caching_state */ - SELECT team_id, - date_diff('second', max(timestamp), now()) AS age - FROM events - WHERE timestamp > date_sub(DAY, 3, now()) - AND timestamp < now() - GROUP BY team_id - ORDER BY age; + SELECT groupArray(1)(date)[1] AS date, + arrayFold((acc, x) -> arrayMap(i -> plus(acc[i], x[i]), range(1, plus(length(date), 1))), groupArray(ifNull(total, 0)), arrayWithConstant(length(date), reinterpretAsFloat64(0))) AS total, + arrayMap(i -> if(ifNull(ifNull(greaterOrEquals(row_number, 25), 0), 0), '$$_posthog_breakdown_other_$$', i), breakdown_value) AS breakdown_value + FROM + (SELECT arrayMap(number -> plus(toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC'))), toIntervalDay(number)), range(0, plus(coalesce(dateDiff('day', toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC'))), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))))), 1))) AS date, + arrayMap(_match_date -> arraySum(arraySlice(groupArray(ifNull(count, 0)), indexOf(groupArray(day_start) AS _days_for_count, _match_date) AS _index, plus(minus(arrayLastIndex(x -> ifNull(equals(x, _match_date), isNull(x) + and isNull(_match_date)), _days_for_count), _index), 1))), date) AS total, + breakdown_value AS breakdown_value, + rowNumberInAllBlocks() AS row_number + FROM + (SELECT sum(total) AS count, + day_start AS day_start, + [ifNull(toString(breakdown_value_1), '$$_posthog_breakdown_null_$$')] AS breakdown_value + FROM + (SELECT count(DISTINCT if(not(empty(e__override.distinct_id)), e__override.person_id, e.person_id)) AS total, + toStartOfDay(toTimeZone(e.timestamp, 'UTC')) AS day_start, + ifNull(nullIf(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(e.properties, '$some_property'), ''), 'null'), '^"|"$', '')), ''), '$$_posthog_breakdown_null_$$') AS breakdown_value_1 + FROM events AS e SAMPLE 1.0 + LEFT OUTER JOIN + (SELECT argMax(person_distinct_id_overrides.person_id, person_distinct_id_overrides.version) AS person_id, + person_distinct_id_overrides.distinct_id AS distinct_id + FROM person_distinct_id_overrides + WHERE equals(person_distinct_id_overrides.team_id, 2) + GROUP BY person_distinct_id_overrides.distinct_id + HAVING ifNull(equals(argMax(person_distinct_id_overrides.is_deleted, person_distinct_id_overrides.version), 0), 0) SETTINGS optimize_aggregation_in_order=1) AS e__override ON equals(e.distinct_id, e__override.distinct_id) + WHERE and(equals(e.team_id, 2), greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toStartOfDay(assumeNotNull(parseDateTime64BestEffortOrNull('2019-12-28 00:00:00', 6, 'UTC')))), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), assumeNotNull(parseDateTime64BestEffortOrNull('2020-01-04 23:59:59', 6, 'UTC'))), equals(e.event, 'sign up')) + GROUP BY day_start, + breakdown_value_1) + GROUP BY day_start, + breakdown_value_1 + ORDER BY day_start ASC, breakdown_value ASC) + GROUP BY breakdown_value + ORDER BY if(has(breakdown_value, '$$_posthog_breakdown_other_$$'), 2, if(has(breakdown_value, '$$_posthog_breakdown_null_$$'), 1, 0)) ASC, arraySum(total) DESC, breakdown_value ASC) + WHERE arrayExists(x -> isNotNull(x), breakdown_value) + GROUP BY breakdown_value + ORDER BY if(has(breakdown_value, '$$_posthog_breakdown_other_$$'), 2, if(has(breakdown_value, '$$_posthog_breakdown_null_$$'), 1, 0)) ASC, arraySum(total) DESC, breakdown_value ASC + LIMIT 50000 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=1, + format_csv_allow_double_quotes=0, + max_ast_elements=4000000, + max_expanded_ast_elements=4000000, + max_bytes_before_external_group_by=0 ''' # --- # name: TestTrends.test_dau_with_breakdown_filtering_with_sampling.5 From 067db1f9fd02c6134b81677d5738b2b82d3dcc4a Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Mon, 28 Oct 2024 15:49:56 +0100 Subject: [PATCH 06/22] fix: test --- ee/hogai/test/test_assistant.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py index b2eea5a217ec5..6eda6867d2ead 100644 --- a/ee/hogai/test/test_assistant.py +++ b/ee/hogai/test/test_assistant.py @@ -1,6 +1,7 @@ import json from unittest.mock import patch +import pytest from django.test import override_settings from langchain_core.runnables import RunnableLambda @@ -15,6 +16,7 @@ @override_settings(IN_UNIT_TESTING=True) class TestAssistant(NonAtomicBaseTest): + @pytest.mark.django_db(transaction=True) def test_assistant(self): mocked_planner_response = """ Action: @@ -37,6 +39,7 @@ def test_assistant(self): generator = assistant.stream( Conversation(messages=[HumanMessage(content="Launch the chain.")], session_id="id") ) + self.assertEqual(next(generator), "") self.assertEqual( json.loads(next(generator)), VisualizationMessage(answer=None, reasoning_steps=[], plan="Plan").model_dump(), From b8b9f38518df4a8fb5c2d772641ef11c746d996f Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Mon, 28 Oct 2024 16:29:34 +0100 Subject: [PATCH 07/22] fix: tests --- ee/hogai/test/__init__.py | 0 ee/hogai/test/test_assistant.py | 48 -------------------- ee/hogai/trends/nodes.py | 2 +- ee/hogai/trends/test/test_nodes.py | 73 +++++++++++++++++++++++++++++- 4 files changed, 73 insertions(+), 50 deletions(-) delete mode 100644 ee/hogai/test/__init__.py delete mode 100644 ee/hogai/test/test_assistant.py diff --git a/ee/hogai/test/__init__.py b/ee/hogai/test/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/ee/hogai/test/test_assistant.py b/ee/hogai/test/test_assistant.py deleted file mode 100644 index 6eda6867d2ead..0000000000000 --- a/ee/hogai/test/test_assistant.py +++ /dev/null @@ -1,48 +0,0 @@ -import json -from unittest.mock import patch - -import pytest -from django.test import override_settings -from langchain_core.runnables import RunnableLambda - -from ee.hogai.assistant import Assistant -from ee.hogai.trends.utils import GenerateTrendOutputModel -from ee.hogai.utils import Conversation -from posthog.schema import HumanMessage, VisualizationMessage -from posthog.test.base import ( - NonAtomicBaseTest, -) - - -@override_settings(IN_UNIT_TESTING=True) -class TestAssistant(NonAtomicBaseTest): - @pytest.mark.django_db(transaction=True) - def test_assistant(self): - mocked_planner_response = """ - Action: - ``` - {"action": "final_answer", "action_input": "Plan"} - ``` - """ - generator_response = GenerateTrendOutputModel(reasoning_steps=[], answer=None) - with ( - patch( - "ee.hogai.trends.nodes.CreateTrendsPlanNode._model", - return_value=RunnableLambda(lambda _: mocked_planner_response), - ) as planner_model_mock, - patch( - "ee.hogai.trends.nodes.GenerateTrendsNode._model", - return_value=RunnableLambda(lambda _: generator_response.model_dump()), - ) as generator_model_mock, - ): - assistant = Assistant(self.team) - generator = assistant.stream( - Conversation(messages=[HumanMessage(content="Launch the chain.")], session_id="id") - ) - self.assertEqual(next(generator), "") - self.assertEqual( - json.loads(next(generator)), - VisualizationMessage(answer=None, reasoning_steps=[], plan="Plan").model_dump(), - ) - self.assertEqual(planner_model_mock.call_count, 1) - self.assertEqual(generator_model_mock.call_count, 1) diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index f0cb5e0d06bd7..6c817c4a6d83d 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -403,4 +403,4 @@ def run(self, state: AssistantState, config: RunnableConfig): if not intermediate_steps: return state action, _ = intermediate_steps[-1] - return {"intermediate_steps": (action, action.log)} + return {"intermediate_steps": [(action, action.log)]} diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index dc297570c1fd1..dfa627d4ff790 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -1,6 +1,12 @@ +import json +from unittest.mock import patch + from django.test import override_settings +from langchain_core.agents import AgentAction +from langchain_core.runnables import RunnableLambda -from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode +from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode +from ee.hogai.trends.utils import GenerateTrendOutputModel from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person @@ -92,6 +98,25 @@ class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest): def setUp(self): self.schema = ExperimentalAITrendsQuery(series=[]) + def test_node_runs(self): + node = GenerateTrendsNode(self.team) + with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: + generator_model_mock.return_value = RunnableLambda( + lambda _: GenerateTrendOutputModel(reasoning_steps=["step"], answer=self.schema).model_dump() + ) + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "plan": "Plan", + }, + {}, + ) + self.assertNotIn("intermediate_steps", new_state) + self.assertEqual( + new_state["messages"], + [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])], + ) + def test_agent_reconstructs_conversation(self): node = GenerateTrendsNode(self.team) history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")]}) @@ -203,3 +228,49 @@ def test_agent_reconstructs_conversation_and_merges_messages(self): self.assertIn("Answer to this question:", history[5].content) self.assertNotIn("{{question}}", history[5].content) self.assertIn("Follow\nUp", history[5].content) + + def test_failover_with_incorrect_schema(self): + node = GenerateTrendsNode(self.team) + with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: + schema = GenerateTrendOutputModel(reasoning_steps=[], answer=None).model_dump() + # Emulate an incorrect JSON. It should be an object. + schema["answer"] = [] + generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) + + new_state = node.run({"messages": [HumanMessage(content="Text")]}, {}) + self.assertIn("intermediate_steps", new_state) + self.assertEqual(len(new_state["intermediate_steps"]), 1) + + def test_agent_reconstructs_conversation_with_failover(self): + action = AgentAction(tool="fix", tool_input="validation error", log="exception") + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [HumanMessage(content="Text")], + "plan": "randomplan", + "intermediate_steps": [(action, "uniqexception")], + }, + "uniqexception", + ) + self.assertEqual(len(history), 4) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Text", history[2].content) + self.assertEqual(history[3].type, "human") + self.assertIn("Pydantic", history[3].content) + self.assertIn("uniqexception", history[3].content) + + +class TestGenerateTrendsToolsNode(ClickhouseTestMixin, APIBaseTest): + def test_tools_node(self): + node = GenerateTrendsToolsNode(self.team) + action = AgentAction(tool="fix", tool_input="validation error", log="exception") + state = node.run({"messages": [], "intermediate_steps": [(action, None)]}, {}) + self.assertEqual(state, {"intermediate_steps": [(action, "exception")]}) From 800e64beee23cbdad6072e22ffa9491b75b398e9 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Mon, 28 Oct 2024 16:57:25 +0100 Subject: [PATCH 08/22] test: more tests --- ee/hogai/trends/nodes.py | 5 +++-- ee/hogai/trends/test/test_nodes.py | 35 +++++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index 6c817c4a6d83d..0a313ffc1a1a8 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -281,11 +281,12 @@ def run(self, state: AssistantState, config: RunnableConfig): reasoning_steps=message.reasoning_steps, answer=message.answer, ) - ] + ], + "intermediate_steps": None, } def router(self, state: AssistantState): - if state.get("tool_argument") is not None: + if state.get("intermediate_steps") is not None: return AssistantNodeName.GENERATE_TRENDS_TOOLS return AssistantNodeName.END diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index dfa627d4ff790..95d3a7b9371de 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -7,6 +7,7 @@ from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode from ee.hogai.trends.utils import GenerateTrendOutputModel +from ee.hogai.utils import AssistantNodeName from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person @@ -111,10 +112,12 @@ def test_node_runs(self): }, {}, ) - self.assertNotIn("intermediate_steps", new_state) self.assertEqual( - new_state["messages"], - [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])], + new_state, + { + "messages": [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])], + "intermediate_steps": None, + }, ) def test_agent_reconstructs_conversation(self): @@ -241,6 +244,23 @@ def test_failover_with_incorrect_schema(self): self.assertIn("intermediate_steps", new_state) self.assertEqual(len(new_state["intermediate_steps"]), 1) + def test_node_leaves_failover(self): + node = GenerateTrendsNode(self.team) + with patch( + "ee.hogai.trends.nodes.GenerateTrendsNode._model", + return_value=RunnableLambda( + lambda _: GenerateTrendOutputModel(reasoning_steps=[], answer=self.schema).model_dump() + ), + ): + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], + }, + {}, + ) + self.assertIsNone(new_state["intermediate_steps"]) + def test_agent_reconstructs_conversation_with_failover(self): action = AgentAction(tool="fix", tool_input="validation error", log="exception") node = GenerateTrendsNode(self.team) @@ -267,6 +287,15 @@ def test_agent_reconstructs_conversation_with_failover(self): self.assertIn("Pydantic", history[3].content) self.assertIn("uniqexception", history[3].content) + def test_router(self): + node = GenerateTrendsNode(self.team) + state = node.router({"messages": [], "intermediate_steps": None}) + self.assertEqual(state, AssistantNodeName.END) + state = node.router( + {"messages": [], "intermediate_steps": [(AgentAction(tool="", tool_input="", log=""), None)]} + ) + self.assertEqual(state, AssistantNodeName.GENERATE_TRENDS_TOOLS) + class TestGenerateTrendsToolsNode(ClickhouseTestMixin, APIBaseTest): def test_tools_node(self): From 2ceb4845b9123c64ca41b9ebe5d1175b50fe1c0e Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 29 Oct 2024 09:51:53 +0100 Subject: [PATCH 09/22] feat: improved validation message --- ee/hogai/trends/nodes.py | 43 +++++++++++++++--------------- ee/hogai/trends/parsers.py | 24 +++++++++++++++++ ee/hogai/trends/prompts.py | 14 ++++++++-- ee/hogai/trends/test/test_nodes.py | 6 +++-- 4 files changed, 62 insertions(+), 25 deletions(-) create mode 100644 ee/hogai/trends/parsers.py diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index f4eada73a475c..df12f0f8f5e76 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -1,5 +1,4 @@ import itertools -import json import xml.etree.ElementTree as ET from functools import cached_property from typing import Optional, Union, cast @@ -11,19 +10,20 @@ from langchain_core.messages import AIMessage as LangchainAssistantMessage from langchain_core.messages import BaseMessage, merge_message_runs from langchain_core.messages import HumanMessage as LangchainHumanMessage -from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI from pydantic import ValidationError from ee.hogai.hardcoded_definitions import hardcoded_prop_defs +from ee.hogai.trends.parsers import PydanticOutputParserException, parse_generated_trends_output from ee.hogai.trends.prompts import ( react_definitions_prompt, react_follow_up_prompt, react_scratchpad_prompt, react_system_prompt, react_user_prompt, + trends_failover_output_prompt, trends_failover_prompt, trends_group_mapping_prompt, trends_new_plan_prompt, @@ -252,26 +252,15 @@ def run(self, state: AssistantState, config: RunnableConfig): ) + self._reconstruct_conversation(state, validation_error_message=validation_error_message) merger = merge_message_runs() - chain = ( - trends_generation_prompt - | merger - | self._model - # Result from structured output is a parsed dict. Convert to a string since the output parser expects it. - | RunnableLambda(lambda x: json.dumps(x)) - # Validate a string input. - | PydanticOutputParser[GenerateTrendOutputModel](pydantic_object=GenerateTrendOutputModel) - ) + chain = trends_generation_prompt | merger | self._model | parse_generated_trends_output try: message: GenerateTrendOutputModel = chain.invoke({}, config) - except OutputParserException as e: - if e.send_to_llm: - observation = str(e.observation) - else: - observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question." - + except PydanticOutputParserException as e: return { - "intermediate_steps": [(AgentAction("handle_incorrect_response", observation, str(e)), None)], + "intermediate_steps": [ + (AgentAction("handle_incorrect_response", e.llm_output, e.validation_message), None) + ], } return { @@ -378,7 +367,7 @@ def _reconstruct_conversation( if validation_error_message: conversation.append( HumanMessagePromptTemplate.from_template(trends_failover_prompt, template_format="mustache").format( - exception_message=validation_error_message + validation_error_message=validation_error_message ) ) @@ -403,5 +392,17 @@ def run(self, state: AssistantState, config: RunnableConfig): intermediate_steps = state.get("intermediate_steps", []) if not intermediate_steps: return state + action, _ = intermediate_steps[-1] - return {"intermediate_steps": [(action, action.log)]} + prompt = ( + ChatPromptTemplate.from_template(trends_failover_output_prompt, template_format="mustache") + .format_messages(output=action.tool_input, exception_message=action.log)[0] + .content + ) + + return { + "intermediate_steps": [ + *intermediate_steps[:-1], + (action, prompt), + ] + } diff --git a/ee/hogai/trends/parsers.py b/ee/hogai/trends/parsers.py new file mode 100644 index 0000000000000..a461c692d8825 --- /dev/null +++ b/ee/hogai/trends/parsers.py @@ -0,0 +1,24 @@ +import json + +from pydantic import ValidationError + +from ee.hogai.trends.utils import GenerateTrendOutputModel + + +class PydanticOutputParserException(ValueError): + llm_output: str + """Serialized LLM output.""" + validation_message: str + """Pydantic validation error message.""" + + def __init__(self, llm_output: str, validation_message: str): + super().__init__(llm_output) + self.llm_output = llm_output + self.validation_message = validation_message + + +def parse_generated_trends_output(output: dict) -> GenerateTrendOutputModel: + try: + return GenerateTrendOutputModel.model_validate(output) + except ValidationError as e: + raise PydanticOutputParserException(llm_output=json.dumps(output), validation_message=e.json(include_url=False)) diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py index bb644922e3045..84c2bcedb544a 100644 --- a/ee/hogai/trends/prompts.py +++ b/ee/hogai/trends/prompts.py @@ -270,12 +270,22 @@ Answer to this question: {{question}} """ -trends_failover_prompt = """ -The result of your previous generatin raised the Pydantic validation exception: +trends_failover_output_prompt = """ +Generation output: +``` +{{output}} +``` +Exception message: ``` {{exception_message}} ``` +""" + +trends_failover_prompt = """ +The result of the previous generation raised the Pydantic validation exception. + +{{validation_error_message}} Fix the error and return the correct response. """ diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index 95d3a7b9371de..9c280bed8394a 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -300,6 +300,8 @@ def test_router(self): class TestGenerateTrendsToolsNode(ClickhouseTestMixin, APIBaseTest): def test_tools_node(self): node = GenerateTrendsToolsNode(self.team) - action = AgentAction(tool="fix", tool_input="validation error", log="exception") + action = AgentAction(tool="fix", tool_input="validationerror", log="pydanticexception") state = node.run({"messages": [], "intermediate_steps": [(action, None)]}, {}) - self.assertEqual(state, {"intermediate_steps": [(action, "exception")]}) + self.assertIsNotNone("validationerror", state["intermediate_steps"][0][1]) + self.assertIn("validationerror", state["intermediate_steps"][0][1]) + self.assertIn("pydanticexception", state["intermediate_steps"][0][1]) From 923768b89174d0c9eab8e745efbe36a7b8dbe587 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 29 Oct 2024 10:16:48 +0100 Subject: [PATCH 10/22] feat: status messages --- ee/hogai/assistant.py | 13 ++++++----- ee/hogai/trends/nodes.py | 23 +++++++++++++++++--- ee/hogai/utils.py | 4 ++-- frontend/src/queries/schema.json | 37 +++++++++++++++++++++++++++++++- frontend/src/queries/schema.ts | 22 ++++++++++++++++++- posthog/api/query.py | 15 +++++++++++-- posthog/schema.py | 30 ++++++++++++++++++++++++-- 7 files changed, 128 insertions(+), 16 deletions(-) diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index 817a6f31a6b59..d01181b937c4e 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -4,6 +4,7 @@ from langchain_core.messages import AIMessageChunk from langfuse.callback import CallbackHandler from langgraph.graph.state import StateGraph +from pydantic import BaseModel from ee import settings from ee.hogai.trends.nodes import ( @@ -14,7 +15,7 @@ ) from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation from posthog.models.team.team import Team -from posthog.schema import VisualizationMessage +from posthog.schema import AssistantGenerationStatusEvent, AssistantGenerationStatusType, VisualizationMessage if settings.LANGFUSE_PUBLIC_KEY: langfuse_handler = CallbackHandler( @@ -82,7 +83,7 @@ def _compile_graph(self): return builder.compile() - def stream(self, conversation: Conversation) -> Generator[str, None, None]: + def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]: assistant_graph = self._compile_graph() callbacks = [langfuse_handler] if langfuse_handler else [] messages = [message.root for message in conversation.messages] @@ -99,7 +100,7 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: chunks = AIMessageChunk(content="") # Send a chunk to establish the connection avoiding the worker's timeout. - yield "" + yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK) for update in generator: if is_state_update(update): @@ -117,7 +118,9 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: message = cast( VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0] ) - yield message.model_dump_json() + yield message + elif state_update[AssistantNodeName.GENERATE_TRENDS].get("intermediate_steps", []): + yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR) elif is_message_update(update): langchain_message, langgraph_state = update[1] @@ -129,4 +132,4 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]: if parsed_message: yield VisualizationMessage( reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer - ).model_dump_json() + ) diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index df12f0f8f5e76..d7c521b3baf60 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -46,7 +46,13 @@ from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner from posthog.hogql_queries.query_runner import ExecutionMode from posthog.models.group_type_mapping import GroupTypeMapping -from posthog.schema import CachedTeamTaxonomyQueryResponse, HumanMessage, TeamTaxonomyQuery, VisualizationMessage +from posthog.schema import ( + CachedTeamTaxonomyQueryResponse, + FailureMessage, + HumanMessage, + TeamTaxonomyQuery, + VisualizationMessage, +) class CreateTrendsPlanNode(AssistantNode): @@ -241,7 +247,7 @@ class GenerateTrendsNode(AssistantNode): def run(self, state: AssistantState, config: RunnableConfig): generated_plan = state.get("plan", "") - intermediate_steps = state.get("intermediate_steps", []) + intermediate_steps = state.get("intermediate_steps") or [] validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None trends_generation_prompt = ChatPromptTemplate.from_messages( @@ -257,9 +263,20 @@ def run(self, state: AssistantState, config: RunnableConfig): try: message: GenerateTrendOutputModel = chain.invoke({}, config) except PydanticOutputParserException as e: + # Generation step is expensive. After a second unsuccessful attempt, it's better to send a failure message. + if len(intermediate_steps) > 2: + return { + "messages": [ + FailureMessage( + content="Oops! It looks like I’m having trouble generating this trends insight. Could you please try again?" + ) + ], + } + return { "intermediate_steps": [ - (AgentAction("handle_incorrect_response", e.llm_output, e.validation_message), None) + *intermediate_steps, + (AgentAction("handle_incorrect_response", e.llm_output, e.validation_message), None), ], } diff --git a/ee/hogai/utils.py b/ee/hogai/utils.py index 3f7712b8ac82c..70d1ea969621a 100644 --- a/ee/hogai/utils.py +++ b/ee/hogai/utils.py @@ -10,9 +10,9 @@ from pydantic import BaseModel, Field from posthog.models.team.team import Team -from posthog.schema import AssistantMessage, HumanMessage, RootAssistantMessage, VisualizationMessage +from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, RootAssistantMessage, VisualizationMessage -AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage] +AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage, FailureMessage] class Conversation(BaseModel): diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 75b26c88bc4de..f4fe409f978b7 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -617,6 +617,24 @@ } ] }, + "AssistantEventType": { + "enum": ["status", "message"], + "type": "string" + }, + "AssistantGenerationStatusEvent": { + "additionalProperties": false, + "properties": { + "type": { + "$ref": "#/definitions/AssistantGenerationStatusType" + } + }, + "required": ["type"], + "type": "object" + }, + "AssistantGenerationStatusType": { + "enum": ["ack", "generation_error"], + "type": "string" + }, "AssistantMessage": { "additionalProperties": false, "properties": { @@ -632,7 +650,7 @@ "type": "object" }, "AssistantMessageType": { - "enum": ["human", "ai", "ai/viz"], + "enum": ["human", "ai", "ai/viz", "ai/failure"], "type": "string" }, "AutocompleteCompletionItem": { @@ -5450,6 +5468,20 @@ "required": ["kind", "series"], "type": "object" }, + "FailureMessage": { + "additionalProperties": false, + "properties": { + "content": { + "type": "string" + }, + "type": { + "const": "ai/failure", + "type": "string" + } + }, + "required": ["type"], + "type": "object" + }, "FeaturePropertyFilter": { "additionalProperties": false, "properties": { @@ -10735,6 +10767,9 @@ }, { "$ref": "#/definitions/HumanMessage" + }, + { + "$ref": "#/definitions/FailureMessage" } ] }, diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index 273605a42f6d7..3577201d140c6 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -2096,6 +2096,7 @@ export enum AssistantMessageType { Human = 'human', Assistant = 'ai', Visualization = 'ai/viz', + Failure = 'ai/failure', } export interface HumanMessage { @@ -2115,4 +2116,23 @@ export interface VisualizationMessage { answer?: ExperimentalAITrendsQuery } -export type RootAssistantMessage = VisualizationMessage | AssistantMessage | HumanMessage +export interface FailureMessage { + type: AssistantMessageType.Failure + content?: string +} + +export type RootAssistantMessage = VisualizationMessage | AssistantMessage | HumanMessage | FailureMessage + +export enum AssistantEventType { + Status = 'status', + Message = 'message', +} + +export enum AssistantGenerationStatusType { + Acknowledged = 'ack', + GenerationError = 'generation_error', +} + +export interface AssistantGenerationStatusEvent { + type: AssistantGenerationStatusType +} diff --git a/posthog/api/query.py b/posthog/api/query.py index f2eaccea53ae5..d42c72504e5c4 100644 --- a/posthog/api/query.py +++ b/posthog/api/query.py @@ -41,7 +41,14 @@ ClickHouseSustainedRateThrottle, HogQLQueryThrottle, ) -from posthog.schema import HumanMessage, QueryRequest, QueryResponseAlternative, QueryStatusResponse +from posthog.schema import ( + AssistantEventType, + AssistantGenerationStatusEvent, + HumanMessage, + QueryRequest, + QueryResponseAlternative, + QueryStatusResponse, +) class ServerSentEventRenderer(BaseRenderer): @@ -185,7 +192,11 @@ def generate(): last_message = None for message in assistant.stream(validated_body): last_message = message - yield f"data: {message}\n\n" + if isinstance(message, AssistantGenerationStatusEvent): + yield f"event: {AssistantEventType.STATUS}\n\n" + else: + yield f"event: {AssistantEventType.MESSAGE}\n\n" + yield f"data: {message.model_dump_json()}\n\n" human_message = validated_body.messages[-1].root if isinstance(human_message, HumanMessage): diff --git a/posthog/schema.py b/posthog/schema.py index b386d5d6c8e97..e1d54707a87e8 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -63,6 +63,16 @@ class AlertState(StrEnum): SNOOZED = "Snoozed" +class AssistantEventType(StrEnum): + STATUS = "status" + MESSAGE = "message" + + +class AssistantGenerationStatusType(StrEnum): + ACK = "ack" + GENERATION_ERROR = "generation_error" + + class AssistantMessage(BaseModel): model_config = ConfigDict( extra="forbid", @@ -75,6 +85,7 @@ class AssistantMessageType(StrEnum): HUMAN = "human" AI = "ai" AI_VIZ = "ai/viz" + AI_FAILURE = "ai/failure" class Kind(StrEnum): @@ -558,6 +569,14 @@ class ExperimentVariantTrendsBaseStats(BaseModel): key: str +class FailureMessage(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + content: Optional[str] = None + type: Literal["ai/failure"] = "ai/failure" + + class FilterLogicalOperator(StrEnum): AND_ = "AND" OR_ = "OR" @@ -1745,6 +1764,13 @@ class AlertCondition(BaseModel): type: AlertConditionType +class AssistantGenerationStatusEvent(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: AssistantGenerationStatusType + + class Breakdown(BaseModel): model_config = ConfigDict( extra="forbid", @@ -6032,8 +6058,8 @@ class QueryResponseAlternative( ] -class RootAssistantMessage(RootModel[Union[VisualizationMessage, AssistantMessage, HumanMessage]]): - root: Union[VisualizationMessage, AssistantMessage, HumanMessage] +class RootAssistantMessage(RootModel[Union[VisualizationMessage, AssistantMessage, HumanMessage, FailureMessage]]): + root: Union[VisualizationMessage, AssistantMessage, HumanMessage, FailureMessage] class DatabaseSchemaQueryResponse(BaseModel): From ad7e07c13547fe10e202799c3091e982a18e2ee5 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 29 Oct 2024 10:23:18 +0100 Subject: [PATCH 11/22] test: failover --- ee/hogai/trends/nodes.py | 3 +- ee/hogai/trends/test/test_nodes.py | 52 +++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index d7c521b3baf60..e72af9b86290d 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -264,13 +264,14 @@ def run(self, state: AssistantState, config: RunnableConfig): message: GenerateTrendOutputModel = chain.invoke({}, config) except PydanticOutputParserException as e: # Generation step is expensive. After a second unsuccessful attempt, it's better to send a failure message. - if len(intermediate_steps) > 2: + if len(intermediate_steps) >= 2: return { "messages": [ FailureMessage( content="Oops! It looks like I’m having trouble generating this trends insight. Could you please try again?" ) ], + "intermediate_steps": None, } return { diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index 9c280bed8394a..57f2ae8358fe8 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -8,7 +8,13 @@ from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode from ee.hogai.trends.utils import GenerateTrendOutputModel from ee.hogai.utils import AssistantNodeName -from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage +from posthog.schema import ( + AssistantMessage, + ExperimentalAITrendsQuery, + FailureMessage, + HumanMessage, + VisualizationMessage, +) from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person @@ -244,6 +250,16 @@ def test_failover_with_incorrect_schema(self): self.assertIn("intermediate_steps", new_state) self.assertEqual(len(new_state["intermediate_steps"]), 1) + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], + }, + {}, + ) + self.assertIn("intermediate_steps", new_state) + self.assertEqual(len(new_state["intermediate_steps"]), 2) + def test_node_leaves_failover(self): node = GenerateTrendsNode(self.team) with patch( @@ -261,6 +277,40 @@ def test_node_leaves_failover(self): ) self.assertIsNone(new_state["intermediate_steps"]) + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [ + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + ], + }, + {}, + ) + self.assertIsNone(new_state["intermediate_steps"]) + + def test_node_leaves_failover_after_second_unsuccessful_attempt(self): + node = GenerateTrendsNode(self.team) + with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: + schema = GenerateTrendOutputModel(reasoning_steps=[], answer=None).model_dump() + # Emulate an incorrect JSON. It should be an object. + schema["answer"] = [] + generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) + + new_state = node.run( + { + "messages": [HumanMessage(content="Text")], + "intermediate_steps": [ + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + (AgentAction(tool="", tool_input="", log="exception"), "exception"), + ], + }, + {}, + ) + self.assertIsNone(new_state["intermediate_steps"]) + self.assertEqual(len(new_state["messages"]), 1) + self.assertIsInstance(new_state["messages"][0], FailureMessage) + def test_agent_reconstructs_conversation_with_failover(self): action = AgentAction(tool="fix", tool_input="validation error", log="exception") node = GenerateTrendsNode(self.team) From 40a1e1cd9179c88d8c542d5e6934a994aa1b36f3 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Tue, 29 Oct 2024 11:13:54 +0100 Subject: [PATCH 12/22] feat: frontend messages for repeated generations --- frontend/src/scenes/max/Max.stories.tsx | 44 ++++++++++++- frontend/src/scenes/max/Thread.tsx | 63 ++++++++++++++----- .../max/__mocks__/chatResponse.mocks.ts | 24 ++++++- .../scenes/max/__mocks__/failureResponse.json | 4 ++ frontend/src/scenes/max/maxLogic.ts | 54 ++++++++++------ frontend/src/scenes/max/utils.ts | 12 +++- posthog/api/query.py | 4 +- 7 files changed, 165 insertions(+), 40 deletions(-) create mode 100644 frontend/src/scenes/max/__mocks__/failureResponse.json diff --git a/frontend/src/scenes/max/Max.stories.tsx b/frontend/src/scenes/max/Max.stories.tsx index 27045963c6e42..88ffe32ea7559 100644 --- a/frontend/src/scenes/max/Max.stories.tsx +++ b/frontend/src/scenes/max/Max.stories.tsx @@ -1,10 +1,10 @@ import { Meta, StoryFn } from '@storybook/react' -import { BindLogic, useActions } from 'kea' +import { BindLogic, useActions, useValues } from 'kea' import { useEffect } from 'react' import { mswDecorator, useStorybookMocks } from '~/mocks/browser' -import { chatResponseChunk } from './__mocks__/chatResponse.mocks' +import { chatResponseChunk, failureChunk, generationFailureChunk } from './__mocks__/chatResponse.mocks' import { MaxInstance } from './Max' import { maxLogic } from './maxLogic' @@ -104,3 +104,43 @@ EmptyThreadLoading.parameters = { waitForLoadersToDisappear: false, }, } + +export const GenerationFailureThread: StoryFn = () => { + useStorybookMocks({ + post: { + '/api/environments/:team_id/query/chat/': (_, res, ctx) => res(ctx.text(generationFailureChunk)), + }, + }) + + const sessionId = 'd210b263-8521-4c5b-b3c4-8e0348df574b' + + const { askMax, setMessageStatus } = useActions(maxLogic({ sessionId })) + const { thread, threadLoading } = useValues(maxLogic({ sessionId })) + useEffect(() => { + askMax('What are my most popular pages?') + }, []) + useEffect(() => { + if (thread.length === 2 && !threadLoading) { + setMessageStatus(1, 'error') + } + }, [thread.length, threadLoading]) + + return