Skip to content

Commit

Permalink
fix(product-assistant): correctly establish a connection for streaming (
Browse files Browse the repository at this point in the history
  • Loading branch information
skoob13 authored Oct 25, 2024
1 parent 6013e43 commit 544e239
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 55 deletions.
3 changes: 3 additions & 0 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def stream(self, conversation: Conversation) -> Generator[str, None, None]:

chunks = AIMessageChunk(content="")

# Send a chunk to establish the connection avoiding the worker's timeout.
yield ""

for update in generator:
if is_value_update(update):
_, state_update = update
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/scenes/max/Max.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { useEffect } from 'react'

import { mswDecorator, useStorybookMocks } from '~/mocks/browser'

import chatResponse from './__mocks__/chatResponse.json'
import { chatResponseChunk } from './__mocks__/chatResponse.mocks'
import { MaxInstance } from './Max'
import { maxLogic } from './maxLogic'

Expand All @@ -13,7 +13,7 @@ const meta: Meta = {
decorators: [
mswDecorator({
post: {
'/api/environments/:team_id/query/chat/': chatResponse,
'/api/environments/:team_id/query/chat/': (_, res, ctx) => res(ctx.text(chatResponseChunk)),
},
}),
],
Expand Down
3 changes: 3 additions & 0 deletions frontend/src/scenes/max/__mocks__/chatResponse.mocks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import chatResponse from './chatResponse.json'

export const chatResponseChunk = `data: ${JSON.stringify(chatResponse)}\n\n`
77 changes: 26 additions & 51 deletions frontend/src/scenes/max/maxLogic.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { shuffle } from 'd3'
import { createParser } from 'eventsource-parser'
import { actions, kea, key, listeners, path, props, reducers, selectors } from 'kea'
import { loaders } from 'kea-loaders'
import api from 'lib/api'
Expand Down Expand Up @@ -118,20 +119,22 @@ export const maxLogic = kea<maxLogicType>([
messages: values.thread.map(({ status, ...message }) => message),
})
const reader = response.body?.getReader()

if (!reader) {
return
}

const decoder = new TextDecoder()

if (reader) {
let firstChunk = true
let firstChunk = true

while (true) {
const { done, value } = await reader.read()
if (done) {
actions.setMessageStatus(newIndex, 'completed')
break
}
const parser = createParser({
onEvent: (event) => {
const parsedResponse = parseResponse(event.data)

const text = decoder.decode(value)
const parsedResponse = parseResponse(text)
if (!parsedResponse) {
return
}

if (firstChunk) {
firstChunk = false
Expand All @@ -145,6 +148,17 @@ export const maxLogic = kea<maxLogicType>([
status: 'loading',
})
}
},
})

while (true) {
const { done, value } = await reader.read()

parser.feed(decoder.decode(value))

if (done) {
actions.setMessageStatus(newIndex, 'completed')
break
}
}
} catch {
Expand All @@ -163,50 +177,11 @@ export const maxLogic = kea<maxLogicType>([
* Parses the generation result from the API. Some generation chunks might be sent in batches.
* @param response
*/
function parseResponse(response: string, recursive = true): RootAssistantMessage | null {
function parseResponse(response: string): RootAssistantMessage | null | undefined {
try {
const parsed = JSON.parse(response)
return parsed as RootAssistantMessage
return parsed as RootAssistantMessage | null | undefined
} catch {
if (!recursive) {
return null
}

const results: [number, number][] = []
let pair: [number, number] = [0, 0]
let seq = 0

for (let i = 0; i < response.length; i++) {
const char = response[i]

if (char === '{') {
if (seq === 0) {
pair[0] = i
}

seq += 1
}

if (char === '}') {
seq -= 1
if (seq === 0) {
pair[1] = i
}
}

if (seq === 0) {
results.push(pair)
pair = [0, 0]
}
}

const lastPair = results.pop()

if (lastPair) {
const [left, right] = lastPair
return parseResponse(response.slice(left, right + 1), false)
}

return null
}
}
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
"esbuild-plugin-less": "^1.3.1",
"esbuild-plugin-polyfill-node": "^0.3.0",
"esbuild-sass-plugin": "^3.0.0",
"eventsource-parser": "^3.0.0",
"expr-eval": "^2.0.2",
"express": "^4.17.1",
"fast-deep-equal": "^3.1.3",
Expand Down
11 changes: 10 additions & 1 deletion pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion posthog/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def generate():
last_message = None
for message in assistant.stream(validated_body):
last_message = message
yield last_message
yield f"data: {message}\n\n"

human_message = validated_body.messages[-1].root
if isinstance(human_message, HumanMessage):
Expand Down

0 comments on commit 544e239

Please sign in to comment.