From d836bc860a9045ddedee4439270914a848a9139f Mon Sep 17 00:00:00 2001
From: Georgiy Tarasov <gtarasov.work@gmail.com>
Date: Wed, 20 Nov 2024 15:44:47 +0100
Subject: [PATCH] feat(product-assistant): evaluation pipeline (#26179)

Co-authored-by: Michael Matloka <michael@posthog.com>
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
---
 .gitignore                                    |   8 +-
 ee/hogai/assistant.py                         | 112 +++++--
 ee/hogai/eval/__init__.py                     |   0
 ee/hogai/eval/test_eval_funnel_planner.py     | 179 +++++++++++
 ee/hogai/eval/test_eval_router.py             |  59 ++++
 ee/hogai/eval/test_eval_trends_planner.py     | 163 ++++++++++
 ee/hogai/eval/utils.py                        |  28 ++
 ee/hogai/schema_generator/nodes.py            |   2 +-
 .../api/test/__snapshots__/test_decide.ambr   |  32 ++
 posthog/api/test/test_decide.py               |   4 +-
 requirements-dev.in                           |   1 +
 requirements-dev.txt                          | 304 +++++++++++++++++-
 requirements.in                               |   2 +
 requirements.txt                              |   6 +-
 14 files changed, 867 insertions(+), 33 deletions(-)
 create mode 100644 ee/hogai/eval/__init__.py
 create mode 100644 ee/hogai/eval/test_eval_funnel_planner.py
 create mode 100644 ee/hogai/eval/test_eval_router.py
 create mode 100644 ee/hogai/eval/test_eval_trends_planner.py
 create mode 100644 ee/hogai/eval/utils.py

diff --git a/.gitignore b/.gitignore
index 6f0c1be90cbae..a41dd0980a217 100644
--- a/.gitignore
+++ b/.gitignore
@@ -69,4 +69,10 @@ plugin-transpiler/dist
 *.log
 # pyright config (keep this until we have a standardized one)
 pyrightconfig.json
-.temporal-worker-settings
\ No newline at end of file
+# Assistant Evaluation with Deepeval
+.deepeval
+.deepeval-cache.json
+.deepeval_telemtry.txt
+.temporal-worker-settings
+temp_test_run_data.json
+.temp-deepeval-cache.json
diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py
index df200ae869c3b..57e6a6f13cb86 100644
--- a/ee/hogai/assistant.py
+++ b/ee/hogai/assistant.py
@@ -1,9 +1,9 @@
-from collections.abc import Generator
-from typing import Any, Literal, TypedDict, TypeGuard, Union
+from collections.abc import Generator, Hashable, Iterator
+from typing import Any, Literal, Optional, TypedDict, TypeGuard, Union, cast
 
 from langchain_core.messages import AIMessageChunk
 from langfuse.callback import CallbackHandler
-from langgraph.graph.state import StateGraph
+from langgraph.graph.state import CompiledStateGraph, StateGraph
 from pydantic import BaseModel
 from sentry_sdk import capture_exception
 
@@ -74,25 +74,49 @@ def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], As
 }
 
 
-class Assistant:
+class AssistantGraph:
     _team: Team
     _graph: StateGraph
 
     def __init__(self, team: Team):
         self._team = team
         self._graph = StateGraph(AssistantState)
-
-    def _compile_graph(self):
+        self._has_start_node = False
+
+    def add_edge(self, from_node: AssistantNodeName, to_node: AssistantNodeName):
+        if from_node == AssistantNodeName.START:
+            self._has_start_node = True
+        self._graph.add_edge(from_node, to_node)
+        return self
+
+    def compile(self):
+        if not self._has_start_node:
+            raise ValueError("Start node not added to the graph")
+        return self._graph.compile()
+
+    def add_start(self):
+        return self.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)
+
+    def add_router(
+        self,
+        path_map: Optional[dict[Hashable, AssistantNodeName]] = None,
+    ):
         builder = self._graph
-
+        path_map = path_map or {
+            "trends": AssistantNodeName.TRENDS_PLANNER,
+            "funnel": AssistantNodeName.FUNNEL_PLANNER,
+        }
         router_node = RouterNode(self._team)
         builder.add_node(AssistantNodeName.ROUTER, router_node.run)
-        builder.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)
         builder.add_conditional_edges(
             AssistantNodeName.ROUTER,
             router_node.router,
-            path_map={"trends": AssistantNodeName.TRENDS_PLANNER, "funnel": AssistantNodeName.FUNNEL_PLANNER},
+            path_map=cast(dict[Hashable, str], path_map),
         )
+        return self
+
+    def add_trends_planner(self, next_node: AssistantNodeName = AssistantNodeName.TRENDS_GENERATOR):
+        builder = self._graph
 
         create_trends_plan_node = TrendsPlannerNode(self._team)
         builder.add_node(AssistantNodeName.TRENDS_PLANNER, create_trends_plan_node.run)
@@ -111,26 +135,36 @@ def _compile_graph(self):
             create_trends_plan_tools_node.router,
             path_map={
                 "continue": AssistantNodeName.TRENDS_PLANNER,
-                "plan_found": AssistantNodeName.TRENDS_GENERATOR,
+                "plan_found": next_node,
             },
         )
 
