-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
Showing
9 changed files
with
538 additions
and
653 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,28 @@ | ||
import functools | ||
from collections.abc import Generator | ||
|
||
import pytest | ||
from django.conf import settings | ||
from django.test import override_settings | ||
from langchain_core.runnables import RunnableConfig | ||
|
||
from ee.models import Conversation | ||
from posthog.demo.matrix.manager import MatrixManager | ||
from posthog.models import Organization, Project, Team, User | ||
from posthog.tasks.demo_create_data import HedgeboxMatrix | ||
from posthog.test.base import BaseTest | ||
|
||
|
||
# Flaky is a handy tool, but it always runs setup fixtures for retries. | ||
# This decorator will just retry without re-running setup. | ||
def retry_test_only(max_retries=3): | ||
def decorator(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
last_error: Exception | None = None | ||
for attempt in range(max_retries): | ||
try: | ||
return func(*args, **kwargs) | ||
except Exception as e: | ||
last_error = e | ||
print(f"\nRetrying test (attempt {attempt + 1}/{max_retries})...") # noqa | ||
if last_error: | ||
raise last_error | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
# Apply decorators to all tests in the package. | ||
def pytest_collection_modifyitems(items): | ||
for item in items: | ||
item.add_marker( | ||
pytest.mark.skipif(not settings.IN_EVAL_TESTING, reason="Only runs for the assistant evaluation") | ||
) | ||
# Apply our custom retry decorator to the test function | ||
item.obj = retry_test_only(max_retries=3)(item.obj) | ||
|
||
|
||
@pytest.fixture(scope="package") | ||
def team(django_db_blocker) -> Generator[Team, None, None]: | ||
with django_db_blocker.unblock(): | ||
organization = Organization.objects.create(name=BaseTest.CONFIG_ORGANIZATION_NAME) | ||
project = Project.objects.create(id=Team.objects.increment_id_sequence(), organization=organization) | ||
team = Team.objects.create( | ||
id=project.id, | ||
project=project, | ||
organization=organization, | ||
test_account_filters=[ | ||
{ | ||
"key": "email", | ||
"value": "@posthog.com", | ||
"operator": "not_icontains", | ||
"type": "person", | ||
} | ||
], | ||
has_completed_onboarding_for={"product_analytics": True}, | ||
) | ||
yield team | ||
organization.delete() | ||
|
||
from posthog.test.base import run_clickhouse_statement_in_parallel | ||
|
||
@pytest.fixture(scope="package") | ||
def user(team, django_db_blocker) -> Generator[User, None, None]: | ||
with django_db_blocker.unblock(): | ||
user = User.objects.create_and_join(team.organization, "[email protected]", "password1234") | ||
yield user | ||
user.delete() | ||
|
||
@pytest.fixture(scope="module", autouse=True) | ||
def setup_kafka_tables(django_db_setup): | ||
from posthog.clickhouse.client import sync_execute | ||
from posthog.clickhouse.schema import ( | ||
CREATE_KAFKA_TABLE_QUERIES, | ||
build_query, | ||
) | ||
from posthog.settings import CLICKHOUSE_CLUSTER, CLICKHOUSE_DATABASE | ||
|
||
@pytest.mark.django_db(transaction=True) | ||
@pytest.fixture | ||
def runnable_config(team, user) -> Generator[RunnableConfig, None, None]: | ||
conversation = Conversation.objects.create(team=team, user=user) | ||
yield { | ||
"configurable": { | ||
"thread_id": conversation.id, | ||
} | ||
} | ||
conversation.delete() | ||
kafka_queries = list(map(build_query, CREATE_KAFKA_TABLE_QUERIES)) | ||
run_clickhouse_statement_in_parallel(kafka_queries) | ||
|
||
yield | ||
|
||
@pytest.fixture(scope="package", autouse=True) | ||
def setup_test_data(django_db_setup, team, user, django_db_blocker): | ||
with django_db_blocker.unblock(): | ||
matrix = HedgeboxMatrix( | ||
seed="b1ef3c66-5f43-488a-98be-6b46d92fbcef", # this seed generates all events | ||
days_past=120, | ||
days_future=30, | ||
n_clusters=500, | ||
group_type_index_offset=0, | ||
) | ||
matrix_manager = MatrixManager(matrix, print_steps=True) | ||
with override_settings(TEST=False): | ||
# Simulation saving should occur in non-test mode, so that Kafka isn't mocked. Normally in tests we don't | ||
# want to ingest via Kafka, but simulation saving is specifically designed to use that route for speed | ||
matrix_manager.run_on_team(team, user) | ||
kafka_tables = sync_execute( | ||
f""" | ||
SELECT name | ||
FROM system.tables | ||
WHERE database = '{CLICKHOUSE_DATABASE}' AND name LIKE 'kafka_%' | ||
""", | ||
) | ||
kafka_truncate_queries = [f"DROP TABLE {table[0]} ON CLUSTER '{CLICKHOUSE_CLUSTER}'" for table in kafka_tables] | ||
run_clickhouse_statement_in_parallel(kafka_truncate_queries) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,40 @@ | ||
from collections.abc import Callable | ||
from typing import cast | ||
|
||
import pytest | ||
from langgraph.graph.state import CompiledStateGraph | ||
|
||
from ee.hogai.assistant import AssistantGraph | ||
from ee.hogai.eval.utils import EvalBaseTest | ||
from ee.hogai.utils.types import AssistantNodeName, AssistantState | ||
from posthog.schema import AssistantFunnelsQuery, HumanMessage, VisualizationMessage | ||
|
||
|
||
@pytest.fixture | ||
def call_node(team, runnable_config) -> Callable[[str, str], AssistantFunnelsQuery]: | ||
graph: CompiledStateGraph = ( | ||
AssistantGraph(team) | ||
.add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_GENERATOR) | ||
.add_funnel_generator(AssistantNodeName.END) | ||
.compile() | ||
) | ||
|
||
def callable(query: str, plan: str) -> AssistantFunnelsQuery: | ||
class TestEvalFunnelGenerator(EvalBaseTest): | ||
def _call_node(self, query: str, plan: str) -> AssistantFunnelsQuery: | ||
graph: CompiledStateGraph = ( | ||
AssistantGraph(self.team) | ||
.add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_GENERATOR) | ||
.add_funnel_generator(AssistantNodeName.END) | ||
.compile() | ||
) | ||
state = graph.invoke( | ||
AssistantState(messages=[HumanMessage(content=query)], plan=plan), | ||
runnable_config, | ||
self._get_config(), | ||
) | ||
return cast(VisualizationMessage, AssistantState.model_validate(state).messages[-1]).answer | ||
|
||
return callable | ||
|
||
|
||
def test_node_replaces_equals_with_contains(call_node): | ||
query = "what is the conversion rate from a page view to sign up for users with name John?" | ||
plan = """Sequence: | ||
1. $pageview | ||
- property filter 1 | ||
- person | ||
- name | ||
- equals | ||
- John | ||
2. signed_up | ||
""" | ||
actual_output = call_node(query, plan).model_dump_json(exclude_none=True) | ||
assert "exact" not in actual_output | ||
assert "icontains" in actual_output | ||
assert "John" not in actual_output | ||
assert "john" in actual_output | ||
def test_node_replaces_equals_with_contains(self): | ||
query = "what is the conversion rate from a page view to sign up for users with name John?" | ||
plan = """Sequence: | ||
1. $pageview | ||
- property filter 1 | ||
- person | ||
- name | ||
- equals | ||
- John | ||
2. signed_up | ||
""" | ||
actual_output = self._call_node(query, plan).model_dump_json(exclude_none=True) | ||
assert "exact" not in actual_output | ||
assert "icontains" in actual_output | ||
assert "John" not in actual_output | ||
assert "john" in actual_output |
Oops, something went wrong.