Skip to content

Commit

Permalink
feature: run evaluation prompt
Browse files Browse the repository at this point in the history
This commit implements actions to run evaluation prompts as well as some
internal refactors.
  • Loading branch information
geclos committed Sep 5, 2024
1 parent c031599 commit 19d163e
Show file tree
Hide file tree
Showing 36 changed files with 827 additions and 492 deletions.
10 changes: 1 addition & 9 deletions apps/gateway/src/routes/api/v1/chats/handlers/addMessage.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { zValidator } from '@hono/zod-validator'
import { LogSources } from '@latitude-data/core/browser'
import { addMessages } from '@latitude-data/core/services/documentLogs/index'
import { createProviderLog } from '@latitude-data/core/services/providerLogs/create'
import { messageSchema } from '$/common/messageSchema'
import { pipeToStream } from '$/common/pipeToStream'
import { Factory } from 'hono/factory'
Expand All @@ -22,19 +21,12 @@ export const addMessageHandler = factory.createHandlers(
return streamSSE(c, async (stream) => {
const { documentLogUuid, messages, source } = c.req.valid('json')
const workspace = c.get('workspace')
const apiKey = c.get('apiKey')

const result = await addMessages({
workspace,
documentLogUuid,
messages,
providerLogHandler: async (log) => {
await createProviderLog({
...log,
source,
apiKeyId: apiKey.id,
}).then((r) => r.unwrap())
},
source,
}).then((r) => r.unwrap())

await pipeToStream(stream, result.stream)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { zValidator } from '@hono/zod-validator'
import { LogSources } from '@latitude-data/core/browser'
import { runDocumentAtCommit } from '@latitude-data/core/services/commits/runDocumentAtCommit'
import { createProviderLog } from '@latitude-data/core/services/providerLogs/create'
import { pipeToStream } from '$/common/pipeToStream'
import { queues } from '$/jobs'
import { Factory } from 'hono/factory'
Expand All @@ -28,7 +27,6 @@ export const runHandler = factory.createHandlers(
const { documentPath, parameters, source } = c.req.valid('json')

const workspace = c.get('workspace')
const apiKey = c.get('apiKey')

const { document, commit, project } = await getData({
workspace,
Expand All @@ -42,18 +40,11 @@ export const runHandler = factory.createHandlers(
document,
commit,
parameters,
providerLogHandler: (log) => {
createProviderLog({
...log,
source,
apiKeyId: apiKey.id,
}).then((r) => r.unwrap())
},
source,
}).then((r) => r.unwrap())

await pipeToStream(stream, result.stream)

// TODO: review if this is needed and why it's not in addMessages handler?
queues.defaultQueue.jobs.enqueueCreateDocumentLogJob({
commit,
data: {
Expand Down
56 changes: 56 additions & 0 deletions apps/web/src/actions/prompts/run.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
'use server'

import { LogSources } from '@latitude-data/core/browser'
import { streamToGenerator } from '@latitude-data/core/lib/streamToGenerator'
import { runPrompt } from '@latitude-data/core/services/prompts/run'
import { buildProviderApikeysMap } from '@latitude-data/core/services/providerApiKeys/buildMap'
import { createStreamableValue } from 'ai/rsc'
import { z } from 'zod'

import { authProcedure } from '../procedures'

export const runPromptAction = authProcedure
.createServerAction()
.input(
z.object({
prompt: z.string(),
parameters: z.object({ messages: z.string(), last_message: z.string() }),
}),
)
.handler(async ({ ctx, input }) => {
const { prompt, parameters } = input
const stream = createStreamableValue()
try {
const result = await runPrompt({
source: LogSources.Evaluation,
prompt,
parameters,
apikeys: await buildProviderApikeysMap({
workspaceId: ctx.workspace.id,
}),
}).then((r) => r.unwrap())

pipeToStream(result.stream, stream)

return {
output: stream.value,
response: result.response,
}
} catch (error) {
stream.error(error)
stream.done()

throw error
}
})

async function pipeToStream(
source: ReadableStream,
target: ReturnType<typeof createStreamableValue>,
) {
for await (const chunk of streamToGenerator(source)) {
target.update(chunk)
}

target.done()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import { useCallback, useEffect, useRef, useState } from 'react'

import {
Conversation,
Message as ConversationMessage,
} from '@latitude-data/compiler'
import {
ChainEventTypes,
EvaluationDto,
StreamEventTypes,
} from '@latitude-data/core/browser'
import {
ErrorMessage,
MessageList,
Text,
useAutoScroll,
} from '@latitude-data/web-ui'
import { runPromptAction } from '$/actions/prompts/run'
import {
StreamMessage,
Timer,
TokenUsage,
} from '$/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat'
import { readStreamableValue } from 'ai/rsc'

export const EVALUATION_PARAMETERS = ['messages', 'last_message']

export type Parameters = (typeof EVALUATION_PARAMETERS)[number]
export type Inputs = { [key in Parameters]: string }

export default function Chat({
evaluation,
parameters,
}: {
evaluation: EvaluationDto
parameters: Inputs
}) {
const [error, setError] = useState<Error | undefined>()
const [tokens, setTokens] = useState<number>(0)
const [isScrolledToBottom, setIsScrolledToBottom] = useState(false)
const [startTime, _] = useState(performance.now())
const [endTime, setEndTime] = useState<number>()
const containerRef = useRef<HTMLDivElement>(null)
useAutoScroll(containerRef, {
startAtBottom: true,
onScrollChange: setIsScrolledToBottom,
})

const runChainOnce = useRef(false)
// Index where the chain ends and the chat begins
const [chainLength, setChainLength] = useState<number>(Infinity)
const [conversation, setConversation] = useState<Conversation | undefined>()
const [responseStream, setResponseStream] = useState<string | undefined>()

const addMessageToConversation = useCallback(
(message: ConversationMessage) => {
let newConversation: Conversation
setConversation((prevConversation) => {
newConversation = {
...prevConversation,
messages: [...(prevConversation?.messages ?? []), message],
} as Conversation
return newConversation
})
return newConversation!
},
[],
)

const runEvaluation = useCallback(async () => {
setError(undefined)
setResponseStream(undefined)

let response = ''
let messagesCount = 0

const [data, error] = await runPromptAction({
prompt: evaluation.metadata.prompt,
parameters: parameters as { messages: string; last_message: string },
})
if (error) {
setError(error)
return
}

const { output } = data!

for await (const serverEvent of readStreamableValue(output)) {
if (!serverEvent) continue

const { event, data } = serverEvent
const hasMessages = 'messages' in data
if (hasMessages) {
data.messages.forEach(addMessageToConversation)
messagesCount += data.messages.length
}

switch (event) {
case StreamEventTypes.Latitude: {
if (data.type === ChainEventTypes.Step) {
if (data.isLastStep) setChainLength(messagesCount + 1)
} else if (data.type === ChainEventTypes.Complete) {
setTokens(data.response.usage.totalTokens)
setEndTime(performance.now())
} else if (data.type === ChainEventTypes.Error) {
setError(new Error(data.error.message))
}
break
}

case StreamEventTypes.Provider: {
if (data.type === 'text-delta') {
response += data.textDelta
setResponseStream(response)
} else if (data.type === 'finish') {
setResponseStream(undefined)
response = ''
}
break
}
default:
break
}
}
}, [parameters, runPromptAction])

useEffect(() => {
if (runChainOnce.current) return

runChainOnce.current = true // Prevent double-running when StrictMode is enabled
runEvaluation()
}, [runEvaluation])

return (
<div className='flex flex-col h-full'>
<div
ref={containerRef}
className='flex flex-col gap-3 h-full overflow-y-auto pb-12'
>
<Text.H6M>Prompt</Text.H6M>
<MessageList
messages={conversation?.messages.slice(0, chainLength - 1) ?? []}
/>
{(conversation?.messages.length ?? 0) >= chainLength && (
<>
<MessageList
messages={
conversation?.messages.slice(chainLength - 1, chainLength) ?? []
}
variant='accent'
/>
{endTime && <Timer timeMs={endTime - startTime} />}
</>
)}
{(conversation?.messages.length ?? 0) > chainLength && (
<>
<Text.H6M>Chat</Text.H6M>
<MessageList
messages={conversation!.messages.slice(chainLength)}
variant='outline'
/>
</>
)}
{error ? (
<ErrorMessage error={error} />
) : (
<StreamMessage
responseStream={responseStream}
conversation={conversation}
chainLength={chainLength}
/>
)}
</div>
<div className='flex relative flex-row w-full items-center justify-center'>
<TokenUsage
isScrolledToBottom={isScrolledToBottom}
tokens={tokens}
responseStream={responseStream}
/>
</div>
</div>
)
}
Loading

0 comments on commit 19d163e

Please sign in to comment.