-        generate_trends_node = TrendsGeneratorNode(self._team)
-        builder.add_node(AssistantNodeName.TRENDS_GENERATOR, generate_trends_node.run)
+        return self
+
+    def add_trends_generator(self, next_node: AssistantNodeName = AssistantNodeName.SUMMARIZER):
+        builder = self._graph
+
+        trends_generator = TrendsGeneratorNode(self._team)
+        builder.add_node(AssistantNodeName.TRENDS_GENERATOR, trends_generator.run)
 
-        generate_trends_tools_node = TrendsGeneratorToolsNode(self._team)
-        builder.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, generate_trends_tools_node.run)
+        trends_generator_tools = TrendsGeneratorToolsNode(self._team)
+        builder.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, trends_generator_tools.run)
 
         builder.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR)
         builder.add_conditional_edges(
             AssistantNodeName.TRENDS_GENERATOR,
-            generate_trends_node.router,
+            trends_generator.router,
             path_map={
                 "tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS,
-                "next": AssistantNodeName.SUMMARIZER,
+                "next": next_node,
             },
         )
 
+        return self
+
+    def add_funnel_planner(self, next_node: AssistantNodeName = AssistantNodeName.FUNNEL_GENERATOR):
+        builder = self._graph
+
         funnel_planner = FunnelPlannerNode(self._team)
         builder.add_node(AssistantNodeName.FUNNEL_PLANNER, funnel_planner.run)
         builder.add_conditional_edges(
@@ -148,41 +182,69 @@ def _compile_graph(self):
             funnel_planner_tools.router,
             path_map={
                 "continue": AssistantNodeName.FUNNEL_PLANNER,
-                "plan_found": AssistantNodeName.FUNNEL_GENERATOR,
+                "plan_found": next_node,
             },
         )
 
+        return self
+
+    def add_funnel_generator(self, next_node: AssistantNodeName = AssistantNodeName.SUMMARIZER):
+        builder = self._graph
+
         funnel_generator = FunnelGeneratorNode(self._team)
         builder.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator.run)
 
-        funnel_generator_tools_node = FunnelGeneratorToolsNode(self._team)
-        builder.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools_node.run)
+        funnel_generator_tools = FunnelGeneratorToolsNode(self._team)
+        builder.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools.run)
 
         builder.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR)
         builder.add_conditional_edges(
             AssistantNodeName.FUNNEL_GENERATOR,
-            generate_trends_node.router,
+            funnel_generator.router,
             path_map={
                 "tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS,
-                "next": AssistantNodeName.SUMMARIZER,
+                "next": next_node,
             },
         )
 
+        return self
+
+    def add_summarizer(self, next_node: AssistantNodeName = AssistantNodeName.END):
+        builder = self._graph
         summarizer_node = SummarizerNode(self._team)
         builder.add_node(AssistantNodeName.SUMMARIZER, summarizer_node.run)
-        builder.add_edge(AssistantNodeName.SUMMARIZER, AssistantNodeName.END)
+        builder.add_edge(AssistantNodeName.SUMMARIZER, next_node)
+        return self
+
+    def compile_full_graph(self):
+        return (
+            self.add_start()
+            .add_router()
+            .add_trends_planner()
+            .add_trends_generator()
+            .add_funnel_planner()
+            .add_funnel_generator()
+            .add_summarizer()
+            .compile()
+        )
+
 
-        return builder.compile()
+class Assistant:
+    _team: Team
+    _graph: CompiledStateGraph
+
+    def __init__(self, team: Team):
+        self._team = team
+        self._graph = AssistantGraph(team).compile_full_graph()
 
     def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]:
-        assistant_graph = self._compile_graph()
         callbacks = [langfuse_handler] if langfuse_handler else []
         messages = [message.root for message in conversation.messages]
 
         chunks = AIMessageChunk(content="")
         state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}
 
