diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index 4727ff07f4f78..3c740a822598e 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -136,7 +136,11 @@ def _events_prompt(self) -> str: if not isinstance(response, CachedTeamTaxonomyQueryResponse): raise ValueError("Failed to generate events prompt.") - events = [item.event for item in response.results] + events: list[str] = [] + for item in response.results: + if len(response.results) > 25 and item.count <= 3: + continue + events.append(item.event) # default for null in the tags: list[str] = ["all events"] diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index 6e878e80ffbc7..dc297570c1fd1 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -2,15 +2,13 @@ from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage -from posthog.test.base import ( - APIBaseTest, - ClickhouseTestMixin, -) +from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person @override_settings(IN_UNIT_TESTING=True) class TestPlanAgentNode(ClickhouseTestMixin, APIBaseTest): def setUp(self): + super().setUp() self.schema = ExperimentalAITrendsQuery(series=[]) def test_agent_reconstructs_conversation(self): @@ -70,6 +68,24 @@ def test_agent_reconstructs_conversation_and_omits_unknown_messages(self): self.assertIn("Text", history[0].content) self.assertNotIn("{{question}}", history[0].content) + def test_agent_filters_out_low_count_events(self): + _create_person(distinct_ids=["test"], team=self.team) + for i in range(26): + _create_event(event=f"event{i}", distinct_id="test", team=self.team) + _create_event(event="distinctevent", distinct_id="test", team=self.team) + node = CreateTrendsPlanNode(self.team) + self.assertEqual( + node._events_prompt, + "\nall events\ndistinctevent\n", + ) + + def test_agent_preserves_low_count_events_for_smaller_teams(self): + _create_person(distinct_ids=["test"], team=self.team) + _create_event(event="distinctevent", distinct_id="test", team=self.team) + node = CreateTrendsPlanNode(self.team) + self.assertIn("distinctevent", node._events_prompt) + self.assertIn("all events", node._events_prompt) + @override_settings(IN_UNIT_TESTING=True) class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest):