Skip to content

Commit

Permalink
Merge branch 'master' into experiments/improve-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbachhuber committed Dec 20, 2024
2 parents 435b0f6 + 02ad064 commit 8dafec1
Show file tree
Hide file tree
Showing 111 changed files with 3,242 additions and 1,856 deletions.
2 changes: 1 addition & 1 deletion cypress/e2e/persons.cy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ describe('Persons', () => {
})

it('All tabs work', () => {
cy.get('h1').should('contain', 'People')
cy.get('h1').should('contain', 'Persons')
cy.get('[data-attr=persons-search]').type('marisol').type('{enter}').should('have.value', 'marisol')
cy.wait(200)
cy.get('[data-row-key]').its('length').should('be.gte', 0)
Expand Down
4 changes: 3 additions & 1 deletion ee/api/test/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def test_create_hog_function_via_hook(self):
"target": "https://hooks.zapier.com/{inputs.hook}",
},
},
"order": 2,
},
"debug": {},
"debug": {"order": 1},
"hook": {
"bytecode": [
"_H",
Expand All @@ -149,6 +150,7 @@ def test_create_hog_function_via_hook(self):
"hooks/standard/1234/abcd",
],
"value": "hooks/standard/1234/abcd",
"order": 0,
},
}

Expand Down
24 changes: 6 additions & 18 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
from collections.abc import AsyncGenerator, Generator, Iterator
from functools import partial
from collections.abc import Generator, Iterator
from typing import Any, Optional
from uuid import uuid4