-        generator = assistant_graph.stream(
+        generator: Iterator[Any] = self._graph.stream(
             state,
             config={"recursion_limit": 24, "callbacks": callbacks},
             stream_mode=["messages", "values", "updates"],
diff --git a/ee/hogai/eval/__init__.py b/ee/hogai/eval/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/ee/hogai/eval/test_eval_funnel_planner.py b/ee/hogai/eval/test_eval_funnel_planner.py
new file mode 100644
index 0000000000000..e7370e6915f21
--- /dev/null
+++ b/ee/hogai/eval/test_eval_funnel_planner.py
@@ -0,0 +1,179 @@
+from deepeval import assert_test
+from deepeval.metrics import GEval
+from deepeval.test_case import LLMTestCase, LLMTestCaseParams
+from langgraph.graph.state import CompiledStateGraph
+
+from ee.hogai.assistant import AssistantGraph
+from ee.hogai.eval.utils import EvalBaseTest
+from ee.hogai.utils import AssistantNodeName
+from posthog.schema import HumanMessage
+
+
+class TestEvalFunnelPlanner(EvalBaseTest):
+    def _get_plan_correctness_metric(self):
+        return GEval(
+            name="Funnel Plan Correctness",
+            criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a funnel insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about funnel insights.",
+            evaluation_steps=[
+                "A plan must define at least two series in the sequence, but it is not required to define any filters, exclusion steps, or a breakdown.",
+                "Compare events, properties, math types, and property values of 'expected output' and 'actual output'.",
+                "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.",
+                # The criteria for aggregations must be more specific because there isn't a way to bypass them.
+                "Check if the math types in 'actual output' match those in 'expected output.' If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.",
+                "If 'expected output' contains exclusion steps, check if 'actual output' contains those, and heavily penalize if the exclusion steps are not present or different.",
+                "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different. Plans may only have one breakdown.",
+                # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time.
+                "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.",
+            ],
+            evaluation_params=[
+                LLMTestCaseParams.INPUT,
+                LLMTestCaseParams.EXPECTED_OUTPUT,
+                LLMTestCaseParams.ACTUAL_OUTPUT,
+            ],
+            threshold=0.7,
+        )
+
+    def _call_node(self, query):
+        graph: CompiledStateGraph = (
+            AssistantGraph(self.team)
+            .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER)
+            .add_funnel_planner(AssistantNodeName.END)
+            .compile()
+        )
+        state = graph.invoke({"messages": [HumanMessage(content=query)]})
+        return state["plan"]
+
+    def test_basic_funnel(self):
+        query = "what was the conversion from a page view to sign up?"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Sequence:
+            1. $pageview
+            2. signed_up
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_outputs_at_least_two_events(self):
+        """
+        Ambigious query. The funnel must return at least two events.
+        """
+        query = "how many users paid a bill?"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Sequence:
+            1. any event
+            2. upgrade_plan
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_no_excessive_property_filters(self):
+        query = "Show the user conversion from a sign up to a file download"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Sequence:
+            1. signed_up
+            2. downloaded_file
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_basic_filtering(self):
+        query = (
+            "What was the conversion from uploading a file to downloading it from Chrome and Safari in the last 30d?"
+        )
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Sequence:
+            1. uploaded_file
+                - property filter 1:
+                    - entity: event
+                    - property name: $browser
+                    - property type: String
+                    - operator: equals
+                    - property value: Chrome
+                - property filter 2:
+                    - entity: event
+                    - property name: $browser
+                    - property type: String
+                    - operator: equals
+                    - property value: Safari
+            2. downloaded_file
+                - property filter 1:
+                    - entity: event
+                    - property name: $browser
+                    - property type: String
+                    - operator: equals
+                    - property value: Chrome
+                - property filter 2:
+                    - entity: event
+                    - property name: $browser
+                    - property type: String
+                    - operator: equals
+                    - property value: Safari
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_exclusion_steps(self):
+        query = "What was the conversion from uploading a file to downloading it in the last 30d excluding users that deleted a file?"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Sequence:
+            1. uploaded_file
+            2. downloaded_file
+
+            Exclusions:
+            - deleted_file
+                - start index: 0
+                - end index: 1
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_breakdown(self):
+        query = "Show a conversion from uploading a file to downloading it segmented by a user's email"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Sequence:
+            1. uploaded_file
+            2. downloaded_file
+
+            Breakdown by:
+            - entity: person
+            - property name: email
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_needle_in_a_haystack(self):
+        query = "What was the conversion from a sign up to a paying customer on the personal-pro plan?"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Sequence:
+            1. signed_up
+            2. paid_bill
+                - property filter 1:
+                    - entity: event
+                    - property name: plan
+                    - property type: String
+                    - operator: equals
+                    - property value: personal/pro
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
diff --git a/ee/hogai/eval/test_eval_router.py b/ee/hogai/eval/test_eval_router.py
new file mode 100644
index 0000000000000..73f916245ae7a
--- /dev/null
+++ b/ee/hogai/eval/test_eval_router.py
@@ -0,0 +1,59 @@
+from langgraph.graph.state import CompiledStateGraph
+
+from ee.hogai.assistant import AssistantGraph
+from ee.hogai.eval.utils import EvalBaseTest
+from ee.hogai.utils import AssistantNodeName
+from posthog.schema import HumanMessage, RouterMessage
+
+
+class TestEvalRouter(EvalBaseTest):
+    def _call_node(self, query: str | list):
+        graph: CompiledStateGraph = (
+            AssistantGraph(self.team)
+            .add_start()
+            .add_router(path_map={"trends": AssistantNodeName.END, "funnel": AssistantNodeName.END})
+            .compile()
+        )
+        messages = [HumanMessage(content=query)] if isinstance(query, str) else query
+        state = graph.invoke({"messages": messages})
+        return state["messages"][-1].content
+
+    def test_outputs_basic_trends_insight(self):
+        query = "Show the $pageview trend"
+        res = self._call_node(query)
+        self.assertEqual(res, "trends")
+
+    def test_outputs_basic_funnel_insight(self):
+        query = "What is the conversion rate of users who uploaded a file to users who paid for a plan?"
+        res = self._call_node(query)
+        self.assertEqual(res, "funnel")
+
+    def test_converts_trends_to_funnel(self):
+        conversation = [
+            HumanMessage(content="Show trends of $pageview and $identify"),
+            RouterMessage(content="trends"),
+            HumanMessage(content="Convert this insight to a funnel"),
+        ]
+        res = self._call_node(conversation[:1])
+        self.assertEqual(res, "trends")
+        res = self._call_node(conversation)
+        self.assertEqual(res, "funnel")
+
+    def test_converts_funnel_to_trends(self):
+        conversation = [
+            HumanMessage(content="What is the conversion from a page view to a sign up?"),
+            RouterMessage(content="funnel"),
+            HumanMessage(content="Convert this insight to a trends"),
+        ]
+        res = self._call_node(conversation[:1])
+        self.assertEqual(res, "funnel")
+        res = self._call_node(conversation)
+        self.assertEqual(res, "trends")
+
+    def test_outputs_single_trends_insight(self):
+        """
+        Must display a trends insight because it's not possible to build a funnel with a single series.
+        """
+        query = "how many users upgraded their plan to personal pro?"
+        res = self._call_node(query)
+        self.assertEqual(res, "trends")
diff --git a/ee/hogai/eval/test_eval_trends_planner.py b/ee/hogai/eval/test_eval_trends_planner.py
new file mode 100644
index 0000000000000..fa12df10ae9d2
--- /dev/null
+++ b/ee/hogai/eval/test_eval_trends_planner.py
@@ -0,0 +1,163 @@
+from deepeval import assert_test
+from deepeval.metrics import GEval
+from deepeval.test_case import LLMTestCase, LLMTestCaseParams
+from langgraph.graph.state import CompiledStateGraph
+
+from ee.hogai.assistant import AssistantGraph
+from ee.hogai.eval.utils import EvalBaseTest
+from ee.hogai.utils import AssistantNodeName
+from posthog.schema import HumanMessage
+
+
+class TestEvalTrendsPlanner(EvalBaseTest):
+    def _get_plan_correctness_metric(self):
+        return GEval(
+            name="Trends Plan Correctness",
+            criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a trends insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about trends insights.",
+            evaluation_steps=[
+                "A plan must define at least one event and a math type, but it is not required to define any filters, breakdowns, or formulas.",
+                "Compare events, properties, math types, and property values of 'expected output' and 'actual output'.",
+                "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.",
+                # The criteria for aggregations must be more specific because there isn't a way to bypass them.
+                "Check if the math types in 'actual output' match those in 'expected output'. Math types sometimes are interchangeable, so use your judgement. If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.",
+                "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different.",
+                "If 'expected output' contains a formula, check if 'actual output' contains a similar formula, and heavily penalize if the formula is not present or different.",
+                # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time.
+                "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.",
+            ],
+            evaluation_params=[
+                LLMTestCaseParams.INPUT,
+                LLMTestCaseParams.EXPECTED_OUTPUT,
+                LLMTestCaseParams.ACTUAL_OUTPUT,
+            ],
+            threshold=0.7,
+        )
+
+    def _call_node(self, query):
+        graph: CompiledStateGraph = (
+            AssistantGraph(self.team)
+            .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_PLANNER)
+            .add_trends_planner(AssistantNodeName.END)
+            .compile()
+        )
+        state = graph.invoke({"messages": [HumanMessage(content=query)]})
+        return state["plan"]
+
+    def test_no_excessive_property_filters(self):
+        query = "Show the $pageview trend"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Events:
+            - $pageview
+                - math operation: total count
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_no_excessive_property_filters_for_a_defined_math_type(self):
+        query = "What is the MAU?"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Events:
+            - $pageview
+                - math operation: unique users
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_basic_filtering(self):
+        query = "can you compare how many Chrome vs Safari users uploaded a file in the last 30d?"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Events:
+            - uploaded_file
+                - math operation: total count
+                - property filter 1:
+                    - entity: event
+                    - property name: $browser
+                    - property type: String
+                    - operator: equals
+                    - property value: Chrome
+                - property filter 2:
+                    - entity: event
+                    - property name: $browser
+                    - property type: String
+                    - operator: equals
+                    - property value: Safari
+
+            Breakdown by:
+            - breakdown 1:
+                - entity: event
+                - property name: $browser
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_formula_mode(self):
+        query = "i want to see a ratio of identify divided by page views"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Events:
+            - $identify
+                - math operation: total count
+            - $pageview
+                - math operation: total count
+
+            Formula:
+            `A/B`, where `A` is the total count of `$identify` and `B` is the total count of `$pageview`
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_math_type_by_a_property(self):
+        query = "what is the average session duration?"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Events:
+            - All Events
+                - math operation: average by `$session_duration`
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_math_type_by_a_user(self):
+        query = "What is the median page view count for a user?"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Events:
+            - $pageview
+                - math operation: median by users
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
+
+    def test_needle_in_a_haystack(self):
+        query = "How frequently do people pay for a personal-pro plan?"
+        test_case = LLMTestCase(
+            input=query,
+            expected_output="""
+            Events:
+            - paid_bill
+                - math operation: total count
+                - property filter 1:
+                    - entity: event
+                    - property name: plan
+                    - property type: String
+                    - operator: contains
+                    - property value: personal/pro
+            """,
+            actual_output=self._call_node(query),
+        )
+        assert_test(test_case, [self._get_plan_correctness_metric()])
diff --git a/ee/hogai/eval/utils.py b/ee/hogai/eval/utils.py
new file mode 100644
index 0000000000000..473b47fe17a84
--- /dev/null
+++ b/ee/hogai/eval/utils.py
@@ -0,0 +1,28 @@
+import datetime as dt
+import os
+
+import pytest
+from flaky import flaky
+
+from posthog.demo.matrix.manager import MatrixManager
+from posthog.tasks.demo_create_data import HedgeboxMatrix
+from posthog.test.base import BaseTest
+
+
+@pytest.mark.skipif(os.environ.get("DEEPEVAL") != "YES", reason="Only runs for the assistant evaluation")
+@flaky(max_runs=3, min_passes=1)
+class EvalBaseTest(BaseTest):
+    @classmethod
+    def setUpTestData(cls):
+        super().setUpTestData()
+        matrix = HedgeboxMatrix(
+            seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef",  # this seed generates all events
+            now=dt.datetime.now(dt.UTC) - dt.timedelta(days=25),
+            days_past=60,
+            days_future=30,
+            n_clusters=60,
+            group_type_index_offset=0,
+        )
+        matrix_manager = MatrixManager(matrix, print_steps=True)
+        existing_user = cls.team.organization.members.first()
+        matrix_manager.run_on_team(cls.team, existing_user)
diff --git a/ee/hogai/schema_generator/nodes.py b/ee/hogai/schema_generator/nodes.py
index 6470c52c4fe08..8845fb6a14a6a 100644
--- a/ee/hogai/schema_generator/nodes.py
+++ b/ee/hogai/schema_generator/nodes.py
@@ -23,7 +23,7 @@
     QUESTION_PROMPT,
 )
 from ee.hogai.schema_generator.utils import SchemaGeneratorOutput
-from ee.hogai.utils import AssistantState, AssistantNode, filter_visualization_conversation
+from ee.hogai.utils import AssistantNode, AssistantState, filter_visualization_conversation
 from posthog.models.group_type_mapping import GroupTypeMapping
 from posthog.schema import (
     FailureMessage,
diff --git a/posthog/api/test/__snapshots__/test_decide.ambr b/posthog/api/test/__snapshots__/test_decide.ambr
index 39c5cf3557402..30fb4248b0b53 100644
--- a/posthog/api/test/__snapshots__/test_decide.ambr
+++ b/posthog/api/test/__snapshots__/test_decide.ambr
@@ -712,6 +712,22 @@
   '''
 # ---
 # name: TestDecide.test_flag_with_behavioural_cohorts.5
+  '''
+  SELECT "posthog_group"."id",
+         "posthog_group"."team_id",
+         "posthog_group"."group_key",
+         "posthog_group"."group_type_index",
+         "posthog_group"."group_properties",
+         "posthog_group"."created_at",
+         "posthog_group"."properties_last_updated_at",
+         "posthog_group"."properties_last_operation",
+         "posthog_group"."version"
+  FROM "posthog_group"
+  WHERE "posthog_group"."team_id" = 99999
+  LIMIT 21
+  '''
+# ---
+# name: TestDecide.test_flag_with_behavioural_cohorts.6
   '''
   SELECT "posthog_cohort"."id",
          "posthog_cohort"."name",
@@ -736,6 +752,22 @@
          AND "posthog_cohort"."team_id" = 99999)
   '''
 # ---
+# name: TestDecide.test_flag_with_behavioural_cohorts.7
+  '''
+  SELECT "posthog_group"."id",
+         "posthog_group"."team_id",
+         "posthog_group"."group_key",
+         "posthog_group"."group_type_index",
+         "posthog_group"."group_properties",
+         "posthog_group"."created_at",
+         "posthog_group"."properties_last_updated_at",
+         "posthog_group"."properties_last_operation",
+         "posthog_group"."version"
+  FROM "posthog_group"
+  WHERE "posthog_group"."team_id" = 99999
+  LIMIT 21
+  '''
+# ---
 # name: TestDecide.test_flag_with_regular_cohorts
   '''
   SELECT "posthog_hogfunction"."id",
