Skip to content

Commit

Permalink
Merge branch 'master' of github.com:PostHog/posthog into feat/improve…
Browse files Browse the repository at this point in the history
…d-failover-for-taxonomy-agent
  • Loading branch information
skoob13 committed Oct 30, 2024
2 parents 03af295 + 8de5762 commit ef95d09
Show file tree
Hide file tree
Showing 167 changed files with 9,539 additions and 7,730 deletions.
6 changes: 3 additions & 3 deletions ee/api/feature_flag_role_access.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from rest_framework import exceptions, mixins, serializers, viewsets
from rest_framework.permissions import SAFE_METHODS, BasePermission

from ee.api.role import RoleSerializer
from ee.api.rbac.role import RoleSerializer
from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role
from posthog.api.feature_flag import FeatureFlagSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.models import FeatureFlag
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from rest_framework import mixins, serializers, viewsets

from ee.api.role import RolePermissions
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.api.rbac.role import RolePermissions
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from posthog.api.routing import TeamAndOrgViewSetMixin


Expand Down
4 changes: 2 additions & 2 deletions ee/api/role.py → ee/api/rbac/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from rest_framework.permissions import SAFE_METHODS, BasePermission

from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role, RoleMembership
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role, RoleMembership
from posthog.api.organization_member import OrganizationMemberSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
Expand Down
4 changes: 2 additions & 2 deletions ee/api/test/test_feature_flag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ee.api.test.base import APILicensedTest
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role, RoleMembership
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role, RoleMembership
from posthog.models.feature_flag import FeatureFlag
from posthog.models.organization import OrganizationMembership

Expand Down
4 changes: 2 additions & 2 deletions ee/api/test/test_feature_flag_role_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from ee.api.test.base import APILicensedTest
from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role
from posthog.models.feature_flag import FeatureFlag
from posthog.models.organization import OrganizationMembership
from posthog.models.user import User
Expand Down
2 changes: 1 addition & 1 deletion ee/api/test/test_organization_resource_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rest_framework import status

from ee.api.test.base import APILicensedTest
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from posthog.models.organization import Organization, OrganizationMembership
from posthog.test.base import QueryMatchingTest, snapshot_postgres_queries, FuzzyInt

Expand Down
4 changes: 2 additions & 2 deletions ee/api/test/test_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from rest_framework import status

from ee.api.test.base import APILicensedTest
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role
from posthog.models.organization import Organization, OrganizationMembership


Expand Down
2 changes: 1 addition & 1 deletion ee/api/test/test_role_membership.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from rest_framework import status

from ee.api.test.base import APILicensedTest
from ee.models.role import Role, RoleMembership
from ee.models.rbac.role import Role, RoleMembership
from posthog.models.organization import Organization, OrganizationMembership
from posthog.models.user import User

Expand Down
1 change: 1 addition & 0 deletions ee/clickhouse/views/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class Meta:
"created_by",
"created_at",
"updated_at",
"metrics",
]
read_only_fields = [
"id",
Expand Down
59 changes: 45 additions & 14 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
from langgraph.graph.state import StateGraph
from pydantic import BaseModel

from ee import settings
from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode
from ee.hogai.trends.nodes import (
CreateTrendsPlanNode,
CreateTrendsPlanToolsNode,
GenerateTrendsNode,
GenerateTrendsToolsNode,
)
from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation
from posthog.models.team.team import Team
from posthog.schema import VisualizationMessage
from posthog.schema import AssistantGenerationStatusEvent, AssistantGenerationStatusType, VisualizationMessage

if settings.LANGFUSE_PUBLIC_KEY:
langfuse_handler = CallbackHandler(
Expand Down Expand Up @@ -39,6 +45,13 @@ def is_message_update(
return len(update) == 2 and update[0] == "messages"


def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]:
"""
Update of the state.
"""
return len(update) == 2 and update[0] == "values"


class Assistant:
_team: Team
_graph: StateGraph
Expand All @@ -59,38 +72,56 @@ def _compile_graph(self):
generate_trends_node = GenerateTrendsNode(self._team)
builder.add_node(GenerateTrendsNode.name, generate_trends_node.run)

generate_trends_tools_node = GenerateTrendsToolsNode(self._team)
builder.add_node(GenerateTrendsToolsNode.name, generate_trends_tools_node.run)
builder.add_edge(GenerateTrendsToolsNode.name, GenerateTrendsNode.name)

builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name)
builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router)
builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router)
builder.add_conditional_edges(GenerateTrendsNode.name, generate_trends_node.router)

return builder.compile()

def stream(self, conversation: Conversation) -> Generator[str, None, None]:
def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]:
assistant_graph = self._compile_graph()
callbacks = [langfuse_handler] if langfuse_handler else []
messages = [message.root for message in conversation.messages]

chunks = AIMessageChunk(content="")
state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}

generator = assistant_graph.stream(
{"messages": messages},
state,
config={"recursion_limit": 24, "callbacks": callbacks},
stream_mode=["messages", "updates"],
stream_mode=["messages", "values", "updates"],
)

chunks = AIMessageChunk(content="")

# Send a chunk to establish the connection avoiding the worker's timeout.
yield ""
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)

for update in generator:
if is_value_update(update):
if is_state_update(update):
_, new_state = update
state = new_state

elif is_value_update(update):
_, state_update = update
if (
AssistantNodeName.GENERATE_TRENDS in state_update
and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]
):
message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0])
yield message.model_dump_json()

if AssistantNodeName.GENERATE_TRENDS in state_update:
# Reset chunks when schema validation fails.
chunks = AIMessageChunk(content="")

if "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]:
message = cast(
VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]
)
yield message
elif state_update[AssistantNodeName.GENERATE_TRENDS].get("intermediate_steps", []):
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)

elif is_message_update(update):
langchain_message, langgraph_state = update[1]
if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance(
Expand All @@ -101,4 +132,4 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]:
if parsed_message:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
).model_dump_json()
)
Loading

0 comments on commit ef95d09

Please sign in to comment.