Skip to content

Commit

Permalink
feat(product-assistant): evaluation pipeline (#26179)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Matloka <[email protected]>
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent bc98fda commit d836bc8
Show file tree
Hide file tree
Showing 14 changed files with 867 additions and 33 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,10 @@ plugin-transpiler/dist
*.log
# pyright config (keep this until we have a standardized one)
pyrightconfig.json
.temporal-worker-settings
# Assistant Evaluation with Deepeval
.deepeval
.deepeval-cache.json
.deepeval_telemtry.txt
.temporal-worker-settings
temp_test_run_data.json
.temp-deepeval-cache.json
112 changes: 87 additions & 25 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections.abc import Generator
from typing import Any, Literal, TypedDict, TypeGuard, Union
from collections.abc import Generator, Hashable, Iterator
from typing import Any, Literal, Optional, TypedDict, TypeGuard, Union, cast

from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
from langgraph.graph.state import StateGraph
from langgraph.graph.state import CompiledStateGraph, StateGraph
from pydantic import BaseModel
from sentry_sdk import capture_exception

Expand Down Expand Up @@ -74,25 +74,49 @@ def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], As
}


class Assistant:
class AssistantGraph:
_team: Team
_graph: StateGraph

def __init__(self, team: Team):
self._team = team
self._graph = StateGraph(AssistantState)

def _compile_graph(self):
self._has_start_node = False

def add_edge(self, from_node: AssistantNodeName, to_node: AssistantNodeName):
if from_node == AssistantNodeName.START:
self._has_start_node = True
self._graph.add_edge(from_node, to_node)
return self

def compile(self):
if not self._has_start_node:
raise ValueError("Start node not added to the graph")
return self._graph.compile()

def add_start(self):
return self.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)

def add_router(
self,
path_map: Optional[dict[Hashable, AssistantNodeName]] = None,
):
builder = self._graph

path_map = path_map or {
"trends": AssistantNodeName.TRENDS_PLANNER,
"funnel": AssistantNodeName.FUNNEL_PLANNER,
}
router_node = RouterNode(self._team)
builder.add_node(AssistantNodeName.ROUTER, router_node.run)
builder.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)
builder.add_conditional_edges(
AssistantNodeName.ROUTER,
router_node.router,
path_map={"trends": AssistantNodeName.TRENDS_PLANNER, "funnel": AssistantNodeName.FUNNEL_PLANNER},
path_map=cast(dict[Hashable, str], path_map),
)
return self

def add_trends_planner(self, next_node: AssistantNodeName = AssistantNodeName.TRENDS_GENERATOR):
builder = self._graph

create_trends_plan_node = TrendsPlannerNode(self._team)
builder.add_node(AssistantNodeName.TRENDS_PLANNER, create_trends_plan_node.run)
Expand All @@ -111,26 +135,36 @@ def _compile_graph(self):
create_trends_plan_tools_node.router,
path_map={
"continue": AssistantNodeName.TRENDS_PLANNER,
"plan_found": AssistantNodeName.TRENDS_GENERATOR,
"plan_found": next_node,
},
)

generate_trends_node = TrendsGeneratorNode(self._team)
builder.add_node(AssistantNodeName.TRENDS_GENERATOR, generate_trends_node.run)
return self

def add_trends_generator(self, next_node: AssistantNodeName = AssistantNodeName.SUMMARIZER):
builder = self._graph

trends_generator = TrendsGeneratorNode(self._team)
builder.add_node(AssistantNodeName.TRENDS_GENERATOR, trends_generator.run)

generate_trends_tools_node = TrendsGeneratorToolsNode(self._team)
builder.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, generate_trends_tools_node.run)
trends_generator_tools = TrendsGeneratorToolsNode(self._team)
builder.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, trends_generator_tools.run)

builder.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR)
builder.add_conditional_edges(
AssistantNodeName.TRENDS_GENERATOR,
generate_trends_node.router,
trends_generator.router,
path_map={
"tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS,
"next": AssistantNodeName.SUMMARIZER,
"next": next_node,
},
)

return self

def add_funnel_planner(self, next_node: AssistantNodeName = AssistantNodeName.FUNNEL_GENERATOR):
builder = self._graph