diff --git a/posthog/api/test/test_decide.py b/posthog/api/test/test_decide.py
index 99a1a031c087c..ecc40e634a432 100644
--- a/posthog/api/test/test_decide.py
+++ b/posthog/api/test/test_decide.py
@@ -2624,12 +2624,12 @@ def test_flag_with_behavioural_cohorts(self, *args):
             created_by=self.user,
         )
 
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self._post_decide(api_version=3, distinct_id="example_id_1")
             self.assertEqual(response.json()["featureFlags"], {})
             self.assertEqual(response.json()["errorsWhileComputingFlags"], True)
 
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self._post_decide(api_version=3, distinct_id="another_id")
             self.assertEqual(response.json()["featureFlags"], {})
             self.assertEqual(response.json()["errorsWhileComputingFlags"], True)
diff --git a/requirements-dev.in b/requirements-dev.in
index 81484c1a04b9f..4bb51e1f7bfb4 100644
--- a/requirements-dev.in
+++ b/requirements-dev.in
@@ -56,3 +56,4 @@ flaky==3.7.0
 aioresponses==0.7.6
 prance==23.06.21.0
 openapi-spec-validator==0.7.1 # Needed for prance as a validation backend
+deepeval==1.5.5
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 029a1ebf01e7a..7586c83bfdf8d 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -4,6 +4,10 @@ aiohttp==3.9.3
     # via
     #   -c requirements.txt
     #   aioresponses
+    #   datasets
+    #   fsspec
+    #   langchain
+    #   langchain-community
 aioresponses==0.7.6
     # via -r requirements-dev.in
 aiosignal==1.2.0
@@ -14,6 +18,13 @@ annotated-types==0.7.0
     # via
     #   -c requirements.txt
     #   pydantic
