Skip to content

Commit

Permalink
Merge branch 'master' into memoize-funnel-path-url
Browse files Browse the repository at this point in the history
  • Loading branch information
thmsobrmlr authored Dec 20, 2024
2 parents 41612b5 + 3bb4386 commit c278e16
Show file tree
Hide file tree
Showing 120 changed files with 3,165 additions and 1,164 deletions.
6 changes: 3 additions & 3 deletions bin/mprocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ procs:
shell: 'bin/check_kafka_clickhouse_up && bin/check_temporal_up && python manage.py start_temporal_worker'

docker-compose:
shell: 'docker compose -f docker-compose.dev.yml up'
stop:
send-keys: ['<C-c>']
# docker-compose makes sure the stack is up, and then follows its logs - but doesn't tear down on exit for speed
shell: 'docker compose -f docker-compose.dev.yml up -d && docker compose -f docker-compose.dev.yml logs --tail=0 -f'

mouse_scroll_speed: 1
scrollback: 10000
10 changes: 10 additions & 0 deletions bin/start
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,14 @@ export HOG_HOOK_URL=${HOG_HOOK_URL:-http://localhost:3300/hoghook}

[ ! -f ./share/GeoLite2-City.mmdb ] && ( curl -L "https://mmdbcdn.posthog.net/" --http1.1 | brotli --decompress --output=./share/GeoLite2-City.mmdb )

if ! command -v mprocs &> /dev/null; then
if command -v brew &> /dev/null; then
echo "🔁 Installing mprocs via Homebrew..."
brew install mprocs
else
echo "👉 To run bin/start, install mprocs: https://github.com/pvolok/mprocs#installation"
exit 1
fi
fi

exec mprocs --config bin/mprocs.yaml
8 changes: 2 additions & 6 deletions cypress/e2e/experiments.cy.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import { setupFeatureFlags } from '../support/decide'

describe('Experiments', () => {
let randomNum
let experimentName
Expand Down Expand Up @@ -47,10 +45,6 @@ describe('Experiments', () => {
})

const createExperimentInNewUi = (): void => {
setupFeatureFlags({
'new-experiments-ui': true,
})

cy.visit('/experiments')

// Name, flag key, description
Expand Down Expand Up @@ -96,6 +90,7 @@ describe('Experiments', () => {
cy.get('[data-attr="experiment-creation-date"]').contains('a few seconds ago').should('be.visible')
cy.get('[data-attr="experiment-start-date"]').should('not.exist')

cy.wait(1000)
cy.get('[data-attr="launch-experiment"]').first().click()
cy.get('[data-attr="experiment-creation-date"]').should('not.exist')
cy.get('[data-attr="experiment-start-date"]').contains('a few seconds ago').should('be.visible')
Expand All @@ -114,6 +109,7 @@ describe('Experiments', () => {
it('move start date', () => {
createExperimentInNewUi()

cy.wait(1000)
cy.get('[data-attr="launch-experiment"]').first().click()

cy.get('[data-attr="move-experiment-start-date"]').first().click()
Expand Down
5 changes: 1 addition & 4 deletions cypress/e2e/insights-reload-query.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import JSONCrush from 'jsoncrush'

describe('ReloadInsight component', () => {
beforeEach(() => {
// Clear local storage before each test to ensure a clean state
Expand All @@ -21,8 +19,7 @@ describe('ReloadInsight component', () => {
const draftQuery = window.localStorage.getItem(`draft-query-${currentTeamId}`)
expect(draftQuery).to.not.be.null

const draftQueryObjUncrushed = JSONCrush.uncrush(draftQuery)
const draftQueryObj = JSON.parse(draftQueryObjUncrushed)
const draftQueryObj = JSON.parse(draftQuery)

expect(draftQueryObj).to.have.property('query')

Expand Down
4 changes: 3 additions & 1 deletion ee/api/test/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def test_create_hog_function_via_hook(self):
"target": "https://hooks.zapier.com/{inputs.hook}",
},
},
"order": 2,
},
"debug": {},
"debug": {"order": 1},
"hook": {
"bytecode": [
"_H",
Expand All @@ -149,6 +150,7 @@ def test_create_hog_function_via_hook(self):
"hooks/standard/1234/abcd",
],
"value": "hooks/standard/1234/abcd",
"order": 0,
},
}

Expand Down
5 changes: 4 additions & 1 deletion ee/hogai/django_checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def _get_checkpoint_channel_values(
query = Q()
for channel, version in loaded_checkpoint["channel_versions"].items():
query |= Q(channel=channel, version=version)
return checkpoint.blobs.filter(query)
return ConversationCheckpointBlob.objects.filter(
Q(thread_id=checkpoint.thread_id, checkpoint_ns=checkpoint.checkpoint_ns) & query
)

def list(
self,
Expand Down Expand Up @@ -238,6 +240,7 @@ def put(
blobs.append(
ConversationCheckpointBlob(
checkpoint=updated_checkpoint,
thread_id=thread_id,
channel=channel,
version=str(version),
type=type,
Expand Down
153 changes: 152 additions & 1 deletion ee/hogai/django_checkpoint/test/test_checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore

from typing import Any, TypedDict
import operator
from typing import Annotated, Any, Optional, TypedDict

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Expand All @@ -13,6 +14,7 @@
from langgraph.errors import NodeInterrupt
from langgraph.graph import END, START
from langgraph.graph.state import CompiledStateGraph, StateGraph
from pydantic import BaseModel, Field

from ee.hogai.django_checkpoint.checkpointer import DjangoCheckpointer
from ee.models.assistant import (
Expand Down Expand Up @@ -272,3 +274,152 @@ def test_resuming(self):
self.assertEqual(res, {"val": 3})
snapshot = graph.get_state(config)
self.assertFalse(snapshot.next)

def test_checkpoint_blobs_are_bound_to_thread(self):
class State(TypedDict, total=False):
messages: Annotated[list[str], operator.add]
string: Optional[str]

graph = StateGraph(State)

def handle_node1(state: State):
return

def handle_node2(state: State):
raise NodeInterrupt("test")

graph.add_node("node1", handle_node1)
graph.add_node("node2", handle_node2)

graph.add_edge(START, "node1")
graph.add_edge("node1", "node2")
graph.add_edge("node2", END)

compiled = graph.compile(checkpointer=DjangoCheckpointer())

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"], "string": "world"}, config=config)

snapshot = compiled.get_state(config)
self.assertIsNotNone(snapshot.next)
self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test")
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello"])
self.assertEqual(saved_state["string"], "world")

def test_checkpoint_can_save_and_load_pydantic_state(self):
class State(BaseModel):
messages: Annotated[list[str], operator.add]
string: Optional[str]

class PartialState(BaseModel):
messages: Optional[list[str]] = Field(default=None)
string: Optional[str] = Field(default=None)

graph = StateGraph(State)

def handle_node1(state: State):
return PartialState()

def handle_node2(state: State):
raise NodeInterrupt("test")

graph.add_node("node1", handle_node1)
graph.add_node("node2", handle_node2)

graph.add_edge(START, "node1")
graph.add_edge("node1", "node2")
graph.add_edge("node2", END)

compiled = graph.compile(checkpointer=DjangoCheckpointer())

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"], "string": "world"}, config=config)

snapshot = compiled.get_state(config)
self.assertIsNotNone(snapshot.next)
self.assertEqual(snapshot.tasks[0].interrupts[0].value, "test")
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello"])
self.assertEqual(saved_state["string"], "world")

def test_saved_blobs(self):
class State(TypedDict, total=False):
messages: Annotated[list[str], operator.add]

graph = StateGraph(State)

def handle_node1(state: State):
return {"messages": ["world"]}

graph.add_node("node1", handle_node1)

graph.add_edge(START, "node1")
graph.add_edge("node1", END)

checkpointer = DjangoCheckpointer()
compiled = graph.compile(checkpointer=checkpointer)

thread = Conversation.objects.create(user=self.user, team=self.team)
config = {"configurable": {"thread_id": str(thread.id)}}
compiled.invoke({"messages": ["hello"]}, config=config)

snapshot = compiled.get_state(config)
self.assertFalse(snapshot.next)
saved_state = snapshot.values
self.assertEqual(saved_state["messages"], ["hello", "world"])

blobs = list(ConversationCheckpointBlob.objects.filter(thread=thread))
self.assertEqual(len(blobs), 7)

# Set initial state
self.assertEqual(blobs[0].channel, "__start__")
self.assertEqual(blobs[0].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[0].type, blobs[0].blob)),
{"messages": ["hello"]},
)

# Set first node
self.assertEqual(blobs[1].channel, "__start__")
self.assertEqual(blobs[1].type, "empty")
self.assertIsNone(blobs[1].blob)

# Set value channels before start
self.assertEqual(blobs[2].channel, "messages")
self.assertEqual(blobs[2].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[2].type, blobs[2].blob)),
["hello"],
)

# Transition to node1
self.assertEqual(blobs[3].channel, "start:node1")
self.assertEqual(blobs[3].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[3].type, blobs[3].blob)),
"__start__",
)

# Set new state for messages
self.assertEqual(blobs[4].channel, "messages")
self.assertEqual(blobs[4].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[4].type, blobs[4].blob)),
["hello", "world"],
)

# After setting a state
self.assertEqual(blobs[5].channel, "start:node1")
self.assertEqual(blobs[5].type, "empty")
self.assertIsNone(blobs[5].blob)

# Set last step
self.assertEqual(blobs[6].channel, "node1")
self.assertEqual(blobs[6].type, "msgpack")
self.assertEqual(
checkpointer.serde.loads_typed((blobs[6].type, blobs[6].blob)),
"node1",
)
Loading

0 comments on commit c278e16

Please sign in to comment.