funnel_planner = FunnelPlannerNode(self._team)
builder.add_node(AssistantNodeName.FUNNEL_PLANNER, funnel_planner.run)
builder.add_conditional_edges(
Expand All @@ -148,41 +182,69 @@ def _compile_graph(self):
funnel_planner_tools.router,
path_map={
"continue": AssistantNodeName.FUNNEL_PLANNER,
"plan_found": AssistantNodeName.FUNNEL_GENERATOR,
"plan_found": next_node,
},
)

return self

def add_funnel_generator(self, next_node: AssistantNodeName = AssistantNodeName.SUMMARIZER):
builder = self._graph

funnel_generator = FunnelGeneratorNode(self._team)
builder.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator.run)

funnel_generator_tools_node = FunnelGeneratorToolsNode(self._team)
builder.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools_node.run)
funnel_generator_tools = FunnelGeneratorToolsNode(self._team)
builder.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools.run)

builder.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR)
builder.add_conditional_edges(
AssistantNodeName.FUNNEL_GENERATOR,
generate_trends_node.router,
funnel_generator.router,
path_map={
"tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS,
"next": AssistantNodeName.SUMMARIZER,
"next": next_node,
},
)

return self

def add_summarizer(self, next_node: AssistantNodeName = AssistantNodeName.END):
builder = self._graph
summarizer_node = SummarizerNode(self._team)
builder.add_node(AssistantNodeName.SUMMARIZER, summarizer_node.run)
builder.add_edge(AssistantNodeName.SUMMARIZER, AssistantNodeName.END)
builder.add_edge(AssistantNodeName.SUMMARIZER, next_node)
return self

def compile_full_graph(self):
return (
self.add_start()
.add_router()
.add_trends_planner()
.add_trends_generator()
.add_funnel_planner()
.add_funnel_generator()
.add_summarizer()
.compile()
)


return builder.compile()
class Assistant:
_team: Team
_graph: CompiledStateGraph

def __init__(self, team: Team):
self._team = team
self._graph = AssistantGraph(team).compile_full_graph()

def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]:
assistant_graph = self._compile_graph()
callbacks = [langfuse_handler] if langfuse_handler else []
messages = [message.root for message in conversation.messages]

chunks = AIMessageChunk(content="")
state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}