+anyio==4.6.2.post1
+    # via
+    #   -c requirements.txt
+    #   httpx
+    #   openai
+appdirs==1.4.4
+    # via ragas
 argcomplete==2.0.0
     # via datamodel-code-generator
 asgiref==3.7.2
@@ -45,7 +56,10 @@ botocore-stubs==1.34.84
 certifi==2019.11.28
     # via
     #   -c requirements.txt
+    #   httpcore
+    #   httpx
     #   requests
+    #   sentry-sdk
 cffi==1.16.0
     # via
     #   -c requirements.txt
@@ -61,6 +75,7 @@ click==8.1.7
     #   -c requirements.txt
     #   black
     #   inline-snapshot
+    #   typer
 colorama==0.4.4
     # via pytest-watch
 coverage==5.5
@@ -69,8 +84,26 @@ cryptography==39.0.2
     # via
     #   -c requirements.txt
     #   types-paramiko
+dataclasses-json==0.6.7
+    # via langchain-community
 datamodel-code-generator==0.26.1
     # via -r requirements-dev.in
+datasets==2.19.1
+    # via ragas
+deepeval==1.5.5
+    # via -r requirements-dev.in
+deprecated==1.2.15
+    # via
+    #   opentelemetry-api
+    #   opentelemetry-exporter-otlp-proto-grpc
+dill==0.3.8
+    # via
+    #   datasets
+    #   multiprocess
+distro==1.9.0
+    # via
+    #   -c requirements.txt
+    #   openai
 django==4.2.15
     # via
     #   -c requirements.txt
@@ -93,6 +126,8 @@ dnspython==2.2.1
     #   email-validator
 docopt==0.6.2
     # via pytest-watch
+docx2txt==0.8
+    # via deepeval
 email-validator==2.0.0.post2
     # via pydantic
 execnet==2.1.1
@@ -103,6 +138,11 @@ faker==17.5.0
     # via -r requirements-dev.in
 fakeredis==2.23.3
     # via -r requirements-dev.in
+filelock==3.12.0
+    # via
+    #   -c requirements.txt
+    #   datasets
+    #   huggingface-hub
 flaky==3.7.0
     # via -r requirements-dev.in
 freezegun==1.2.2
@@ -112,16 +152,51 @@ frozenlist==1.4.1
     #   -c requirements.txt
     #   aiohttp
     #   aiosignal
+fsspec==2023.10.0
+    # via
+    #   -c requirements.txt
+    #   datasets
+    #   huggingface-hub
 genson==1.2.2
     # via datamodel-code-generator
+googleapis-common-protos==1.60.0
+    # via
+    #   -c requirements.txt
+    #   opentelemetry-exporter-otlp-proto-grpc
+grpcio==1.63.2
+    # via
+    #   -c requirements.txt
+    #   deepeval
+    #   opentelemetry-exporter-otlp-proto-grpc
+h11==0.13.0
+    # via
+    #   -c requirements.txt
+    #   httpcore
+httpcore==1.0.2
+    # via
+    #   -c requirements.txt
+    #   httpx
+httpx==0.26.0
+    # via
+    #   -c requirements.txt
+    #   langsmith
+    #   openai
+huggingface-hub==0.26.2
+    # via datasets
 icdiff==2.0.5
     # via pytest-icdiff
 idna==3.10
     # via
     #   -c requirements.txt
+    #   anyio
     #   email-validator
+    #   httpx
     #   requests
     #   yarl
+importlib-metadata==7.0.0
+    # via
+    #   deepeval
+    #   opentelemetry-api
 inflect==5.6.2
     # via datamodel-code-generator
 iniconfig==1.1.1
@@ -132,6 +207,18 @@ isort==5.2.2
     # via datamodel-code-generator
 jinja2==3.1.4
     # via datamodel-code-generator
+jiter==0.5.0
+    # via
+    #   -c requirements.txt
+    #   openai
+jsonpatch==1.33
+    # via
+    #   -c requirements.txt
+    #   langchain-core
+jsonpointer==3.0.0
+    # via
+    #   -c requirements.txt
+    #   jsonpatch
 jsonschema==4.20.0
     # via
     #   -c requirements.txt
@@ -144,6 +231,38 @@ jsonschema-specifications==2023.12.1
     #   -c requirements.txt
     #   jsonschema
     #   openapi-schema-validator
+langchain==0.3.3
+    # via
+    #   -c requirements.txt
+    #   deepeval
+    #   langchain-community
+    #   ragas
+langchain-community==0.3.2
+    # via ragas
+langchain-core==0.3.10
+    # via
+    #   -c requirements.txt
+    #   deepeval
+    #   langchain
+    #   langchain-community
+    #   langchain-openai
+    #   langchain-text-splitters
+    #   ragas
+langchain-openai==0.2.2
+    # via
+    #   -c requirements.txt
+    #   deepeval
+    #   ragas
+langchain-text-splitters==0.3.0
+    # via
+    #   -c requirements.txt
+    #   langchain
+langsmith==0.1.132
+    # via
+    #   -c requirements.txt
+    #   langchain
+    #   langchain-community
+    #   langchain-core
 lazy-object-proxy==1.10.0
     # via openapi-spec-validator
 lupa==2.2
@@ -152,6 +271,8 @@ markdown-it-py==3.0.0
     # via rich
 markupsafe==2.1.5
     # via jinja2
