Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
aspicer committed Oct 30, 2024
2 parents 250ed21 + 68ddb00 commit 4673c99
Show file tree
Hide file tree
Showing 106 changed files with 3,933 additions and 1,951 deletions.
14 changes: 13 additions & 1 deletion .github/workflows/ci-backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,20 @@ jobs:
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:23.12.6.19'
python-version: '3.11.9'
concurrency: 1
concurrency: 3
group: 1
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:23.12.6.19'
python-version: '3.11.9'
concurrency: 3
group: 2
- segment: 'Temporal'
person-on-events: false
clickhouse-server-image: 'clickhouse/clickhouse-server:23.12.6.19'
python-version: '3.11.9'
concurrency: 3
group: 3

steps:
# The first step is the only one that should run if `needs.changes.outputs.backend == 'false'`.
Expand Down
1,946 changes: 1,222 additions & 724 deletions .test_durations

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion docker-compose.base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ services:
environment:
- TEMPORAL_ADDRESS=temporal:7233
- TEMPORAL_CORS_ORIGINS=http://localhost:3000
image: temporalio/ui:2.10.3
- TEMPORAL_CSRF_COOKIE_INSECURE=true
image: temporalio/ui:2.31.2
ports:
- 8081:8080
temporal-django-worker:
Expand Down
6 changes: 3 additions & 3 deletions ee/api/feature_flag_role_access.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from rest_framework import exceptions, mixins, serializers, viewsets
from rest_framework.permissions import SAFE_METHODS, BasePermission

from ee.api.role import RoleSerializer
from ee.api.rbac.role import RoleSerializer
from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role
from posthog.api.feature_flag import FeatureFlagSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.models import FeatureFlag
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from rest_framework import mixins, serializers, viewsets

from ee.api.role import RolePermissions
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.api.rbac.role import RolePermissions
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from posthog.api.routing import TeamAndOrgViewSetMixin


Expand Down
4 changes: 2 additions & 2 deletions ee/api/role.py → ee/api/rbac/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from rest_framework.permissions import SAFE_METHODS, BasePermission

from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role, RoleMembership
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role, RoleMembership
from posthog.api.organization_member import OrganizationMemberSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
Expand Down
4 changes: 2 additions & 2 deletions ee/api/test/test_feature_flag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ee.api.test.base import APILicensedTest
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role, RoleMembership
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role, RoleMembership
from posthog.models.feature_flag import FeatureFlag
from posthog.models.organization import OrganizationMembership

Expand Down
4 changes: 2 additions & 2 deletions ee/api/test/test_feature_flag_role_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from ee.api.test.base import APILicensedTest
from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role
from posthog.models.feature_flag import FeatureFlag
from posthog.models.organization import OrganizationMembership
from posthog.models.user import User
Expand Down
2 changes: 1 addition & 1 deletion ee/api/test/test_organization_resource_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rest_framework import status

from ee.api.test.base import APILicensedTest
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from posthog.models.organization import Organization, OrganizationMembership
from posthog.test.base import QueryMatchingTest, snapshot_postgres_queries, FuzzyInt

Expand Down
4 changes: 2 additions & 2 deletions ee/api/test/test_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from rest_framework import status

from ee.api.test.base import APILicensedTest
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role
from ee.models.rbac.organization_resource_access import OrganizationResourceAccess
from ee.models.rbac.role import Role
from posthog.models.organization import Organization, OrganizationMembership


Expand Down
2 changes: 1 addition & 1 deletion ee/api/test/test_role_membership.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from rest_framework import status

from ee.api.test.base import APILicensedTest
from ee.models.role import Role, RoleMembership
from ee.models.rbac.role import Role, RoleMembership
from posthog.models.organization import Organization, OrganizationMembership
from posthog.models.user import User

Expand Down
1 change: 1 addition & 0 deletions ee/billing/billing_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def update_billing_organization_users(self, organization: Organization) -> None:
"distinct_id",
"organization_membership__level",
)
.order_by("email") # Deterministic order for tests
.annotate(role=F("organization_membership__level"))
.filter(role__gte=OrganizationMembership.Level.ADMIN)
.values(
Expand Down
1 change: 1 addition & 0 deletions ee/clickhouse/views/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class Meta:
"created_by",
"created_at",
"updated_at",
"metrics",
]
read_only_fields = [
"id",
Expand Down
59 changes: 45 additions & 14 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
from langgraph.graph.state import StateGraph
from pydantic import BaseModel

from ee import settings
from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode
from ee.hogai.trends.nodes import (
CreateTrendsPlanNode,
CreateTrendsPlanToolsNode,
GenerateTrendsNode,
GenerateTrendsToolsNode,
)
from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation
from posthog.models.team.team import Team
from posthog.schema import VisualizationMessage
from posthog.schema import AssistantGenerationStatusEvent, AssistantGenerationStatusType, VisualizationMessage

if settings.LANGFUSE_PUBLIC_KEY:
langfuse_handler = CallbackHandler(
Expand Down Expand Up @@ -39,6 +45,13 @@ def is_message_update(
return len(update) == 2 and update[0] == "messages"


def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]:
"""
Update of the state.
"""
return len(update) == 2 and update[0] == "values"


class Assistant:
_team: Team
_graph: StateGraph
Expand All @@ -59,38 +72,56 @@ def _compile_graph(self):
generate_trends_node = GenerateTrendsNode(self._team)
builder.add_node(GenerateTrendsNode.name, generate_trends_node.run)

generate_trends_tools_node = GenerateTrendsToolsNode(self._team)
builder.add_node(GenerateTrendsToolsNode.name, generate_trends_tools_node.run)
builder.add_edge(GenerateTrendsToolsNode.name, GenerateTrendsNode.name)

builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name)
builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router)
builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router)
builder.add_conditional_edges(GenerateTrendsNode.name, generate_trends_node.router)

return builder.compile()

def stream(self, conversation: Conversation) -> Generator[str, None, None]:
def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]:
assistant_graph = self._compile_graph()
callbacks = [langfuse_handler] if langfuse_handler else []
messages = [message.root for message in conversation.messages]

chunks = AIMessageChunk(content="")
state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}

generator = assistant_graph.stream(
{"messages": messages},
state,
config={"recursion_limit": 24, "callbacks": callbacks},
stream_mode=["messages", "updates"],
stream_mode=["messages", "values", "updates"],
)

chunks = AIMessageChunk(content="")

# Send a chunk to establish the connection avoiding the worker's timeout.
yield ""
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)

for update in generator:
if is_value_update(update):
if is_state_update(update):
_, new_state = update
state = new_state

elif is_value_update(update):
_, state_update = update
if (
AssistantNodeName.GENERATE_TRENDS in state_update
and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]
):
message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0])
yield message.model_dump_json()

if AssistantNodeName.GENERATE_TRENDS in state_update:
# Reset chunks when schema validation fails.
chunks = AIMessageChunk(content="")

if "messages" in state_update[AssistantNodeName.GENERATE_TRENDS]:
message = cast(
VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]
)
yield message
elif state_update[AssistantNodeName.GENERATE_TRENDS].get("intermediate_steps", []):
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)

elif is_message_update(update):
langchain_message, langgraph_state = update[1]
if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance(
Expand All @@ -101,4 +132,4 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]:
if parsed_message:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
).model_dump_json()
)
Loading

0 comments on commit 4673c99

Please sign in to comment.