generator = assistant_graph.stream(
generator: Iterator[Any] = self._graph.stream(
state,
config={"recursion_limit": 24, "callbacks": callbacks},
stream_mode=["messages", "values", "updates"],
Expand Down
Empty file added ee/hogai/eval/__init__.py
Empty file.
179 changes: 179 additions & 0 deletions ee/hogai/eval/test_eval_funnel_planner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from deepeval import assert_test
from deepeval.metrics import GEval
from deepeval.test_case import LLMTestCase, LLMTestCaseParams
from langgraph.graph.state import CompiledStateGraph

from ee.hogai.assistant import AssistantGraph
from ee.hogai.eval.utils import EvalBaseTest
from ee.hogai.utils import AssistantNodeName
from posthog.schema import HumanMessage


class TestEvalFunnelPlanner(EvalBaseTest):
def _get_plan_correctness_metric(self):
return GEval(
name="Funnel Plan Correctness",
criteria="You will be given expected and actual generated plans to provide a taxonomy to answer a user's question with a funnel insight. Compare the plans to determine whether the taxonomy of the actual plan matches the expected plan. Do not apply general knowledge about funnel insights.",
evaluation_steps=[
"A plan must define at least two series in the sequence, but it is not required to define any filters, exclusion steps, or a breakdown.",
"Compare events, properties, math types, and property values of 'expected output' and 'actual output'.",
"Check if the combination of events, properties, and property values in 'actual output' can answer the user's question according to the 'expected output'.",
# The criteria for aggregations must be more specific because there isn't a way to bypass them.
"Check if the math types in 'actual output' match those in 'expected output.' If the aggregation type is specified by a property, user, or group in 'expected output', the same property, user, or group must be used in 'actual output'.",
"If 'expected output' contains exclusion steps, check if 'actual output' contains those, and heavily penalize if the exclusion steps are not present or different.",
"If 'expected output' contains a breakdown, check if 'actual output' contains a similar breakdown, and heavily penalize if the breakdown is not present or different. Plans may only have one breakdown.",
# We don't want to see in the output unnecessary property filters. The assistant tries to use them all the time.
"Heavily penalize if the 'actual output' contains any excessive output not present in the 'expected output'. For example, the `is set` operator in filters should not be used unless the user explicitly asks for it.",
],
evaluation_params=[
LLMTestCaseParams.INPUT,
LLMTestCaseParams.EXPECTED_OUTPUT,
LLMTestCaseParams.ACTUAL_OUTPUT,
],
threshold=0.7,
)

def _call_node(self, query):
graph: CompiledStateGraph = (
AssistantGraph(self.team)
.add_edge(AssistantNodeName.START, AssistantNodeName.FUNNEL_PLANNER)
.add_funnel_planner(AssistantNodeName.END)
.compile()
)
state = graph.invoke({"messages": [HumanMessage(content=query)]})
return state["plan"]

def test_basic_funnel(self):
query = "what was the conversion from a page view to sign up?"
test_case = LLMTestCase(
input=query,
expected_output="""
Sequence:
1. $pageview
2. signed_up
""",
actual_output=self._call_node(query),
)
assert_test(test_case, [self._get_plan_correctness_metric()])

def test_outputs_at_least_two_events(self):
"""
Ambigious query. The funnel must return at least two events.
"""
query = "how many users paid a bill?"
test_case = LLMTestCase(
input=query,
expected_output="""
Sequence:
1. any event
2. upgrade_plan
""",
actual_output=self._call_node(query),
)
assert_test(test_case, [self._get_plan_correctness_metric()])

def test_no_excessive_property_filters(self):
query = "Show the user conversion from a sign up to a file download"
test_case = LLMTestCase(
input=query,
expected_output="""
Sequence:
1. signed_up
2. downloaded_file
""",
actual_output=self._call_node(query),
)
assert_test(test_case, [self._get_plan_correctness_metric()])

def test_basic_filtering(self):
query = (
"What was the conversion from uploading a file to downloading it from Chrome and Safari in the last 30d?"
)
test_case = LLMTestCase(
input=query,
expected_output="""
Sequence:
1. uploaded_file
- property filter 1:
- entity: event
- property name: $browser
- property type: String
- operator: equals
- property value: Chrome
- property filter 2:
- entity: event
- property name: $browser
- property type: String
- operator: equals
- property value: Safari
2. downloaded_file
- property filter 1:
- entity: event
- property name: $browser
- property type: String
- operator: equals
- property value: Chrome
- property filter 2:
- entity: event
- property name: $browser
- property type: String
- operator: equals
- property value: Safari
""",
actual_output=self._call_node(query),
)
assert_test(test_case, [self._get_plan_correctness_metric()])

def test_exclusion_steps(self):
query = "What was the conversion from uploading a file to downloading it in the last 30d excluding users that deleted a file?"
test_case = LLMTestCase(
input=query,
expected_output="""
Sequence:
1. uploaded_file
2. downloaded_file
Exclusions:
- deleted_file
- start index: 0
- end index: 1
""",
actual_output=self._call_node(query),
)
assert_test(test_case, [self._get_plan_correctness_metric()])

def test_breakdown(self):
query = "Show a conversion from uploading a file to downloading it segmented by a user's email"
test_case = LLMTestCase(
input=query,
expected_output="""
Sequence:
1. uploaded_file
2. downloaded_file
Breakdown by:
- entity: person
- property name: email
""",
actual_output=self._call_node(query),
)
assert_test(test_case, [self._get_plan_correctness_metric()])

def test_needle_in_a_haystack(self):
query = "What was the conversion from a sign up to a paying customer on the personal-pro plan?"
test_case = LLMTestCase(
input=query,
expected_output="""
Sequence:
1. signed_up
2. paid_bill
- property filter 1:
- entity: event
- property name: plan
- property type: String
- operator: equals
- property value: personal/pro
""",
actual_output=self._call_node(query),
)
assert_test(test_case, [self._get_plan_correctness_metric()])
Loading

0 comments on commit d836bc8

Please sign in to comment.