+marshmallow==3.23.1
+    # via dataclasses-json
 mdurl==0.1.2
     # via markdown-it-py
 multidict==6.0.2
@@ -159,6 +280,8 @@ multidict==6.0.2
     #   -c requirements.txt
     #   aiohttp
     #   yarl
+multiprocess==0.70.16
+    # via datasets
 mypy==1.11.1
     # via -r requirements-dev.in
 mypy-baseline==0.7.0
@@ -170,18 +293,66 @@ mypy-extensions==1.0.0
     #   -r requirements-dev.in
     #   black
     #   mypy
+    #   typing-inspect
+nest-asyncio==1.6.0
+    # via ragas
+numpy==1.23.3
+    # via
+    #   -c requirements.txt
+    #   datasets
+    #   langchain
+    #   langchain-community
+    #   pandas
+    #   pyarrow
+    #   ragas
+openai==1.51.2
+    # via
+    #   -c requirements.txt
+    #   langchain-openai
+    #   ragas
 openapi-schema-validator==0.6.2
     # via openapi-spec-validator
 openapi-spec-validator==0.7.1
     # via -r requirements-dev.in
+opentelemetry-api==1.24.0
+    # via
+    #   deepeval
+    #   opentelemetry-exporter-otlp-proto-grpc
+    #   opentelemetry-sdk
+opentelemetry-exporter-otlp-proto-common==1.24.0
+    # via opentelemetry-exporter-otlp-proto-grpc
+opentelemetry-exporter-otlp-proto-grpc==1.24.0
+    # via deepeval
+opentelemetry-proto==1.24.0
+    # via
+    #   opentelemetry-exporter-otlp-proto-common
+    #   opentelemetry-exporter-otlp-proto-grpc
+opentelemetry-sdk==1.24.0
+    # via
+    #   deepeval
+    #   opentelemetry-exporter-otlp-proto-grpc
+opentelemetry-semantic-conventions==0.45b0
+    # via opentelemetry-sdk
+orjson==3.10.7
+    # via
+    #   -c requirements.txt
+    #   langsmith
 packaging==24.1
     # via
     #   -c requirements.txt
     #   -r requirements-dev.in
     #   black
     #   datamodel-code-generator
+    #   datasets
+    #   huggingface-hub
+    #   langchain-core
+    #   marshmallow
     #   prance
     #   pytest
+pandas==2.2.0
+    # via
+    #   -c requirements.txt
+    #   datasets
 parameterized==0.9.0
     # via -r requirements-dev.in
 pathable==0.4.3
@@ -196,10 +367,24 @@ pluggy==1.5.0
     # via
     #   -c requirements.txt
     #   pytest
+portalocker==2.10.1
+    # via deepeval
 pprintpp==0.4.0
     # via pytest-icdiff
 prance==23.6.21.0
     # via -r requirements-dev.in
+protobuf==4.22.1
+    # via
+    #   -c requirements.txt
+    #   deepeval
+    #   googleapis-common-protos
+    #   opentelemetry-proto
+pyarrow==17.0.0
+    # via
+    #   -c requirements.txt
+    #   datasets
+pyarrow-hotfix==0.6
+    # via datasets
 pycparser==2.20
     # via
     #   -c requirements.txt
@@ -208,21 +393,34 @@ pydantic==2.9.2
     # via
     #   -c requirements.txt
     #   datamodel-code-generator
+    #   deepeval
+    #   langchain
+    #   langchain-core
+    #   langsmith
+    #   openai
+    #   pydantic-settings
+    #   ragas
 pydantic-core==2.23.4
     # via
     #   -c requirements.txt
     #   pydantic
+pydantic-settings==2.6.1
+    # via langchain-community
 pygments==2.18.0
     # via rich
+pysbd==0.3.4
+    # via ragas
 pytest==8.0.2
     # via
     #   -r requirements-dev.in
+    #   deepeval
     #   pytest-asyncio
     #   pytest-cov
     #   pytest-django
     #   pytest-env
     #   pytest-icdiff
     #   pytest-mock
+    #   pytest-repeat
     #   pytest-split
     #   pytest-watch
     #   pytest-xdist
@@ -239,24 +437,44 @@ pytest-icdiff==0.6
     # via -r requirements-dev.in
 pytest-mock==3.11.1
     # via -r requirements-dev.in
+pytest-repeat==0.9.3
+    # via deepeval
 pytest-split==0.9.0
     # via -r requirements-dev.in
 pytest-watch==4.2.0
     # via -r requirements-dev.in
 pytest-xdist==3.6.1
-    # via -r requirements-dev.in
+    # via
+    #   -r requirements-dev.in
+    #   deepeval
 python-dateutil==2.8.2
     # via
     #   -c requirements.txt
     #   -r requirements-dev.in
     #   faker
     #   freezegun
+    #   pandas
+python-dotenv==0.21.0
+    # via
+    #   -c requirements.txt
+    #   pydantic-settings
+pytz==2023.3
+    # via
+    #   -c requirements.txt
+    #   pandas
 pyyaml==6.0.1
     # via
     #   -c requirements.txt
     #   datamodel-code-generator
+    #   datasets
+    #   huggingface-hub
     #   jsonschema-path
+    #   langchain
+    #   langchain-community
+    #   langchain-core
     #   responses
+ragas==0.2.5
+    # via deepeval
 redis==4.5.4
     # via
     #   -c requirements.txt