from asgiref.sync import sync_to_async
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from langfuse.callback import CallbackHandler
Expand All @@ -20,6 +18,7 @@
from ee.hogai.trends.nodes import (
TrendsGeneratorNode,
)
from ee.hogai.utils.asgi import SyncIterableToAsync
from ee.hogai.utils.state import (
GraphMessageUpdateTuple,
GraphTaskStartedUpdateTuple,
Expand Down Expand Up @@ -91,14 +90,8 @@ def stream(self):
return self._astream()
return self._stream()

async def _astream(self) -> AsyncGenerator[str, None]:
generator = self._stream()
while True:
try:
if message := await sync_to_async(partial(next, generator), thread_sensitive=False)():
yield message
except StopIteration:
break
def _astream(self):
return SyncIterableToAsync(self._stream())

def _stream(self) -> Generator[str, None, None]:
state = self._init_or_update_state()
Expand Down Expand Up @@ -155,13 +148,8 @@ def _init_or_update_state(self):
if snapshot.next:
saved_state = validate_state_update(snapshot.values)
self._state = saved_state
if saved_state.intermediate_steps:
intermediate_steps = saved_state.intermediate_steps.copy()
intermediate_steps[-1] = (intermediate_steps[-1][0], self._latest_message.content)
self._graph.update_state(
config,
PartialAssistantState(messages=[self._latest_message], intermediate_steps=intermediate_steps),
)
self._graph.update_state(config, PartialAssistantState(messages=[self._latest_message], resumed=True))

return None
initial_state = self._initial_state
self._state = initial_state
Expand Down
5 changes: 4 additions & 1 deletion ee/hogai/django_checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def _get_checkpoint_channel_values(
query = Q()
for channel, version in loaded_checkpoint["channel_versions"].items():
query |= Q(channel=channel, version=version)
return checkpoint.blobs.filter(query)
return ConversationCheckpointBlob.objects.filter(
Q(thread_id=checkpoint.thread_id, checkpoint_ns=checkpoint.checkpoint_ns) & query
)

def list(
self,
Expand Down Expand Up @@ -238,6 +240,7 @@ def put(
blobs.append(
ConversationCheckpointBlob(
checkpoint=updated_checkpoint,
thread_id=thread_id,
channel=channel,
version=str(version),
type=type,
Expand Down
153 changes: 152 additions & 1 deletion ee/hogai/django_checkpoint/test/test_checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore

from typing import Any, TypedDict
import operator
from typing import Annotated, Any, Optional, TypedDict

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Expand All @@ -13,6 +14,7 @@
from langgraph.errors import NodeInterrupt
from langgraph.graph import END, START
from langgraph.graph.state import CompiledStateGraph, StateGraph
from pydantic import BaseModel, Field

from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.models.assistant import (
Expand Down Expand Up @@ -272,3 +274,152 @@ def test_resuming(self):
self.assertEqual(res, {"val": 3})
snapshot = graph.get_state(config)
self.assertFalse(snapshot.next)

def test_checkpoint_blobs_are_bound_to_thread(self):
class State(TypedDict, total=False):
messages: Annotated[list[str], operator.add]
string: Optional[str]

graph = StateGraph(State)

def handle_node1(state: State):
return

def handle_node2(state: State):
raise NodeInterrupt("test")

graph.add_node("node1", handle_node1)
graph.add_node("node2", handle_node2)

graph.add_edge(START, "node1")
graph.add_edge("node1", "node2")
graph.add_edge("node2", END)

compiled = graph.compile(checkpointer=DjangoCheckpointer())

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"], "string": "world"}, config=config)

snapshot = compiled.get_state(config)
self.assertIsNotNone(snapshot.next)
self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test")
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello"])
self.assertEqual(saved_state["string"], "world")

def test_checkpoint_can_save_and_load_pydantic_state(self):
class State(BaseModel):
messages: Annotated[list[str], operator.add]
string: Optional[str]

class PartialState(BaseModel):
messages: Optional[list[str]] = Field(default=None)
string: Optional[str] = Field(default=None)

graph = StateGraph(State)

def handle_node1(state: State):
return PartialState()

def handle_node2(state: State):
raise NodeInterrupt("test")

graph.add_node("node1", handle_node1)
graph.add_node("node2", handle_node2)

graph.add_edge(START, "node1")
graph.add_edge("node1", "node2")
graph.add_edge("node2", END)

compiled = graph.compile(checkpointer=DjangoCheckpointer())

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"], "string": "world"}, config=config)

snapshot = compiled.get_state(config)
self.assertIsNotNone(snapshot.next)
self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test")
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello"])
self.assertEqual(saved_state["string"], "world")

def test_saved_blobs(self):
class State(TypedDict, total=False):
messages: Annotated[list[str], operator.add]

graph = StateGraph(State)

def handle_node1(state: State):
return {"messages": ["world"]}

graph.add_node("node1", handle_node1)

graph.add_edge(START, "node1")
graph.add_edge("node1", END)

checkpointer = DjangoCheckpointer()
compiled = graph.compile(checkpointer=checkpointer)

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"]}, config=config)

snapshot = compiled.get_state(config)
self.assertFalse(snapshot.next)
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello", "world"])

blobs = list(ConversationCheckpointBlob.objects.filter(thread=thread))
self.assertEqual(len(blobs), 7)

# Set initial state
self.assertEqual(blobs[0].channel, "__start__")
self.assertEqual(blobs[0].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[0].type, blobs[0].blob)),
{"messages": ["hello"]},
)

# Set first node
self.assertEqual(blobs[1].channel, "__start__")
self.assertEqual(blobs[1].type, "empty")
self.assertIsNone(blobs[1].blob)

# Set value channels before start
self.assertEqual(blobs[2].channel, "messages")
self.assertEqual(blobs[2].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[2].type, blobs[2].blob)),
["hello"],
)

# Transition to node1
self.assertEqual(blobs[3].channel, "start:node1")
self.assertEqual(blobs[3].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[3].type, blobs[3].blob)),
"__start__",
)

# Set new state for messages
self.assertEqual(blobs[4].channel, "messages")
self.assertEqual(blobs[4].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[4].type, blobs[4].blob)),
["hello", "world"],
)

# After setting a state
self.assertEqual(blobs[5].channel, "start:node1")
self.assertEqual(blobs[5].type, "empty")
self.assertIsNone(blobs[5].blob)

