diff --git a/ee/hogai/django_checkpoint/checkpointer.py b/ee/hogai/django_checkpoint/checkpointer.py index 78817dca9df76..a57140fecdc13 100644 --- a/ee/hogai/django_checkpoint/checkpointer.py +++ b/ee/hogai/django_checkpoint/checkpointer.py @@ -94,7 +94,9 @@ def _get_checkpoint_channel_values( query = Q() for channel, version in loaded_checkpoint["channel_versions"].items(): query |= Q(channel=channel, version=version) - return checkpoint.blobs.filter(query) + return ConversationCheckpointBlob.objects.filter( + Q(thread_id=checkpoint.thread_id, checkpoint_ns=checkpoint.checkpoint_ns) & query + ) def list( self, @@ -238,6 +240,7 @@ def put( blobs.append( ConversationCheckpointBlob( checkpoint=updated_checkpoint, + thread_id=thread_id, channel=channel, version=str(version), type=type, diff --git a/ee/hogai/django_checkpoint/test/test_checkpointer.py b/ee/hogai/django_checkpoint/test/test_checkpointer.py index 2f8fd7f4a60ed..d7c7a9117862d 100644 --- a/ee/hogai/django_checkpoint/test/test_checkpointer.py +++ b/ee/hogai/django_checkpoint/test/test_checkpointer.py @@ -1,6 +1,7 @@ # type: ignore -from typing import Any, TypedDict +import operator +from typing import Annotated, Any, Optional, TypedDict from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( @@ -13,6 +14,7 @@ from langgraph.errors import NodeInterrupt from langgraph.graph import END, START from langgraph.graph.state import CompiledStateGraph, StateGraph +from pydantic import BaseModel, Field from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer from ee.models.assistant import ( @@ -272,3 +274,152 @@ def test_resuming(self): self.assertEqual(res, {"val": 3}) snapshot = graph.get_state(config) self.assertFalse(snapshot.next) + + def test_checkpoint_blobs_are_bound_to_thread(self): + class State(TypedDict, total=False): + messages: Annotated[list[str], operator.add] + string: Optional[str] + + graph = StateGraph(State) + + def handle_node1(state: State): + return + + def handle_node2(state: State): + raise NodeInterrupt("test") + + graph.add_node("node1", handle_node1) + graph.add_node("node2", handle_node2) + + graph.add_edge(START, "node1") + graph.add_edge("node1", "node2") + graph.add_edge("node2", END) + + compiled = graph.compile(checkpointer=DjangoCheckpointer()) + + thread = Conversation.objects.create(user=self.user, team=self.team) + config = {"configurable": {"thread_id": str(thread.id)}} + compiled.invoke({"messages": ["hello"], "string": "world"}, config=config) + + snapshot = compiled.get_state(config) + self.assertIsNotNone(snapshot.next) + self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test") + saved_state = snapshot.values + self.assertEqual(saved_state["messages"], ["hello"]) + self.assertEqual(saved_state["string"], "world") + + def test_checkpoint_can_save_and_load_pydantic_state(self): + class State(BaseModel): + messages: Annotated[list[str], operator.add] + string: Optional[str] + + class PartialState(BaseModel): + messages: Optional[list[str]] = Field(default=None) + string: Optional[str] = Field(default=None) + + graph = StateGraph(State) + + def handle_node1(state: State): + return PartialState() + + def handle_node2(state: State): + raise NodeInterrupt("test") + + graph.add_node("node1", handle_node1) + graph.add_node("node2", handle_node2) + + graph.add_edge(START, "node1") + graph.add_edge("node1", "node2") + graph.add_edge("node2", END) + + compiled = graph.compile(checkpointer=DjangoCheckpointer()) + + thread = Conversation.objects.create(user=self.user, team=self.team) + config = {"configurable": {"thread_id": str(thread.id)}} + compiled.invoke({"messages": ["hello"], "string": "world"}, config=config) + + snapshot = compiled.get_state(config) + self.assertIsNotNone(snapshot.next) + self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test") + saved_state = snapshot.values + self.assertEqual(saved_state["messages"], ["hello"]) + self.assertEqual(saved_state["string"], "world") + + def test_saved_blobs(self): + class State(TypedDict, total=False): + messages: Annotated[list[str], operator.add] + + graph = StateGraph(State) + + def handle_node1(state: State): + return {"messages": ["world"]} + + graph.add_node("node1", handle_node1) + + graph.add_edge(START, "node1") + graph.add_edge("node1", END) + + checkpointer = DjangoCheckpointer() + compiled = graph.compile(checkpointer=checkpointer) + + thread = Conversation.objects.create(user=self.user, team=self.team) + config = {"configurable": {"thread_id": str(thread.id)}} + compiled.invoke({"messages": ["hello"]}, config=config) + + snapshot = compiled.get_state(config) + self.assertFalse(snapshot.next) + saved_state = snapshot.values + self.assertEqual(saved_state["messages"], ["hello", "world"]) + + blobs = list(ConversationCheckpointBlob.objects.filter(thread=thread)) + self.assertEqual(len(blobs), 7) + + # Set initial state + self.assertEqual(blobs[0].channel, "__start__") + self.assertEqual(blobs[0].type, "msgpack") + self.assertEqual( + checkpointer.serde.loads_typed((blobs[0].type, blobs[0].blob)), + {"messages": ["hello"]}, + ) + + # Set first node + self.assertEqual(blobs[1].channel, "__start__") + self.assertEqual(blobs[1].type, "empty") + self.assertIsNone(blobs[1].blob) + + # Set value channels before start + self.assertEqual(blobs[2].channel, "messages") + self.assertEqual(blobs[2].type, "msgpack") + self.assertEqual( + checkpointer.serde.loads_typed((blobs[2].type, blobs[2].blob)), + ["hello"], + ) + + # Transition to node1 + self.assertEqual(blobs[3].channel, "start:node1") + self.assertEqual(blobs[3].type, "msgpack") + self.assertEqual( + checkpointer.serde.loads_typed((blobs[3].type, blobs[3].blob)), + "__start__", + ) + + # Set new state for messages + self.assertEqual(blobs[4].channel, "messages") + self.assertEqual(blobs[4].type, "msgpack") + self.assertEqual( + checkpointer.serde.loads_typed((blobs[4].type, blobs[4].blob)), + ["hello", "world"], + ) + + # After setting a state + self.assertEqual(blobs[5].channel, "start:node1") + self.assertEqual(blobs[5].type, "empty") + self.assertIsNone(blobs[5].blob) + + # Set last step + self.assertEqual(blobs[6].channel, "node1") + self.assertEqual(blobs[6].type, "msgpack") + self.assertEqual( + checkpointer.serde.loads_typed((blobs[6].type, blobs[6].blob)), + "node1", + ) diff --git a/ee/migrations/0019_remove_conversationcheckpointblob_unique_checkpoint_blob_and_more.py b/ee/migrations/0019_remove_conversationcheckpointblob_unique_checkpoint_blob_and_more.py new file mode 100644 index 0000000000000..377f85b3d29c2 --- /dev/null +++ b/ee/migrations/0019_remove_conversationcheckpointblob_unique_checkpoint_blob_and_more.py @@ -0,0 +1,38 @@ +# Generated by Django 4.2.15 on 2024-12-19 11:00 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("ee", "0018_conversation_conversationcheckpoint_and_more"), + ] + + operations = [ + migrations.RemoveConstraint( + model_name="conversationcheckpointblob", + name="unique_checkpoint_blob", + ), + migrations.AddField( + model_name="conversationcheckpointblob", + name="checkpoint_ns", + field=models.TextField( + default="", + help_text='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).', + ), + ), + migrations.AddField( + model_name="conversationcheckpointblob", + name="thread", + field=models.ForeignKey( + null=True, on_delete=django.db.models.deletion.CASCADE, related_name="blobs", to="ee.conversation" + ), + ), + migrations.AddConstraint( + model_name="conversationcheckpointblob", + constraint=models.UniqueConstraint( + fields=("thread_id", "checkpoint_ns", "channel", "version"), name="unique_checkpoint_blob" + ), + ), + ] diff --git a/ee/migrations/max_migration.txt b/ee/migrations/max_migration.txt index fb889f1cc34cf..aec0628d960c8 100644 --- a/ee/migrations/max_migration.txt +++ b/ee/migrations/max_migration.txt @@ -1 +1 @@ -0018_conversation_conversationcheckpoint_and_more +0019_remove_conversationcheckpointblob_unique_checkpoint_blob_and_more diff --git a/ee/models/assistant.py b/ee/models/assistant.py index 390a7ab7a117f..f2a31d938f5d0 100644 --- a/ee/models/assistant.py +++ b/ee/models/assistant.py @@ -46,6 +46,14 @@ def pending_writes(self) -> Iterable["ConversationCheckpointWrite"]: class ConversationCheckpointBlob(UUIDModel): checkpoint = models.ForeignKey(ConversationCheckpoint, on_delete=models.CASCADE, related_name="blobs") + """ + The checkpoint that created the blob. Do not use this field to query blobs. + """ + thread = models.ForeignKey(Conversation, on_delete=models.CASCADE, related_name="blobs", null=True) + checkpoint_ns = models.TextField( + default="", + help_text='Checkpoint namespace. Denotes the path to the subgraph node the checkpoint originates from, separated by `|` character, e.g. `"child|grandchild"`. Defaults to "" (root graph).', + ) channel = models.TextField( help_text="An arbitrary string defining the channel name. For example, it can be a node name or a reserved LangGraph's enum." ) @@ -56,7 +64,7 @@ class ConversationCheckpointBlob(UUIDModel): class Meta: constraints = [ models.UniqueConstraint( - fields=["checkpoint_id", "channel", "version"], + fields=["thread_id", "checkpoint_ns", "channel", "version"], name="unique_checkpoint_blob", ) ]