From 4e0ef309b3745044f188f53f27eb9aae32e8ab6a Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Thu, 19 Dec 2024 12:16:06 +0100 Subject: [PATCH] chore(product-assistant): revert speed up evaluation tests (#26926) (#27047) --- ee/hogai/eval/conftest.py | 118 +----- .../eval/tests/test_eval_funnel_generator.py | 58 ++- .../eval/tests/test_eval_funnel_planner.py | 398 +++++++++--------- ee/hogai/eval/tests/test_eval_router.py | 119 +++--- .../eval/tests/test_eval_trends_generator.py | 95 ++--- .../eval/tests/test_eval_trends_planner.py | 339 +++++++-------- ee/hogai/eval/utils.py | 40 ++ posthog/conftest.py | 23 +- posthog/settings/__init__.py | 1 - 9 files changed, 538 insertions(+), 653 deletions(-) create mode 100644 ee/hogai/eval/utils.py diff --git a/ee/hogai/eval/conftest.py b/ee/hogai/eval/conftest.py index c6f1924485692..d0bc75348eeac 100644 --- a/ee/hogai/eval/conftest.py +++ b/ee/hogai/eval/conftest.py @@ -1,104 +1,28 @@ -import functools -from collections.abc import Generator - import pytest -from django.conf import settings -from django.test import override_settings -from langchain_core.runnables import RunnableConfig - -from ee.models import Conversation -from posthog.demo.matrix.manager import MatrixManager -from posthog.models import Organization, Project, Team, User -from posthog.tasks.demo_create_data import HedgeboxMatrix -from posthog.test.base import BaseTest - - -# Flaky is a handy tool, but it always runs setup fixtures for retries. -# This decorator will just retry without re-running setup. -def retry_test_only(max_retries=3): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - last_error: Exception | None = None - for attempt in range(max_retries): - try: - return func(*args, **kwargs) - except Exception as e: - last_error = e - print(f"\nRetrying test (attempt {attempt + 1}/{max_retries})...") # noqa - if last_error: - raise last_error - - return wrapper - - return decorator - - -# Apply decorators to all tests in the package. -def pytest_collection_modifyitems(items): - for item in items: - item.add_marker( - pytest.mark.skipif(not settings.IN_EVAL_TESTING, reason="Only runs for the assistant evaluation") - ) - # Apply our custom retry decorator to the test function - item.obj = retry_test_only(max_retries=3)(item.obj) - - -@pytest.fixture(scope="package") -def team(django_db_blocker) -> Generator[Team, None, None]: - with django_db_blocker.unblock(): - organization = Organization.objects.create(name=BaseTest.CONFIG_ORGANIZATION_NAME) - project = Project.objects.create(id=Team.objects.increment_id_sequence(), organization=organization) - team = Team.objects.create( - id=project.id, - project=project, - organization=organization, - test_account_filters=[ - { - "key": "email", - "value": "@posthog.com", - "operator": "not_icontains", - "type": "person", - } - ], - has_completed_onboarding_for={"product_analytics": True}, - ) - yield team - organization.delete() +from posthog.test.base import run_clickhouse_statement_in_parallel -@pytest.fixture(scope="package") -def user(team, django_db_blocker) -> Generator[User, None, None]: - with django_db_blocker.unblock(): - user = User.objects.create_and_join(team.organization, "eval@posthog.com", "password1234") - yield user - user.delete() +@pytest.fixture(scope="module", autouse=True) +def setup_kafka_tables(django_db_setup): + from posthog.clickhouse.client import sync_execute + from posthog.clickhouse.schema import ( + CREATE_KAFKA_TABLE_QUERIES, + build_query, + ) + from posthog.settings import CLICKHOUSE_CLUSTER, CLICKHOUSE_DATABASE -@pytest.mark.django_db(transaction=True) -@pytest.fixture -def runnable_config(team, user) -> Generator[RunnableConfig, None, None]: - conversation = Conversation.objects.create(team=team, user=user) - yield { - "configurable": { - "thread_id": conversation.id, - } - } - conversation.delete() + kafka_queries = list(map(build_query, CREATE_KAFKA_TABLE_QUERIES)) + run_clickhouse_statement_in_parallel(kafka_queries) + yield -@pytest.fixture(scope="package", autouse=True) -def setup_test_data(django_db_setup, team, user, django_db_blocker): - with django_db_blocker.unblock(): - matrix = HedgeboxMatrix( - seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef", # this seed generates all events - days_past=120, - days_future=30, - n_clusters=500, - group_type_index_offset=0, - ) - matrix_manager = MatrixManager(matrix, print_steps=True) - with override_settings(TEST=False): - # Simulation saving should occur in non-test mode, so that Kafka isn't mocked. Normally in tests we don't - # want to ingest via Kafka, but simulation saving is specifically designed to use that route for speed - matrix_manager.run_on_team(team, user) + kafka_tables = sync_execute( + f""" + SELECT name + FROM system.tables + WHERE database = '{CLICKHOUSE_DATABASE}' AND name LIKE 'kafka_%' + """, + ) + kafka_truncate_queries = [f"DROP TABLE {table[0]} ON CLUSTER '{CLICKHOUSE_CLUSTER}'" for table in kafka_tables] + run_clickhouse_statement_in_parallel(kafka_truncate_queries) diff --git a/ee/hogai/eval/tests/test_eval_funnel_generator.py b/ee/hogai/eval/tests/test_eval_funnel_generator.py index 5f0f29243296a..4d7876ca6f73c 100644 --- a/ee/hogai/eval/tests/test_eval_funnel_generator.py +++ b/ee/hogai/eval/tests/test_eval_funnel_generator.py @@ -1,46 +1,40 @@ -from collections.abc import Callable from typing import cast -import pytest from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph +from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import AssistantFunnelsQuery, HumanMessage, VisualizationMessage -@pytest.fixture -def call_node(team, runnable_config) -> Callable[[str, str], AssistantFunnelsQuery]: - graph: CompiledStateGraph = ( - AssistantGraph(team) - .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_GENERATOR) - .add_funnel_generator(AssistantNodeName.END) - .compile() - ) - - def callable(query: str, plan: str) -> AssistantFunnelsQuery: +class TestEvalFunnelGenerator(EvalBaseTest): + def _call_node(self, query: str, plan: str) -> AssistantFunnelsQuery: + graph: CompiledStateGraph = ( + AssistantGraph(self.team) + .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_GENERATOR) + .add_funnel_generator(AssistantNodeName.END) + .compile() + ) state = graph.invoke( AssistantState(messages=[HumanMessage(content=query)], plan=plan), - runnable_config, + self._get_config(), ) return cast(VisualizationMessage, AssistantState.model_validate(state).messages[-1]).answer - return callable - - -def test_node_replaces_equals_with_contains(call_node): - query = "what is the conversion rate from a page view to sign up for users with name John?" - plan = """Sequence: - 1. $pageview - - property filter 1 - - person - - name - - equals - - John - 2. signed_up - """ - actual_output = call_node(query, plan).model_dump_json(exclude_none=True) - assert "exact" not in actual_output - assert "icontains" in actual_output - assert "John" not in actual_output - assert "john" in actual_output + def test_node_replaces_equals_with_contains(self): + query = "what is the conversion rate from a page view to sign up for users with name John?" + plan = """Sequence: + 1. $pageview + - property filter 1 + - person + - name + - equals + - John + 2. signed_up + """ + actual_output = self._call_node(query, plan).model_dump_json(exclude_none=True) + assert "exact" not in actual_output + assert "icontains" in actual_output + assert "John" not in actual_output + assert "john" in actual_output diff --git a/ee/hogai/eval/tests/test_eval_funnel_planner.py b/ee/hogai/eval/tests/test_eval_funnel_planner.py index c8bc25bc0b5dc..9adbd75e77c6c 100644 --- a/ee/hogai/eval/tests/test_eval_funnel_planner.py +++ b/ee/hogai/eval/tests/test_eval_funnel_planner.py @@ -1,224 +1,208 @@ -from collections.abc import Callable - -import pytest from deepeval import assert_test from deepeval.metrics import GEval from deepeval.test_case import LLMTestCase, LLMTestCaseParams -from langchain_core.runnables.config import RunnableConfig from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph +from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import HumanMessage -@pytest.fixture(scope="module") -def metric(): - return GEval( - name="Funnel Plan Correctness", - criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a funnel insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about funnel insights.", - evaluation_steps=[ - "A plan must define at least two series in the sequence, but it is not required to define any filters, exclusion steps, or a breakdown.", - "Compare events, properties, math types, and property values of 'expected output' and 'actual output'. Do not penalize if the actual output does not include a timeframe.", - "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.", - # The criteria for aggregations must be more specific because there isn't a way to bypass them. - "Check if the math types in 'actual output' match those in 'expected output.' If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.", - "If 'expected output' contains exclusion steps, check if 'actual output' contains those, and heavily penalize if the exclusion steps are not present or different.", - "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different. Plans may only have one breakdown.", - # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time. - "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.", - ], - evaluation_params=[ - LLMTestCaseParams.INPUT, - LLMTestCaseParams.EXPECTED_OUTPUT, - LLMTestCaseParams.ACTUAL_OUTPUT, - ], - threshold=0.7, - ) - - -@pytest.fixture -def call_node(team, runnable_config: RunnableConfig) -> Callable[[str], str]: - graph: CompiledStateGraph = ( - AssistantGraph(team) - .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER) - .add_funnel_planner(AssistantNodeName.END) - .compile() - ) +class TestEvalFunnelPlanner(EvalBaseTest): + def _get_plan_correctness_metric(self): + return GEval( + name="Funnel Plan Correctness", + criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a funnel insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about funnel insights.", + evaluation_steps=[ + "A plan must define at least two series in the sequence, but it is not required to define any filters, exclusion steps, or a breakdown.", + "Compare events, properties, math types, and property values of 'expected output' and 'actual output'. Do not penalize if the actual output does not include a timeframe.", + "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.", + # The criteria for aggregations must be more specific because there isn't a way to bypass them. + "Check if the math types in 'actual output' match those in 'expected output.' If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.", + "If 'expected output' contains exclusion steps, check if 'actual output' contains those, and heavily penalize if the exclusion steps are not present or different.", + "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different. Plans may only have one breakdown.", + # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time. + "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.", + ], + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.EXPECTED_OUTPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + ], + threshold=0.7, + ) - def callable(query: str) -> str: + def _call_node(self, query): + graph: CompiledStateGraph = ( + AssistantGraph(self.team) + .add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER) + .add_funnel_planner(AssistantNodeName.END) + .compile() + ) state = graph.invoke( AssistantState(messages=[HumanMessage(content=query)]), - runnable_config, + self._get_config(), ) return AssistantState.model_validate(state).plan or "" - return callable - - -def test_basic_funnel(metric, call_node): - query = "what was the conversion from a page view to sign up?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. $pageview - 2. signed_up - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_outputs_at_least_two_events(metric, call_node): - """ - Ambigious query. The funnel must return at least two events. - """ - query = "how many users paid a bill?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. any event - 2. upgrade_plan - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_no_excessive_property_filters(metric, call_node): - query = "Show the user conversion from a sign up to a file download" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. signed_up - 2. downloaded_file - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_basic_filtering(metric, call_node): - query = "What was the conversion from uploading a file to downloading it from Chrome and Safari in the last 30d?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. uploaded_file - - property filter 1: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Chrome - - property filter 2: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Safari - 2. downloaded_file - - property filter 1: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Chrome - - property filter 2: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Safari - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_exclusion_steps(metric, call_node): - query = "What was the conversion from uploading a file to downloading it in the last 30d excluding users that deleted a file?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. uploaded_file - 2. downloaded_file - - Exclusions: - - deleted_file - - start index: 0 - - end index: 1 - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_breakdown(metric, call_node): - query = "Show a conversion from uploading a file to downloading it segmented by a browser" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. uploaded_file - 2. downloaded_file - - Breakdown by: - - entity: event - - property name: $browser - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_needle_in_a_haystack(metric, call_node): - query = "What was the conversion from a sign up to a paying customer on the personal-pro plan?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. signed_up - 2. paid_bill - - property filter 1: - - entity: event - - property name: plan - - property type: String - - operator: equals - - property value: personal/pro - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_planner_outputs_multiple_series_from_a_single_series_question(metric, call_node): - query = "What's our sign-up funnel?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. $pageview - 2. signed_up - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - + def test_basic_funnel(self): + query = "what was the conversion from a page view to sign up?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. $pageview + 2. signed_up + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_outputs_at_least_two_events(self): + """ + Ambigious query. The funnel must return at least two events. + """ + query = "how many users paid a bill?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. any event + 2. upgrade_plan + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_no_excessive_property_filters(self): + query = "Show the user conversion from a sign up to a file download" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. signed_up + 2. downloaded_file + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) -def test_funnel_does_not_include_timeframe(metric, call_node): - query = "what was the conversion from a page view to sign up for event time before 2024-01-01?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Sequence: - 1. $pageview - 2. signed_up - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) + def test_basic_filtering(self): + query = ( + "What was the conversion from uploading a file to downloading it from Chrome and Safari in the last 30d?" + ) + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. uploaded_file + - property filter 1: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Chrome + - property filter 2: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Safari + 2. downloaded_file + - property filter 1: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Chrome + - property filter 2: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Safari + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_exclusion_steps(self): + query = "What was the conversion from uploading a file to downloading it in the last 30d excluding users that deleted a file?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. uploaded_file + 2. downloaded_file + + Exclusions: + - deleted_file + - start index: 0 + - end index: 1 + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_breakdown(self): + query = "Show a conversion from uploading a file to downloading it segmented by a browser" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. uploaded_file + 2. downloaded_file + + Breakdown by: + - entity: event + - property name: $browser + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_needle_in_a_haystack(self): + query = "What was the conversion from a sign up to a paying customer on the personal-pro plan?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. signed_up + 2. paid_bill + - property filter 1: + - entity: event + - property name: plan + - property type: String + - operator: equals + - property value: personal/pro + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_planner_outputs_multiple_series_from_a_single_series_question(self): + query = "What's our sign-up funnel?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. $pageview + 2. signed_up + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_funnel_does_not_include_timeframe(self): + query = "what was the conversion from a page view to sign up for event time before 2024-01-01?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Sequence: + 1. $pageview + 2. signed_up + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) diff --git a/ee/hogai/eval/tests/test_eval_router.py b/ee/hogai/eval/tests/test_eval_router.py index 84e5c4c809972..c1307e9d40f00 100644 --- a/ee/hogai/eval/tests/test_eval_router.py +++ b/ee/hogai/eval/tests/test_eval_router.py @@ -1,80 +1,69 @@ -from collections.abc import Callable from typing import cast -import pytest from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph +from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import HumanMessage, RouterMessage -@pytest.fixture -def call_node(team, runnable_config) -> Callable[[str | list], str]: - graph: CompiledStateGraph = ( - AssistantGraph(team) - .add_start() - .add_router(path_map={"trends": AssistantNodeName.END, "funnel": AssistantNodeName.END}) - .compile() - ) - - def callable(query: str | list) -> str: +class TestEvalRouter(EvalBaseTest): + def _call_node(self, query: str | list): + graph: CompiledStateGraph = ( + AssistantGraph(self.team) + .add_start() + .add_router(path_map={"trends": AssistantNodeName.END, "funnel": AssistantNodeName.END}) + .compile() + ) messages = [HumanMessage(content=query)] if isinstance(query, str) else query state = graph.invoke( AssistantState(messages=messages), - runnable_config, + self._get_config(), ) return cast(RouterMessage, AssistantState.model_validate(state).messages[-1]).content - return callable - - -def test_outputs_basic_trends_insight(call_node): - query = "Show the $pageview trend" - res = call_node(query) - assert res == "trends" - - -def test_outputs_basic_funnel_insight(call_node): - query = "What is the conversion rate of users who uploaded a file to users who paid for a plan?" - res = call_node(query) - assert res == "funnel" - - -def test_converts_trends_to_funnel(call_node): - conversation = [ - HumanMessage(content="Show trends of $pageview and $identify"), - RouterMessage(content="trends"), - HumanMessage(content="Convert this insight to a funnel"), - ] - res = call_node(conversation[:1]) - assert res == "trends" - res = call_node(conversation) - assert res == "funnel" - - -def test_converts_funnel_to_trends(call_node): - conversation = [ - HumanMessage(content="What is the conversion from a page view to a sign up?"), - RouterMessage(content="funnel"), - HumanMessage(content="Convert this insight to a trends"), - ] - res = call_node(conversation[:1]) - assert res == "funnel" - res = call_node(conversation) - assert res == "trends" - - -def test_outputs_single_trends_insight(call_node): - """ - Must display a trends insight because it's not possible to build a funnel with a single series. - """ - query = "how many users upgraded their plan to personal pro?" - res = call_node(query) - assert res == "trends" - - -def test_classifies_funnel_with_single_series(call_node): - query = "What's our sign-up funnel?" - res = call_node(query) - assert res == "funnel" + def test_outputs_basic_trends_insight(self): + query = "Show the $pageview trend" + res = self._call_node(query) + self.assertEqual(res, "trends") + + def test_outputs_basic_funnel_insight(self): + query = "What is the conversion rate of users who uploaded a file to users who paid for a plan?" + res = self._call_node(query) + self.assertEqual(res, "funnel") + + def test_converts_trends_to_funnel(self): + conversation = [ + HumanMessage(content="Show trends of $pageview and $identify"), + RouterMessage(content="trends"), + HumanMessage(content="Convert this insight to a funnel"), + ] + res = self._call_node(conversation[:1]) + self.assertEqual(res, "trends") + res = self._call_node(conversation) + self.assertEqual(res, "funnel") + + def test_converts_funnel_to_trends(self): + conversation = [ + HumanMessage(content="What is the conversion from a page view to a sign up?"), + RouterMessage(content="funnel"), + HumanMessage(content="Convert this insight to a trends"), + ] + res = self._call_node(conversation[:1]) + self.assertEqual(res, "funnel") + res = self._call_node(conversation) + self.assertEqual(res, "trends") + + def test_outputs_single_trends_insight(self): + """ + Must display a trends insight because it's not possible to build a funnel with a single series. + """ + query = "how many users upgraded their plan to personal pro?" + res = self._call_node(query) + self.assertEqual(res, "trends") + + def test_classifies_funnel_with_single_series(self): + query = "What's our sign-up funnel?" + res = self._call_node(query) + self.assertEqual(res, "funnel") diff --git a/ee/hogai/eval/tests/test_eval_trends_generator.py b/ee/hogai/eval/tests/test_eval_trends_generator.py index c8491957c868f..496bbf0100b51 100644 --- a/ee/hogai/eval/tests/test_eval_trends_generator.py +++ b/ee/hogai/eval/tests/test_eval_trends_generator.py @@ -1,65 +1,58 @@ -from collections.abc import Callable from typing import cast -import pytest from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph +from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import AssistantTrendsQuery, HumanMessage, VisualizationMessage -@pytest.fixture -def call_node(team, runnable_config) -> Callable[[str, str], AssistantTrendsQuery]: - graph: CompiledStateGraph = ( - AssistantGraph(team) - .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR) - .add_trends_generator(AssistantNodeName.END) - .compile() - ) - - def callable(query: str, plan: str) -> AssistantTrendsQuery: +class TestEvalTrendsGenerator(EvalBaseTest): + def _call_node(self, query: str, plan: str) -> AssistantTrendsQuery: + graph: CompiledStateGraph = ( + AssistantGraph(self.team) + .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_GENERATOR) + .add_trends_generator(AssistantNodeName.END) + .compile() + ) state = graph.invoke( AssistantState(messages=[HumanMessage(content=query)], plan=plan), - runnable_config, + self._get_config(), ) return cast(VisualizationMessage, AssistantState.model_validate(state).messages[-1]).answer - return callable - - -def test_node_replaces_equals_with_contains(call_node): - query = "what is pageview trend for users with name John?" - plan = """Events: - - $pageview - - math operation: total count - - property filter 1 - - person - - name - - equals - - John - """ - actual_output = call_node(query, plan).model_dump_json(exclude_none=True) - assert "exact" not in actual_output - assert "icontains" in actual_output - assert "John" not in actual_output - assert "john" in actual_output - - -def test_node_leans_towards_line_graph(call_node): - query = "How often do users download files?" - # We ideally want to consider both total count of downloads per period, as well as how often a median user downloads - plan = """Events: - - downloaded_file - - math operation: total count - - downloaded_file - - math operation: median count per user - """ - actual_output = call_node(query, plan) - assert actual_output.trendsFilter.display == "ActionsLineGraph" - assert actual_output.series[0].kind == "EventsNode" - assert actual_output.series[0].event == "downloaded_file" - assert actual_output.series[0].math == "total" - assert actual_output.series[1].kind == "EventsNode" - assert actual_output.series[1].event == "downloaded_file" - assert actual_output.series[1].math == "median_count_per_actor" + def test_node_replaces_equals_with_contains(self): + query = "what is pageview trend for users with name John?" + plan = """Events: + - $pageview + - math operation: total count + - property filter 1 + - person + - name + - equals + - John + """ + actual_output = self._call_node(query, plan).model_dump_json(exclude_none=True) + assert "exact" not in actual_output + assert "icontains" in actual_output + assert "John" not in actual_output + assert "john" in actual_output + + def test_node_leans_towards_line_graph(self): + query = "How often do users download files?" + # We ideally want to consider both total count of downloads per period, as well as how often a median user downloads + plan = """Events: + - downloaded_file + - math operation: total count + - downloaded_file + - math operation: median count per user + """ + actual_output = self._call_node(query, plan) + assert actual_output.trendsFilter.display == "ActionsLineGraph" + assert actual_output.series[0].kind == "EventsNode" + assert actual_output.series[0].event == "downloaded_file" + assert actual_output.series[0].math == "total" + assert actual_output.series[1].kind == "EventsNode" + assert actual_output.series[1].event == "downloaded_file" + assert actual_output.series[1].math == "median_count_per_actor" diff --git a/ee/hogai/eval/tests/test_eval_trends_planner.py b/ee/hogai/eval/tests/test_eval_trends_planner.py index 4d4ea4c41dfbf..d4fbff456a91c 100644 --- a/ee/hogai/eval/tests/test_eval_trends_planner.py +++ b/ee/hogai/eval/tests/test_eval_trends_planner.py @@ -1,196 +1,179 @@ -from collections.abc import Callable - -import pytest from deepeval import assert_test from deepeval.metrics import GEval from deepeval.test_case import LLMTestCase, LLMTestCaseParams -from langchain_core.runnables.config import RunnableConfig from langgraph.graph.state import CompiledStateGraph from ee.hogai.assistant import AssistantGraph +from ee.hogai.eval.utils import EvalBaseTest from ee.hogai.utils.types import AssistantNodeName, AssistantState from posthog.schema import HumanMessage -@pytest.fixture(scope="module") -def metric(): - return GEval( - name="Trends Plan Correctness", - criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a trends insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about trends insights.", - evaluation_steps=[ - "A plan must define at least one event and a math type, but it is not required to define any filters, breakdowns, or formulas.", - "Compare events, properties, math types, and property values of 'expected output' and 'actual output'. Do not penalize if the actual output does not include a timeframe.", - "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.", - # The criteria for aggregations must be more specific because there isn't a way to bypass them. - "Check if the math types in 'actual output' match those in 'expected output'. Math types sometimes are interchangeable, so use your judgement. If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.", - "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different.", - "If 'expected output' contains a formula, check if 'actual output' contains a similar formula, and heavily penalize if the formula is not present or different.", - # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time. - "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.", - ], - evaluation_params=[ - LLMTestCaseParams.INPUT, - LLMTestCaseParams.EXPECTED_OUTPUT, - LLMTestCaseParams.ACTUAL_OUTPUT, - ], - threshold=0.7, - ) - - -@pytest.fixture -def call_node(team, runnable_config: RunnableConfig) -> Callable[[str], str]: - graph: CompiledStateGraph = ( - AssistantGraph(team) - .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_PLANNER) - .add_trends_planner(AssistantNodeName.END) - .compile() - ) +class TestEvalTrendsPlanner(EvalBaseTest): + def _get_plan_correctness_metric(self): + return GEval( + name="Trends Plan Correctness", + criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a trends insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about trends insights.", + evaluation_steps=[ + "A plan must define at least one event and a math type, but it is not required to define any filters, breakdowns, or formulas.", + "Compare events, properties, math types, and property values of 'expected output' and 'actual output'. Do not penalize if the actual output does not include a timeframe.", + "Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.", + # The criteria for aggregations must be more specific because there isn't a way to bypass them. + "Check if the math types in 'actual output' match those in 'expected output'. Math types sometimes are interchangeable, so use your judgement. If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.", + "If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different.", + "If 'expected output' contains a formula, check if 'actual output' contains a similar formula, and heavily penalize if the formula is not present or different.", + # We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time. + "Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.", + ], + evaluation_params=[ + LLMTestCaseParams.INPUT, + LLMTestCaseParams.EXPECTED_OUTPUT, + LLMTestCaseParams.ACTUAL_OUTPUT, + ], + threshold=0.7, + ) - def callable(query: str) -> str: + def _call_node(self, query): + graph: CompiledStateGraph = ( + AssistantGraph(self.team) + .add_edge(AssistantNodeName.START, AssistantNodeName.TRENDS_PLANNER) + .add_trends_planner(AssistantNodeName.END) + .compile() + ) state = graph.invoke( AssistantState(messages=[HumanMessage(content=query)]), - runnable_config, + self._get_config(), ) return AssistantState.model_validate(state).plan or "" - return callable - - -def test_no_excessive_property_filters(metric, call_node): - query = "Show the $pageview trend" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $pageview - - math operation: total count - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_no_excessive_property_filters_for_a_defined_math_type(metric, call_node): - query = "What is the MAU?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $pageview - - math operation: unique users - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_basic_filtering(metric, call_node): - query = "can you compare how many Chrome vs Safari users uploaded a file in the last 30d?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - uploaded_file - - math operation: total count - - property filter 1: - - entity: event - - property name: $browser - - property type: String - - operator: equals - - property value: Chrome - - property filter 2: + def test_no_excessive_property_filters(self): + query = "Show the $pageview trend" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $pageview + - math operation: total count + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_no_excessive_property_filters_for_a_defined_math_type(self): + query = "What is the MAU?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $pageview + - math operation: unique users + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_basic_filtering(self): + query = "can you compare how many Chrome vs Safari users uploaded a file in the last 30d?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - uploaded_file + - math operation: total count + - property filter 1: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Chrome + - property filter 2: + - entity: event + - property name: $browser + - property type: String + - operator: equals + - property value: Safari + + Breakdown by: + - breakdown 1: - entity: event - property name: $browser - - property type: String - - operator: equals - - property value: Safari - - Breakdown by: - - breakdown 1: - - entity: event - - property name: $browser - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_formula_mode(metric, call_node): - query = "i want to see a ratio of identify divided by page views" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $identify - - math operation: total count - - $pageview - - math operation: total count - - Formula: - `A/B`, where `A` is the total count of `$identify` and `B` is the total count of `$pageview` - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_math_type_by_a_property(metric, call_node): - query = "what is the average session duration?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - All Events - - math operation: average by `$session_duration` - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_math_type_by_a_user(metric, call_node): - query = "What is the median page view count for a user?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $pageview - - math operation: median by users - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_needle_in_a_haystack(metric, call_node): - query = "How frequently do people pay for a personal-pro plan?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - paid_bill - - math operation: total count - - property filter 1: - - entity: event - - property name: plan - - property type: String - - operator: contains - - property value: personal/pro - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) - - -def test_trends_does_not_include_timeframe(metric, call_node): - query = "what is the pageview trend for event time before 2024-01-01?" - test_case = LLMTestCase( - input=query, - expected_output=""" - Events: - - $pageview - - math operation: total count - """, - actual_output=call_node(query), - ) - assert_test(test_case, [metric]) + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_formula_mode(self): + query = "i want to see a ratio of identify divided by page views" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $identify + - math operation: total count + - $pageview + - math operation: total count + + Formula: + `A/B`, where `A` is the total count of `$identify` and `B` is the total count of `$pageview` + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_math_type_by_a_property(self): + query = "what is the average session duration?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - All Events + - math operation: average by `$session_duration` + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_math_type_by_a_user(self): + query = "What is the median page view count for a user?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $pageview + - math operation: median by users + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_needle_in_a_haystack(self): + query = "How frequently do people pay for a personal-pro plan?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - paid_bill + - math operation: total count + - property filter 1: + - entity: event + - property name: plan + - property type: String + - operator: contains + - property value: personal/pro + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) + + def test_funnel_does_not_include_timeframe(self): + query = "what is the pageview trend for event time before 2024-01-01?" + test_case = LLMTestCase( + input=query, + expected_output=""" + Events: + - $pageview + - math operation: total count + """, + actual_output=self._call_node(query), + ) + assert_test(test_case, [self._get_plan_correctness_metric()]) diff --git a/ee/hogai/eval/utils.py b/ee/hogai/eval/utils.py new file mode 100644 index 0000000000000..6e03c4cfafa9f --- /dev/null +++ b/ee/hogai/eval/utils.py @@ -0,0 +1,40 @@ +import os + +import pytest +from django.test import override_settings +from flaky import flaky +from langchain_core.runnables import RunnableConfig + +from ee.models.assistant import Conversation +from posthog.demo.matrix.manager import MatrixManager +from posthog.tasks.demo_create_data import HedgeboxMatrix +from posthog.test.base import NonAtomicBaseTest + + +@pytest.mark.skipif(os.environ.get("DEEPEVAL") != "YES", reason="Only runs for the assistant evaluation") +@flaky(max_runs=3, min_passes=1) +class EvalBaseTest(NonAtomicBaseTest): + def _get_config(self) -> RunnableConfig: + conversation = Conversation.objects.create(team=self.team, user=self.user) + return { + "configurable": { + "thread_id": conversation.id, + } + } + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + matrix = HedgeboxMatrix( + seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef", # this seed generates all events + days_past=120, + days_future=30, + n_clusters=500, + group_type_index_offset=0, + ) + matrix_manager = MatrixManager(matrix, print_steps=True) + existing_user = cls.team.organization.members.first() + with override_settings(TEST=False): + # Simulation saving should occur in non-test mode, so that Kafka isn't mocked. Normally in tests we don't + # want to ingest via Kafka, but simulation saving is specifically designed to use that route for speed + matrix_manager.run_on_team(cls.team, existing_user) diff --git a/posthog/conftest.py b/posthog/conftest.py index e9804d25eff42..c27dbec43955e 100644 --- a/posthog/conftest.py +++ b/posthog/conftest.py @@ -14,7 +14,6 @@ def create_clickhouse_tables(num_tables: int): CREATE_DATA_QUERIES, CREATE_DICTIONARY_QUERIES, CREATE_DISTRIBUTED_TABLE_QUERIES, - CREATE_KAFKA_TABLE_QUERIES, CREATE_MERGETREE_TABLE_QUERIES, CREATE_MV_TABLE_QUERIES, CREATE_VIEW_QUERIES, @@ -29,18 +28,10 @@ def create_clickhouse_tables(num_tables: int): + len(CREATE_DICTIONARY_QUERIES) ) - # Evaluation tests use Kafka for faster data ingestion. - if settings.IN_EVAL_TESTING: - total_tables += len(CREATE_KAFKA_TABLE_QUERIES) - # Check if all the tables have already been created. Views, materialized views, and dictionaries also count if num_tables == total_tables: return - if settings.IN_EVAL_TESTING: - kafka_table_queries = list(map(build_query, CREATE_KAFKA_TABLE_QUERIES)) - run_clickhouse_statement_in_parallel(kafka_table_queries) - table_queries = list(map(build_query, CREATE_MERGETREE_TABLE_QUERIES + CREATE_DISTRIBUTED_TABLE_QUERIES)) run_clickhouse_statement_in_parallel(table_queries) @@ -71,7 +62,7 @@ def reset_clickhouse_tables(): from posthog.models.channel_type.sql import TRUNCATE_CHANNEL_DEFINITION_TABLE_SQL from posthog.models.cohort.sql import TRUNCATE_COHORTPEOPLE_TABLE_SQL from posthog.models.error_tracking.sql import TRUNCATE_ERROR_TRACKING_ISSUE_FINGERPRINT_OVERRIDES_TABLE_SQL - from posthog.models.event.sql import TRUNCATE_EVENTS_RECENT_TABLE_SQL, TRUNCATE_EVENTS_TABLE_SQL + from posthog.models.event.sql import TRUNCATE_EVENTS_TABLE_SQL, TRUNCATE_EVENTS_RECENT_TABLE_SQL from posthog.models.group.sql import TRUNCATE_GROUPS_TABLE_SQL from posthog.models.performance.sql import TRUNCATE_PERFORMANCE_EVENTS_TABLE_SQL from posthog.models.person.sql import ( @@ -109,18 +100,6 @@ def reset_clickhouse_tables(): TRUNCATE_HEATMAPS_TABLE_SQL(), ] - # Drop created Kafka tables because some tests don't expect it. - if settings.IN_EVAL_TESTING: - kafka_tables = sync_execute( - f""" - SELECT name - FROM system.tables - WHERE database = '{settings.CLICKHOUSE_DATABASE}' AND name LIKE 'kafka_%' - """, - ) - # Using `ON CLUSTER` takes x20 more time to drop the tables: https://github.com/ClickHouse/ClickHouse/issues/15473. - TABLES_TO_CREATE_DROP += [f"DROP TABLE {table[0]}" for table in kafka_tables] - run_clickhouse_statement_in_parallel(TABLES_TO_CREATE_DROP) from posthog.clickhouse.schema import ( diff --git a/posthog/settings/__init__.py b/posthog/settings/__init__.py index 3e7ebc0b7c984..c6067fd19c1f7 100644 --- a/posthog/settings/__init__.py +++ b/posthog/settings/__init__.py @@ -108,7 +108,6 @@ PROM_PUSHGATEWAY_ADDRESS: str | None = os.getenv("PROM_PUSHGATEWAY_ADDRESS", None) IN_UNIT_TESTING: bool = get_from_env("IN_UNIT_TESTING", False, type_cast=str_to_bool) -IN_EVAL_TESTING: bool = get_from_env("DEEPEVAL", False, type_cast=str_to_bool) HOGQL_INCREASED_MAX_EXECUTION_TIME: int = get_from_env("HOGQL_INCREASED_MAX_EXECUTION_TIME", 600, type_cast=int)