Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(product-assistant): checkpoint blob queries must not rely on checkpoint #27048

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ee/hogai/django_checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def _get_checkpoint_channel_values(
query = Q()
for channel, version in loaded_checkpoint["channel_versions"].items():
query |= Q(channel=channel, version=version)
return checkpoint.blobs.filter(query)
return ConversationCheckpointBlob.objects.filter(
Q(thread_id=checkpoint.thread_id, checkpoint_ns=checkpoint.checkpoint_ns) & query
)

def list(
self,
Expand Down Expand Up @@ -238,6 +240,7 @@ def put(
blobs.append(
ConversationCheckpointBlob(
checkpoint=updated_checkpoint,
thread_id=thread_id,
channel=channel,
version=str(version),
type=type,
Expand Down
153 changes: 152 additions & 1 deletion ee/hogai/django_checkpoint/test/test_checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore

from typing import Any, TypedDict
import operator
from typing import Annotated, Any, Optional, TypedDict

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Expand All @@ -13,6 +14,7 @@
from langgraph.errors import NodeInterrupt
from langgraph.graph import END, START
from langgraph.graph.state import CompiledStateGraph, StateGraph
from pydantic import BaseModel, Field

from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.models.assistant import (
Expand Down Expand Up @@ -272,3 +274,152 @@ def test_resuming(self):
self.assertEqual(res, {"val": 3})
snapshot = graph.get_state(config)
self.assertFalse(snapshot.next)

def test_checkpoint_blobs_are_bound_to_thread(self):
class State(TypedDict, total=False):
messages: Annotated[list[str], operator.add]
string: Optional[str]

graph = StateGraph(State)

def handle_node1(state: State):
return

def handle_node2(state: State):
raise NodeInterrupt("test")

graph.add_node("node1", handle_node1)
graph.add_node("node2", handle_node2)

graph.add_edge(START, "node1")
graph.add_edge("node1", "node2")
graph.add_edge("node2", END)

compiled = graph.compile(checkpointer=DjangoCheckpointer())

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"], "string": "world"}, config=config)

snapshot = compiled.get_state(config)
self.assertIsNotNone(snapshot.next)
self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test")
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello"])
self.assertEqual(saved_state["string"], "world")

def test_checkpoint_can_save_and_load_pydantic_state(self):
class State(BaseModel):
messages: Annotated[list[str], operator.add]
string: Optional[str]

class PartialState(BaseModel):
messages: Optional[list[str]] = Field(default=None)
string: Optional[str] = Field(default=None)

graph = StateGraph(State)

def handle_node1(state: State):
return PartialState()

def handle_node2(state: State):
raise NodeInterrupt("test")

graph.add_node("node1", handle_node1)
graph.add_node("node2", handle_node2)

graph.add_edge(START, "node1")
graph.add_edge("node1", "node2")
graph.add_edge("node2", END)

compiled = graph.compile(checkpointer=DjangoCheckpointer())

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"], "string": "world"}, config=config)

snapshot = compiled.get_state(config)
self.assertIsNotNone(snapshot.next)
self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test")
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello"])
self.assertEqual(saved_state["string"], "world")

def test_saved_blobs(self):
class State(TypedDict, total=False):
messages: Annotated[list[str], operator.add]

graph = StateGraph(State)

def handle_node1(state: State):
return {"messages": ["world"]}

graph.add_node("node1", handle_node1)

graph.add_edge(START, "node1")
graph.add_edge("node1", END)

checkpointer = DjangoCheckpointer()
compiled = graph.compile(checkpointer=checkpointer)

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"]}, config=config)

snapshot = compiled.get_state(config)
self.assertFalse(snapshot.next)
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello", "world"])

blobs = list(ConversationCheckpointBlob.objects.filter(thread=thread))
self.assertEqual(len(blobs), 7)

# Set initial state
self.assertEqual(blobs[0].channel, "__start__")
self.assertEqual(blobs[0].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[0].type, blobs[0].blob)),
{"messages": ["hello"]},
)

# Set first node
self.assertEqual(blobs[1].channel, "__start__")
self.assertEqual(blobs[1].type, "empty")
self.assertIsNone(blobs[1].blob)

# Set value channels before start
self.assertEqual(blobs[2].channel, "messages")
self.assertEqual(blobs[2].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[2].type, blobs[2].blob)),
["hello"],
)

# Transition to node1
self.assertEqual(blobs[3].channel, "start:node1")
self.assertEqual(blobs[3].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[3].type, blobs[3].blob)),
"__start__",
)

# Set new state for messages
self.assertEqual(blobs[4].channel, "messages")
self.assertEqual(blobs[4].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[4].type, blobs[4].blob)),
["hello", "world"],
)

# After setting a state
self.assertEqual(blobs[5].channel, "start:node1")
self.assertEqual(blobs[5].type, "empty")
self.assertIsNone(blobs[5].blob)

# Set last step
self.assertEqual(blobs[6].channel, "node1")
self.assertEqual(blobs[6].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[6].type, blobs[6].blob)),
"node1",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Generated by Django 4.2.15 on 2024-12-19 11:00

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
("ee", "0018_conversation_conversationcheckpoint_and_more"),
]

operations = [
migrations.RemoveConstraint(
model_name="conversationcheckpointblob",
name="unique_checkpoint_blob",
),
migrations.AddField(
model_name="conversationcheckpointblob",
name="checkpoint_ns",
field=models.TextField(
default="",
help_text='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).',
),
),
migrations.AddField(
model_name="conversationcheckpointblob",
name="thread",
field=models.ForeignKey(
null=True, on_delete=django.db.models.deletion.CASCADE, related_name="blobs", to="ee.conversation"
),
),
migrations.AddConstraint(
model_name="conversationcheckpointblob",
constraint=models.UniqueConstraint(
fields=("thread_id", "checkpoint_ns", "channel", "version"), name="unique_checkpoint_blob"
),
),
]
2 changes: 1 addition & 1 deletion ee/migrations/max_migration.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0018_conversation_conversationcheckpoint_and_more
0019_remove_conversationcheckpointblob_unique_checkpoint_blob_and_more
10 changes: 9 additions & 1 deletion ee/models/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def pending_writes(self) -> Iterable["ConversationCheckpointWrite"]:

class ConversationCheckpointBlob(UUIDModel):
checkpoint = models.ForeignKey(ConversationCheckpoint, on_delete=models.CASCADE, related_name="blobs")
"""
The checkpoint that created the blob. Do not use this field to query blobs.
"""
thread = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name="blobs", null=True)
checkpoint_ns = models.TextField(
default="",
help_text='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).',
)
channel = models.TextField(
help_text="An arbitrary string defining the channel name. For example, it can be a node name or a reserved LangGraph's enum."
)
Expand All @@ -56,7 +64,7 @@ class ConversationCheckpointBlob(UUIDModel):
class Meta:
constraints = [
models.UniqueConstraint(
fields=["checkpoint_id", "channel", "version"],
fields=["thread_id", "checkpoint_ns", "channel", "version"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will lock the table anyway, so I didn't set checkpoint_ns to null. The table was just created, so a short lock shouldn't be a problem.

name="unique_checkpoint_blob",
)
]
Expand Down
Loading