@@ -267,19 +485,39 @@ referencing==0.31.1
     #   jsonschema
     #   jsonschema-path
     #   jsonschema-specifications
+regex==2023.12.25
+    # via
+    #   -c requirements.txt
+    #   tiktoken
 requests==2.32.0
     # via
     #   -c requirements.txt
+    #   datasets
+    #   deepeval
     #   djangorestframework-stubs
+    #   fsspec
+    #   huggingface-hub
     #   jsonschema-path
+    #   langchain
+    #   langchain-community
+    #   langsmith
     #   prance
+    #   requests-toolbelt
     #   responses
+    #   tiktoken
+requests-toolbelt==1.0.0
+    # via
+    #   -c requirements.txt
+    #   langsmith
 responses==0.23.1
     # via -r requirements-dev.in
 rfc3339-validator==0.1.4
     # via openapi-schema-validator
 rich==13.7.1
-    # via inline-snapshot
+    # via
+    #   deepeval
+    #   inline-snapshot
+    #   typer
 rpds-py==0.16.2
     # via
     #   -c requirements.txt
@@ -291,6 +529,12 @@ ruamel-yaml-clib==0.2.8
     # via ruamel-yaml
 ruff==0.6.1
     # via -r requirements-dev.in
+sentry-sdk==1.44.1
+    # via
+    #   -c requirements.txt
+    #   deepeval
+shellingham==1.5.4
+    # via typer
 six==1.16.0
     # via
     #   -c requirements.txt
@@ -298,20 +542,54 @@ six==1.16.0
     #   prance
     #   python-dateutil
     #   rfc3339-validator
+sniffio==1.3.1
+    # via
+    #   -c requirements.txt
+    #   anyio
+    #   httpx
+    #   openai
 sortedcontainers==2.4.0
     # via
     #   -c requirements.txt
     #   fakeredis
+sqlalchemy==2.0.31
+    # via
+    #   -c requirements.txt
+    #   langchain
+    #   langchain-community
 sqlparse==0.4.4
     # via
     #   -c requirements.txt
     #   django
 syrupy==4.6.4
     # via -r requirements-dev.in
+tabulate==0.9.0
+    # via deepeval
+tenacity==8.4.2
+    # via
+    #   -c requirements.txt
+    #   deepeval
+    #   langchain
+    #   langchain-community
+    #   langchain-core
+tiktoken==0.8.0
+    # via
+    #   -c requirements.txt
+    #   langchain-openai
+    #   ragas
 toml==0.10.2
     # via
     #   coverage
     #   inline-snapshot
+tqdm==4.64.1
+    # via
+    #   -c requirements.txt
+    #   datasets
+    #   deepeval
+    #   huggingface-hub
+    #   openai
+typer==0.13.0
+    # via deepeval
 types-awscrt==0.20.9
     # via botocore-stubs
 types-freezegun==1.1.10
@@ -355,21 +633,43 @@ typing-extensions==4.12.2
     #   django-stubs
     #   django-stubs-ext
     #   djangorestframework-stubs
+    #   huggingface-hub
     #   inline-snapshot
+    #   langchain-core
     #   mypy
     #   mypy-boto3-s3
+    #   openai
+    #   opentelemetry-sdk
     #   pydantic
     #   pydantic-core
+    #   sqlalchemy
+    #   typer
+    #   typing-inspect
+typing-inspect==0.9.0
+    # via dataclasses-json
+tzdata==2023.3
+    # via
+    #   -c requirements.txt
+    #   pandas
 urllib3==1.26.18
     # via
     #   -c requirements.txt
     #   requests
     #   responses
+    #   sentry-sdk
 watchdog==2.1.8
     # via
     #   -r requirements-dev.in
     #   pytest-watch
+wrapt==1.15.0
+    # via
+    #   -c requirements.txt
+    #   deprecated
+xxhash==3.5.0
+    # via datasets
 yarl==1.9.4
     # via
     #   -c requirements.txt
     #   aiohttp
+zipp==3.21.0
+    # via importlib-metadata
diff --git a/requirements.in b/requirements.in
index 3696df35d43d1..363915666d010 100644
--- a/requirements.in
+++ b/requirements.in
@@ -111,3 +111,5 @@ zxcvbn==4.4.28
 zstd==1.5.5.1
 xmlsec==1.3.13 # Do not change this version - it will break SAML
 lxml==4.9.4 # Do not change this version - it will break SAML
+grpcio~=1.63.2 # Version constrained so that `deepeval` can be installed in in dev
+tenacity~=8.4.2  # Version constrained so that `deepeval` can be installed in in dev
diff --git a/requirements.txt b/requirements.txt
index c276d7a792904..9da400cf92fb0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -267,8 +267,9 @@ googleapis-common-protos==1.60.0
     # via
     #   google-api-core
     #   grpcio-status
-grpcio==1.57.0
+grpcio==1.63.2
     # via
+    #   -r requirements.in
     #   google-api-core
     #   grpcio-status
     #   sqlalchemy-bigquery
@@ -702,8 +703,9 @@ structlog==23.2.0
     #   django-structlog
 temporalio==1.7.1
     # via -r requirements.in
-tenacity==8.2.3
+tenacity==8.4.2
     # via
+    #   -r requirements.in
     #   celery-redbeat
     #   dlt
     #   langchain