From e3653cc278b5342a977979dd7fcae9b58803f9cf Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Thu, 19 Dec 2024 18:24:55 +0100 Subject: [PATCH] chore(product-assistant): speed up evaluation tests v2 (#27062) --- ee/hogai/eval/conftest.py | 121 +++++- .../eval/tests/test_eval_funnel_generator.py | 58 +-- .../eval/tests/test_eval_funnel_planner.py | 398 +++++++++--------- ee/hogai/eval/tests/test_eval_router.py | 119 +++--- .../eval/tests/test_eval_trends_generator.py | 95 +++-- .../eval/tests/test_eval_trends_planner.py | 339 ++++++++------- ee/hogai/eval/utils.py | 40 -- posthog/conftest.py | 23 +- posthog/settings/__init__.py | 1 + 9 files changed, 656 insertions(+), 538 deletions(-) delete mode 100644 ee/hogai/eval/utils.py diff --git a/ee/hogai/eval/conftest.py b/ee/hogai/eval/conftest.py index d0bc75348eeac..1a88ebffa2e33 100644 --- a/ee/hogai/eval/conftest.py +++ b/ee/hogai/eval/conftest.py @@ -1,28 +1,107 @@ +import functools +from collections.abc import Generator +from pathlib import Path + import pytest +from django.conf import settings +from django.test import override_settings +from langchain_core.runnables import RunnableConfig + +from ee.models import Conversation +from posthog.demo.matrix.manager import MatrixManager +from posthog.models import Organization, Project, Team, User +from posthog.tasks.demo_create_data import HedgeboxMatrix +from posthog.test.base import BaseTest + + +# Flaky is a handy tool, but it always runs setup fixtures for retries. +# This decorator will just retry without re-running setup. +def retry_test_only(max_retries=3): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + last_error: Exception | None = None + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + last_error = e + print(f"\nRetrying test (attempt {attempt + 1}/{max_retries})...") # noqa + if last_error: + raise last_error + + return wrapper + + return decorator + + +# Apply decorators to all tests in the package. +def pytest_collection_modifyitems(items): + current_dir = Path(__file__).parent + for item in items: + if Path(item.fspath).is_relative_to(current_dir): + item.add_marker( + pytest.mark.skipif(not settings.IN_EVAL_TESTING, reason="Only runs for the assistant evaluation") + ) + # Apply our custom retry decorator to the test function + item.obj = retry_test_only(max_retries=3)(item.obj) + + +@pytest.fixture(scope="package") +def team(django_db_blocker) -> Generator[Team, None, None]: + with django_db_blocker.unblock(): + organization = Organization.objects.create(name=BaseTest.CONFIG_ORGANIZATION_NAME) + project = Project.objects.create(id=Team.objects.increment_id_sequence(), organization=organization) + team = Team.objects.create( + id=project.id, + project=project, + organization=organization, + test_account_filters=[ + { + "key": "email", + "value": "@posthog.com", + "operator": "not_icontains", + "type": "person", + } + ], + has_completed_onboarding_for={"product_analytics": True}, + ) + yield team + organization.delete() -from posthog.test.base import run_clickhouse_statement_in_parallel +@pytest.fixture(scope="package") +def user(team, django_db_blocker) -> Generator[User, None, None]: + with django_db_blocker.unblock(): + user = User.objects.create_and_join(team.organization, "eval@posthog.com", "password1234") + yield user + user.delete() -@pytest.fixture(scope="module", autouse=True) -def setup_kafka_tables(django_db_setup): - from posthog.clickhouse.client import sync_execute - from posthog.clickhouse.schema import ( - CREATE_KAFKA_TABLE_QUERIES, - build_query, - ) - from posthog.settings import CLICKHOUSE_CLUSTER, CLICKHOUSE_DATABASE - kafka_queries = list(map(build_query, CREATE_KAFKA_TABLE_QUERIES)) - run_clickhouse_statement_in_parallel(kafka_queries) +@pytest.mark.django_db(transaction=True) +@pytest.fixture +def runnable_config(team, user) -> Generator[RunnableConfig, None, None]: + conversation = Conversation.objects.create(team=team, user=user) + yield { + "configurable": { + "thread_id": conversation.id, + } + } + conversation.delete() - yield - kafka_tables = sync_execute( - f""" - SELECT name - FROM system.tables - WHERE database = '{CLICKHOUSE_DATABASE}' AND name LIKE 'kafka_%' - """, - ) - kafka_truncate_queries = [f"DROP TABLE {table[0]} ON CLUSTER '{CLICKHOUSE_CLUSTER}'" for table in kafka_tables] - run_clickhouse_statement_in_parallel(kafka_truncate_queries) +@pytest.fixture(scope="package", autouse=True) +def setup_test_data(django_db_setup, team, user, django_db_blocker): + with django_db_blocker.unblock(): + matrix = HedgeboxMatrix( + seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef", # this seed generates all events + days_past=120, + days_future=30, + n_clusters=500, + group_type_index_offset=0, + ) + matrix_manager = MatrixManager(matrix, print_steps=True) + with override_settings(TEST=False): + # Simulation saving should occur in non-test mode, so that Kafka isn't mocked. Normally in tests we don't + # want to ingest via Kafka, but simulation saving is specifically designed to use that route for speed + matrix_manager.run_on_team(team, user) diff --git a/ee/hogai/eval/tests/test_eval_funnel_generator.py b/ee/hogai/eval/tests/test_eval_funnel_generator.py index 4d7876ca6f73c..5f0f29243296a 100644 --- a/ee/hogai/eval/tests/test_eval_funnel_generator.py +++ b/ee/hogai/eval/tests/test_eval_funnel_generator.py @@ -1,40 +1,46 @@ +from collections.abc import Callable from typing import cast +import pytest from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph -from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import AssistantFunnelsQuery, HumanMessage, VisualizationMessage -class TestEvalFunnelGenerator(EvalBaseTest): - def _call_node(self, query: str, plan: str) -> AssistantFunnelsQuery: - graph: CompiledStateGraph = ( - AssistantGraph(self.team) - .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_GENERATOR) - .add_funnel_generator(AssistantNodeName.END) - .compile() - ) +@pytest.fixture +def call_node(team, runnable_config) -> Callable[[str, str], AssistantFunnelsQuery]: + graph: CompiledStateGraph = ( + AssistantGraph(team) + .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_GENERATOR) + .add_funnel_generator(AssistantNodeName.END) + .compile() + ) + + def callable(query: str, plan: str) -> AssistantFunnelsQuery: state = graph.invoke( AssistantState(messages=[HumanMessage(content=query)], plan=plan), - self._get_config(), + runnable_config, ) return cast(VisualizationMessage, AssistantState.model_validate(state).messages[-1]).answer - def test_node_replaces_equals_with_contains(self): - query = "what is the conversion rate from a page view to sign up for users with name John?" - plan = """Sequence: - 1. $pageview - - property filter 1 - - person - - name - - equals - - John - 2. signed_up - """ - actual_output = self._call_node(query, plan).model_dump_json(exclude_none=True) - assert "exact" not in actual_output - assert "icontains" in actual_output - assert "John" not in actual_output - assert "john" in actual_output + return callable + + +def test_node_replaces_equals_with_contains(call_node): + query = "what is the conversion rate from a page view to sign up for users with name John?" + plan = """Sequence: + 1. $pageview + - property filter 1 + - person + - name + - equals + - John + 2. signed_up + """ + actual_output = call_node(query, plan).model_dump_json(exclude_none=True) + assert "exact" not in actual_output + assert "icontains" in actual_output + assert "John" not in actual_output + assert "john" in actual_output diff --git a/ee/hogai/eval/tests/test_eval_funnel_planner.py b/ee/hogai/eval/tests/test_eval_funnel_planner.py index 9adbd75e77c6c..c8bc25bc0b5dc 100644 --- a/ee/hogai/eval/tests/test_eval_funnel_planner.py +++ b/ee/hogai/eval/tests/test_eval_funnel_planner.py @@ -1,208 +1,224 @@ +from collections.abc import Callable + +import pytest from deepeval import assert_test from deepeval.metrics import GEval from deepeval.test_case import LLMTestCase, LLMTestCaseParams +from langchain_core.runnables.config import RunnableConfig from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph -from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import HumanMessage -class TestEvalFunnelPlanner(EvalBaseTest): - def _get_plan_correctness_metric(self): - return GEval( - name="Funnel Plan Correctness", - criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a funnel insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about funnel insights.", - evaluation_steps=[ - "A plan must define at least two series in the sequence, but it is not required to define any filters, exclusion steps, or a breakdown.", - "Compare events, properties, math types, and property values of 'expected output' and 'actual output'. Do not penalize if the actual output does not include a timeframe.", - "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.", - # The criteria for aggregations must be more specific because there isn't a way to bypass them. - "Check if the math types in 'actual output' match those in 'expected output.' If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.", - "If 'expected output' contains exclusion steps, check if 'actual output' contains those, and heavily penalize if the exclusion steps are not present or different.", - "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different. Plans may only have one breakdown.", - # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time. - "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.", - ], - evaluation_params=[ - LLMTestCaseParams.INPUT, - LLMTestCaseParams.EXPECTED_OUTPUT, - LLMTestCaseParams.ACTUAL_OUTPUT, - ], - threshold=0.7, - ) +@pytest.fixture(scope="module") +def metric(): + return GEval( + name="Funnel Plan Correctness", + criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a funnel insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about funnel insights.", + evaluation_steps=[ + "A plan must define at least two series in the sequence, but it is not required to define any filters, exclusion steps, or a breakdown.", + "Compare events, properties, math types, and property values of 'expected output' and 'actual output'. Do not penalize if the actual output does not include a timeframe.", + "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.", + # The criteria for aggregations must be more specific because there isn't a way to bypass them. + "Check if the math types in 'actual output' match those in 'expected output.' If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.", + "If 'expected output' contains exclusion steps, check if 'actual output' contains those, and heavily penalize if the exclusion steps are not present or different.", + "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different. Plans may only have one breakdown.", + # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time. + "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.", + ], + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.EXPECTED_OUTPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + ], + threshold=0.7, + ) - def _call_node(self, query): - graph: CompiledStateGraph = ( - AssistantGraph(self.team) - .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER) - .add_funnel_planner(AssistantNodeName.END) - .compile() - ) + +@pytest.fixture +def call_node(team, runnable_config: RunnableConfig) -> Callable[[str], str]: + graph: CompiledStateGraph = ( + AssistantGraph(team) + .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER) + .add_funnel_planner(AssistantNodeName.END) + .compile() + ) + + def callable(query: str) -> str: state = graph.invoke( AssistantState(messages=[HumanMessage(content=query)]), - self._get_config(), + runnable_config, ) return AssistantState.model_validate(state).plan or "" - def test_basic_funnel(self): - query = "what was the conversion from a page view to sign up?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. $pageview - 2. signed_up - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_outputs_at_least_two_events(self): - """ - Ambigious query. The funnel must return at least two events. - """ - query = "how many users paid a bill?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. any event - 2. upgrade_plan - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_no_excessive_property_filters(self): - query = "Show the user conversion from a sign up to a file download" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. signed_up - 2. downloaded_file - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) + return callable - def test_basic_filtering(self): - query = ( - "What was the conversion from uploading a file to downloading it from Chrome and Safari in the last 30d?" - ) - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. uploaded_file - - property filter 1: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Chrome - - property filter 2: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Safari - 2. downloaded_file - - property filter 1: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Chrome - - property filter 2: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Safari - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_exclusion_steps(self): - query = "What was the conversion from uploading a file to downloading it in the last 30d excluding users that deleted a file?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. uploaded_file - 2. downloaded_file - - Exclusions: - - deleted_file - - start index: 0 - - end index: 1 - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_breakdown(self): - query = "Show a conversion from uploading a file to downloading it segmented by a browser" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. uploaded_file - 2. downloaded_file - - Breakdown by: - - entity: event - - property name: $browser - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_needle_in_a_haystack(self): - query = "What was the conversion from a sign up to a paying customer on the personal-pro plan?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. signed_up - 2. paid_bill - - property filter 1: - - entity: event - - property name: plan - - property type: String - - operator: equals - - property value: personal/pro - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_planner_outputs_multiple_series_from_a_single_series_question(self): - query = "What's our sign-up funnel?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. $pageview - 2. signed_up - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_funnel_does_not_include_timeframe(self): - query = "what was the conversion from a page view to sign up for event time before 2024-01-01?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. $pageview - 2. signed_up - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) + +def test_basic_funnel(metric, call_node): + query = "what was the conversion from a page view to sign up?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. $pageview + 2. signed_up + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_outputs_at_least_two_events(metric, call_node): + """ + Ambigious query. The funnel must return at least two events. + """ + query = "how many users paid a bill?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. any event + 2. upgrade_plan + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_no_excessive_property_filters(metric, call_node): + query = "Show the user conversion from a sign up to a file download" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. signed_up + 2. downloaded_file + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_basic_filtering(metric, call_node): + query = "What was the conversion from uploading a file to downloading it from Chrome and Safari in the last 30d?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. uploaded_file + - property filter 1: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Chrome + - property filter 2: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Safari + 2. downloaded_file + - property filter 1: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Chrome + - property filter 2: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Safari + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_exclusion_steps(metric, call_node): + query = "What was the conversion from uploading a file to downloading it in the last 30d excluding users that deleted a file?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. uploaded_file + 2. downloaded_file + + Exclusions: + - deleted_file + - start index: 0 + - end index: 1 + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_breakdown(metric, call_node): + query = "Show a conversion from uploading a file to downloading it segmented by a browser" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. uploaded_file + 2. downloaded_file + + Breakdown by: + - entity: event + - property name: $browser + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_needle_in_a_haystack(metric, call_node): + query = "What was the conversion from a sign up to a paying customer on the personal-pro plan?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. signed_up + 2. paid_bill + - property filter 1: + - entity: event + - property name: plan + - property type: String + - operator: equals + - property value: personal/pro + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_planner_outputs_multiple_series_from_a_single_series_question(metric, call_node): + query = "What's our sign-up funnel?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. $pageview + 2. signed_up + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_funnel_does_not_include_timeframe(metric, call_node): + query = "what was the conversion from a page view to sign up for event time before 2024-01-01?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. $pageview + 2. signed_up + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) diff --git a/ee/hogai/eval/tests/test_eval_router.py b/ee/hogai/eval/tests/test_eval_router.py index c1307e9d40f00..84e5c4c809972 100644 --- a/ee/hogai/eval/tests/test_eval_router.py +++ b/ee/hogai/eval/tests/test_eval_router.py @@ -1,69 +1,80 @@ +from collections.abc import Callable from typing import cast +import pytest from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph -from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import HumanMessage, RouterMessage -class TestEvalRouter(EvalBaseTest): - def _call_node(self, query: str | list): - graph: CompiledStateGraph = ( - AssistantGraph(self.team) - .add_start() - .add_router(path_map={"trends": AssistantNodeName.END, "funnel": AssistantNodeName.END}) - .compile() - ) +@pytest.fixture +def call_node(team, runnable_config) -> Callable[[str | list], str]: + graph: CompiledStateGraph = ( + AssistantGraph(team) + .add_start() + .add_router(path_map={"trends": AssistantNodeName.END, "funnel": AssistantNodeName.END}) + .compile() + ) + + def callable(query: str | list) -> str: messages = [HumanMessage(content=query)] if isinstance(query, str) else query state = graph.invoke( AssistantState(messages=messages), - self._get_config(), + runnable_config, ) return cast(RouterMessage, AssistantState.model_validate(state).messages[-1]).content - def test_outputs_basic_trends_insight(self): - query = "Show the $pageview trend" - res = self._call_node(query) - self.assertEqual(res, "trends") - - def test_outputs_basic_funnel_insight(self): - query = "What is the conversion rate of users who uploaded a file to users who paid for a plan?" - res = self._call_node(query) - self.assertEqual(res, "funnel") - - def test_converts_trends_to_funnel(self): - conversation = [ - HumanMessage(content="Show trends of $pageview and $identify"), - RouterMessage(content="trends"), - HumanMessage(content="Convert this insight to a funnel"), - ] - res = self._call_node(conversation[:1]) - self.assertEqual(res, "trends") - res = self._call_node(conversation) - self.assertEqual(res, "funnel") - - def test_converts_funnel_to_trends(self): - conversation = [ - HumanMessage(content="What is the conversion from a page view to a sign up?"), - RouterMessage(content="funnel"), - HumanMessage(content="Convert this insight to a trends"), - ] - res = self._call_node(conversation[:1]) - self.assertEqual(res, "funnel") - res = self._call_node(conversation) - self.assertEqual(res, "trends") - - def test_outputs_single_trends_insight(self): - """ - Must display a trends insight because it's not possible to build a funnel with a single series. - """ - query = "how many users upgraded their plan to personal pro?" - res = self._call_node(query) - self.assertEqual(res, "trends") - - def test_classifies_funnel_with_single_series(self): - query = "What's our sign-up funnel?" - res = self._call_node(query) - self.assertEqual(res, "funnel") + return callable + + +def test_outputs_basic_trends_insight(call_node): + query = "Show the $pageview trend" + res = call_node(query) + assert res == "trends" + + +def test_outputs_basic_funnel_insight(call_node): + query = "What is the conversion rate of users who uploaded a file to users who paid for a plan?" + res = call_node(query) + assert res == "funnel" + + +def test_converts_trends_to_funnel(call_node): + conversation = [ + HumanMessage(content="Show trends of $pageview and $identify"), + RouterMessage(content="trends"), + HumanMessage(content="Convert this insight to a funnel"), + ] + res = call_node(conversation[:1]) + assert res == "trends" + res = call_node(conversation) + assert res == "funnel" + + +def test_converts_funnel_to_trends(call_node): + conversation = [ + HumanMessage(content="What is the conversion from a page view to a sign up?"), + RouterMessage(content="funnel"), + HumanMessage(content="Convert this insight to a trends"), + ] + res = call_node(conversation[:1]) + assert res == "funnel" + res = call_node(conversation) + assert res == "trends" + + +def test_outputs_single_trends_insight(call_node): + """ + Must display a trends insight because it's not possible to build a funnel with a single series. + """ + query = "how many users upgraded their plan to personal pro?" + res = call_node(query) + assert res == "trends" + + +def test_classifies_funnel_with_single_series(call_node): + query = "What's our sign-up funnel?" + res = call_node(query) + assert res == "funnel" diff --git a/ee/hogai/eval/tests/test_eval_trends_generator.py b/ee/hogai/eval/tests/test_eval_trends_generator.py index 496bbf0100b51..c8491957c868f 100644 --- a/ee/hogai/eval/tests/test_eval_trends_generator.py +++ b/ee/hogai/eval/tests/test_eval_trends_generator.py @@ -1,58 +1,65 @@ +from collections.abc import Callable from typing import cast +import pytest from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph -from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import AssistantTrendsQuery, HumanMessage, VisualizationMessage -class TestEvalTrendsGenerator(EvalBaseTest): - def _call_node(self, query: str, plan: str) -> AssistantTrendsQuery: - graph: CompiledStateGraph = ( - AssistantGraph(self.team) - .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR) - .add_trends_generator(AssistantNodeName.END) - .compile() - ) +@pytest.fixture +def call_node(team, runnable_config) -> Callable[[str, str], AssistantTrendsQuery]: + graph: CompiledStateGraph = ( + AssistantGraph(team) + .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR) + .add_trends_generator(AssistantNodeName.END) + .compile() + ) + + def callable(query: str, plan: str) -> AssistantTrendsQuery: state = graph.invoke( AssistantState(messages=[HumanMessage(content=query)], plan=plan), - self._get_config(), + runnable_config, ) return cast(VisualizationMessage, AssistantState.model_validate(state).messages[-1]).answer - def test_node_replaces_equals_with_contains(self): - query = "what is pageview trend for users with name John?" - plan = """Events: - - $pageview - - math operation: total count - - property filter 1 - - person - - name - - equals - - John - """ - actual_output = self._call_node(query, plan).model_dump_json(exclude_none=True) - assert "exact" not in actual_output - assert "icontains" in actual_output - assert "John" not in actual_output - assert "john" in actual_output - - def test_node_leans_towards_line_graph(self): - query = "How often do users download files?" - # We ideally want to consider both total count of downloads per period, as well as how often a median user downloads - plan = """Events: - - downloaded_file - - math operation: total count - - downloaded_file - - math operation: median count per user - """ - actual_output = self._call_node(query, plan) - assert actual_output.trendsFilter.display == "ActionsLineGraph" - assert actual_output.series[0].kind == "EventsNode" - assert actual_output.series[0].event == "downloaded_file" - assert actual_output.series[0].math == "total" - assert actual_output.series[1].kind == "EventsNode" - assert actual_output.series[1].event == "downloaded_file" - assert actual_output.series[1].math == "median_count_per_actor" + return callable + + +def test_node_replaces_equals_with_contains(call_node): + query = "what is pageview trend for users with name John?" + plan = """Events: + - $pageview + - math operation: total count + - property filter 1 + - person + - name + - equals + - John + """ + actual_output = call_node(query, plan).model_dump_json(exclude_none=True) + assert "exact" not in actual_output + assert "icontains" in actual_output + assert "John" not in actual_output + assert "john" in actual_output + + +def test_node_leans_towards_line_graph(call_node): + query = "How often do users download files?" + # We ideally want to consider both total count of downloads per period, as well as how often a median user downloads + plan = """Events: + - downloaded_file + - math operation: total count + - downloaded_file + - math operation: median count per user + """ + actual_output = call_node(query, plan) + assert actual_output.trendsFilter.display == "ActionsLineGraph" + assert actual_output.series[0].kind == "EventsNode" + assert actual_output.series[0].event == "downloaded_file" + assert actual_output.series[0].math == "total" + assert actual_output.series[1].kind == "EventsNode" + assert actual_output.series[1].event == "downloaded_file" + assert actual_output.series[1].math == "median_count_per_actor" diff --git a/ee/hogai/eval/tests/test_eval_trends_planner.py b/ee/hogai/eval/tests/test_eval_trends_planner.py index d4fbff456a91c..4d4ea4c41dfbf 100644 --- a/ee/hogai/eval/tests/test_eval_trends_planner.py +++ b/ee/hogai/eval/tests/test_eval_trends_planner.py @@ -1,179 +1,196 @@ +from collections.abc import Callable + +import pytest from deepeval import assert_test from deepeval.metrics import GEval from deepeval.test_case import LLMTestCase, LLMTestCaseParams +from langchain_core.runnables.config import RunnableConfig from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph -from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import HumanMessage -class TestEvalTrendsPlanner(EvalBaseTest): - def _get_plan_correctness_metric(self): - return GEval( - name="Trends Plan Correctness", - criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a trends insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about trends insights.", - evaluation_steps=[ - "A plan must define at least one event and a math type, but it is not required to define any filters, breakdowns, or formulas.", - "Compare events, properties, math types, and property values of 'expected output' and 'actual output'. Do not penalize if the actual output does not include a timeframe.", - "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.", - # The criteria for aggregations must be more specific because there isn't a way to bypass them. - "Check if the math types in 'actual output' match those in 'expected output'. Math types sometimes are interchangeable, so use your judgement. If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.", - "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different.", - "If 'expected output' contains a formula, check if 'actual output' contains a similar formula, and heavily penalize if the formula is not present or different.", - # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time. - "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.", - ], - evaluation_params=[ - LLMTestCaseParams.INPUT, - LLMTestCaseParams.EXPECTED_OUTPUT, - LLMTestCaseParams.ACTUAL_OUTPUT, - ], - threshold=0.7, - ) +@pytest.fixture(scope="module") +def metric(): + return GEval( + name="Trends Plan Correctness", + criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a trends insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about trends insights.", + evaluation_steps=[ + "A plan must define at least one event and a math type, but it is not required to define any filters, breakdowns, or formulas.", + "Compare events, properties, math types, and property values of 'expected output' and 'actual output'. Do not penalize if the actual output does not include a timeframe.", + "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.", + # The criteria for aggregations must be more specific because there isn't a way to bypass them. + "Check if the math types in 'actual output' match those in 'expected output'. Math types sometimes are interchangeable, so use your judgement. If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.", + "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different.", + "If 'expected output' contains a formula, check if 'actual output' contains a similar formula, and heavily penalize if the formula is not present or different.", + # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time. + "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.", + ], + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.EXPECTED_OUTPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + ], + threshold=0.7, + ) - def _call_node(self, query): - graph: CompiledStateGraph = ( - AssistantGraph(self.team) - .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_PLANNER) - .add_trends_planner(AssistantNodeName.END) - .compile() - ) + +@pytest.fixture +def call_node(team, runnable_config: RunnableConfig) -> Callable[[str], str]: + graph: CompiledStateGraph = ( + AssistantGraph(team) + .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_PLANNER) + .add_trends_planner(AssistantNodeName.END) + .compile() + ) + + def callable(query: str) -> str: state = graph.invoke( AssistantState(messages=[HumanMessage(content=query)]), - self._get_config(), + runnable_config, ) return AssistantState.model_validate(state).plan or "" - def test_no_excessive_property_filters(self): - query = "Show the $pageview trend" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $pageview - - math operation: total count - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_no_excessive_property_filters_for_a_defined_math_type(self): - query = "What is the MAU?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $pageview - - math operation: unique users - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_basic_filtering(self): - query = "can you compare how many Chrome vs Safari users uploaded a file in the last 30d?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - uploaded_file - - math operation: total count - - property filter 1: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Chrome - - property filter 2: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Safari - - Breakdown by: - - breakdown 1: + return callable + + +def test_no_excessive_property_filters(metric, call_node): + query = "Show the $pageview trend" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $pageview + - math operation: total count + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_no_excessive_property_filters_for_a_defined_math_type(metric, call_node): + query = "What is the MAU?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $pageview + - math operation: unique users + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_basic_filtering(metric, call_node): + query = "can you compare how many Chrome vs Safari users uploaded a file in the last 30d?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - uploaded_file + - math operation: total count + - property filter 1: - entity: event - property name: $browser - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_formula_mode(self): - query = "i want to see a ratio of identify divided by page views" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $identify - - math operation: total count - - $pageview - - math operation: total count - - Formula: - `A/B`, where `A` is the total count of `$identify` and `B` is the total count of `$pageview` - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_math_type_by_a_property(self): - query = "what is the average session duration?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - All Events - - math operation: average by `$session_duration` - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_math_type_by_a_user(self): - query = "What is the median page view count for a user?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $pageview - - math operation: median by users - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_needle_in_a_haystack(self): - query = "How frequently do people pay for a personal-pro plan?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - paid_bill - - math operation: total count - - property filter 1: - - entity: event - - property name: plan - - property type: String - - operator: contains - - property value: personal/pro - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) - - def test_funnel_does_not_include_timeframe(self): - query = "what is the pageview trend for event time before 2024-01-01?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $pageview - - math operation: total count - """, - actual_output=self._call_node(query), - ) - assert_test(test_case, [self._get_plan_correctness_metric()]) + - property type: String + - operator: equals + - property value: Chrome + - property filter 2: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Safari + + Breakdown by: + - breakdown 1: + - entity: event + - property name: $browser + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_formula_mode(metric, call_node): + query = "i want to see a ratio of identify divided by page views" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $identify + - math operation: total count + - $pageview + - math operation: total count + + Formula: + `A/B`, where `A` is the total count of `$identify` and `B` is the total count of `$pageview` + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_math_type_by_a_property(metric, call_node): + query = "what is the average session duration?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - All Events + - math operation: average by `$session_duration` + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_math_type_by_a_user(metric, call_node): + query = "What is the median page view count for a user?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $pageview + - math operation: median by users + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_needle_in_a_haystack(metric, call_node): + query = "How frequently do people pay for a personal-pro plan?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - paid_bill + - math operation: total count + - property filter 1: + - entity: event + - property name: plan + - property type: String + - operator: contains + - property value: personal/pro + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) + + +def test_trends_does_not_include_timeframe(metric, call_node): + query = "what is the pageview trend for event time before 2024-01-01?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $pageview + - math operation: total count + """, + actual_output=call_node(query), + ) + assert_test(test_case, [metric]) diff --git a/ee/hogai/eval/utils.py b/ee/hogai/eval/utils.py deleted file mode 100644 index 6e03c4cfafa9f..0000000000000 --- a/ee/hogai/eval/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -import os - -import pytest -from django.test import override_settings -from flaky import flaky -from langchain_core.runnables import RunnableConfig - -from ee.models.assistant import Conversation -from posthog.demo.matrix.manager import MatrixManager -from posthog.tasks.demo_create_data import HedgeboxMatrix -from posthog.test.base import NonAtomicBaseTest - - -@pytest.mark.skipif(os.environ.get("DEEPEVAL") != "YES", reason="Only runs for the assistant evaluation") -@flaky(max_runs=3, min_passes=1) -class EvalBaseTest(NonAtomicBaseTest): - def _get_config(self) -> RunnableConfig: - conversation = Conversation.objects.create(team=self.team, user=self.user) - return { - "configurable": { - "thread_id": conversation.id, - } - } - - @classmethod - def setUpTestData(cls): - super().setUpTestData() - matrix = HedgeboxMatrix( - seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef", # this seed generates all events - days_past=120, - days_future=30, - n_clusters=500, - group_type_index_offset=0, - ) - matrix_manager = MatrixManager(matrix, print_steps=True) - existing_user = cls.team.organization.members.first() - with override_settings(TEST=False): - # Simulation saving should occur in non-test mode, so that Kafka isn't mocked. Normally in tests we don't - # want to ingest via Kafka, but simulation saving is specifically designed to use that route for speed - matrix_manager.run_on_team(cls.team, existing_user) diff --git a/posthog/conftest.py b/posthog/conftest.py index c27dbec43955e..e9804d25eff42 100644 --- a/posthog/conftest.py +++ b/posthog/conftest.py @@ -14,6 +14,7 @@ def create_clickhouse_tables(num_tables: int): CREATE_DATA_QUERIES, CREATE_DICTIONARY_QUERIES, CREATE_DISTRIBUTED_TABLE_QUERIES, + CREATE_KAFKA_TABLE_QUERIES, CREATE_MERGETREE_TABLE_QUERIES, CREATE_MV_TABLE_QUERIES, CREATE_VIEW_QUERIES, @@ -28,10 +29,18 @@ def create_clickhouse_tables(num_tables: int): + len(CREATE_DICTIONARY_QUERIES) ) + # Evaluation tests use Kafka for faster data ingestion. + if settings.IN_EVAL_TESTING: + total_tables += len(CREATE_KAFKA_TABLE_QUERIES) + # Check if all the tables have already been created. Views, materialized views, and dictionaries also count if num_tables == total_tables: return + if settings.IN_EVAL_TESTING: + kafka_table_queries = list(map(build_query, CREATE_KAFKA_TABLE_QUERIES)) + run_clickhouse_statement_in_parallel(kafka_table_queries) + table_queries = list(map(build_query, CREATE_MERGETREE_TABLE_QUERIES + CREATE_DISTRIBUTED_TABLE_QUERIES)) run_clickhouse_statement_in_parallel(table_queries) @@ -62,7 +71,7 @@ def reset_clickhouse_tables(): from posthog.models.channel_type.sql import TRUNCATE_CHANNEL_DEFINITION_TABLE_SQL from posthog.models.cohort.sql import TRUNCATE_COHORTPEOPLE_TABLE_SQL from posthog.models.error_tracking.sql import TRUNCATE_ERROR_TRACKING_ISSUE_FINGERPRINT_OVERRIDES_TABLE_SQL - from posthog.models.event.sql import TRUNCATE_EVENTS_TABLE_SQL, TRUNCATE_EVENTS_RECENT_TABLE_SQL + from posthog.models.event.sql import TRUNCATE_EVENTS_RECENT_TABLE_SQL, TRUNCATE_EVENTS_TABLE_SQL from posthog.models.group.sql import TRUNCATE_GROUPS_TABLE_SQL from posthog.models.performance.sql import TRUNCATE_PERFORMANCE_EVENTS_TABLE_SQL from posthog.models.person.sql import ( @@ -100,6 +109,18 @@ def reset_clickhouse_tables(): TRUNCATE_HEATMAPS_TABLE_SQL(), ] + # Drop created Kafka tables because some tests don't expect it. + if settings.IN_EVAL_TESTING: + kafka_tables = sync_execute( + f""" + SELECT name + FROM system.tables + WHERE database = '{settings.CLICKHOUSE_DATABASE}' AND name LIKE 'kafka_%' + """, + ) + # Using `ON CLUSTER` takes x20 more time to drop the tables: https://github.com/ClickHouse/ClickHouse/issues/15473. + TABLES_TO_CREATE_DROP += [f"DROP TABLE {table[0]}" for table in kafka_tables] + run_clickhouse_statement_in_parallel(TABLES_TO_CREATE_DROP) from posthog.clickhouse.schema import ( diff --git a/posthog/settings/__init__.py b/posthog/settings/__init__.py index c6067fd19c1f7..3e7ebc0b7c984 100644 --- a/posthog/settings/__init__.py +++ b/posthog/settings/__init__.py @@ -108,6 +108,7 @@ PROM_PUSHGATEWAY_ADDRESS: str | None = os.getenv("PROM_PUSHGATEWAY_ADDRESS", None) IN_UNIT_TESTING: bool = get_from_env("IN_UNIT_TESTING", False, type_cast=str_to_bool) +IN_EVAL_TESTING: bool = get_from_env("DEEPEVAL", False, type_cast=str_to_bool) HOGQL_INCREASED_MAX_EXECUTION_TIME: int = get_from_env("HOGQL_INCREASED_MAX_EXECUTION_TIME", 600, type_cast=int)