# Set last step
self.assertEqual(blobs[6].channel, "node1")
self.assertEqual(blobs[6].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[6].type, blobs[6].blob)),
"node1",
)
3 changes: 2 additions & 1 deletion ee/hogai/funnels/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_node_runs(self):
new_state,
PartialAssistantState(
messages=[VisualizationMessage(answer=self.schema, plan="Plan", id=new_state.messages[0].id)],
intermediate_steps=None,
intermediate_steps=[],
plan="",
),
)
22 changes: 12 additions & 10 deletions ee/hogai/schema_generator/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _run_with_prompt(
content=f"Oops! It looks like I’m having trouble generating this {self.INSIGHT_NAME} insight. Could you please try again?"
)
],
intermediate_steps=None,
intermediate_steps=[],
plan="",
)

return PartialAssistantState(
Expand All @@ -106,16 +107,17 @@ def _run_with_prompt(
],
)

final_message = VisualizationMessage(
plan=generated_plan,
answer=message.query,
initiator=start_id,
id=str(uuid4()),
)

return PartialAssistantState(
messages=[
VisualizationMessage(
plan=generated_plan,
answer=message.query,
initiator=start_id,
id=str(uuid4()),
)
],
intermediate_steps=None,
messages=[final_message],
intermediate_steps=[],
plan="",
)

def router(self, state: AssistantState):
Expand Down
10 changes: 6 additions & 4 deletions ee/hogai/schema_generator/test/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def test_node_runs(self):
),
{},
)
self.assertIsNone(new_state.intermediate_steps)
self.assertEqual(new_state.intermediate_steps, [])
self.assertEqual(new_state.plan, "")
self.assertEqual(len(new_state.messages), 1)
self.assertEqual(new_state.messages[0].type, "ai/viz")
self.assertEqual(new_state.messages[0].answer, self.schema)
Expand Down Expand Up @@ -316,7 +317,7 @@ def test_node_leaves_failover(self):
),
{},
)
self.assertIsNone(new_state.intermediate_steps)
self.assertEqual(new_state.intermediate_steps, [])

new_state = node.run(
AssistantState(
Expand All @@ -328,7 +329,7 @@ def test_node_leaves_failover(self):
),
{},
)
self.assertIsNone(new_state.intermediate_steps)
self.assertEqual(new_state.intermediate_steps, [])

def test_node_leaves_failover_after_second_unsuccessful_attempt(self):
node = DummyGeneratorNode(self.team)
Expand All @@ -348,9 +349,10 @@ def test_node_leaves_failover_after_second_unsuccessful_attempt(self):
),
{},
)
self.assertIsNone(new_state.intermediate_steps)
self.assertEqual(new_state.intermediate_steps, [])
self.assertEqual(len(new_state.messages), 1)
self.assertIsInstance(new_state.messages[0], FailureMessage)
self.assertEqual(new_state.plan, "")

def test_agent_reconstructs_conversation_with_failover(self):
action = AgentAction(tool="fix", tool_input="validation error", log="exception")
Expand Down
8 changes: 8 additions & 0 deletions ee/hogai/summarizer/nodes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
import json
from time import sleep
from uuid import uuid4

from django.conf import settings
from django.utils import timezone
from django.core.serializers.json import DjangoJSONEncoder
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -76,11 +78,17 @@ def run(self, state: AssistantState, config: RunnableConfig) -> PartialAssistant

chain = summarization_prompt | self._model

utc_now = timezone.now().astimezone(datetime.UTC)
project_now = utc_now.astimezone(self._team.timezone_info)

message = chain.invoke(
{
"query_kind": viz_message.answer.kind,
"product_description": self._team.project.product_description,
"results": json.dumps(results_response["results"], cls=DjangoJSONEncoder),
"utc_datetime_display": utc_now.strftime("%Y-%m-%d %H:%M:%S"),
"project_datetime_display": project_now.strftime("%Y-%m-%d %H:%M:%S"),
"project_timezone": self._team.timezone_info.tzname(utc_now),
},
config,
)
Expand Down
Loading

0 comments on commit 8dafec1

Please sign in to comment.