-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit implements actions to run evaluation prompts as well as some internal refactors.
- Loading branch information
Showing
36 changed files
with
827 additions
and
492 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
'use server' | ||
|
||
import { LogSources } from '@latitude-data/core/browser' | ||
import { streamToGenerator } from '@latitude-data/core/lib/streamToGenerator' | ||
import { runPrompt } from '@latitude-data/core/services/prompts/run' | ||
import { buildProviderApikeysMap } from '@latitude-data/core/services/providerApiKeys/buildMap' | ||
import { createStreamableValue } from 'ai/rsc' | ||
import { z } from 'zod' | ||
|
||
import { authProcedure } from '../procedures' | ||
|
||
export const runPromptAction = authProcedure | ||
.createServerAction() | ||
.input( | ||
z.object({ | ||
prompt: z.string(), | ||
parameters: z.object({ messages: z.string(), last_message: z.string() }), | ||
}), | ||
) | ||
.handler(async ({ ctx, input }) => { | ||
const { prompt, parameters } = input | ||
const stream = createStreamableValue() | ||
try { | ||
const result = await runPrompt({ | ||
source: LogSources.Evaluation, | ||
prompt, | ||
parameters, | ||
apikeys: await buildProviderApikeysMap({ | ||
workspaceId: ctx.workspace.id, | ||
}), | ||
}).then((r) => r.unwrap()) | ||
|
||
pipeToStream(result.stream, stream) | ||
|
||
return { | ||
output: stream.value, | ||
response: result.response, | ||
} | ||
} catch (error) { | ||
stream.error(error) | ||
stream.done() | ||
|
||
throw error | ||
} | ||
}) | ||
|
||
async function pipeToStream( | ||
source: ReadableStream, | ||
target: ReturnType<typeof createStreamableValue>, | ||
) { | ||
for await (const chunk of streamToGenerator(source)) { | ||
target.update(chunk) | ||
} | ||
|
||
target.done() | ||
} |
183 changes: 183 additions & 0 deletions
183
...ons/[evaluationUuid]/editor/_components/EvaluationEditor/Editor/Playground/Chat/index.tsx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import { useCallback, useEffect, useRef, useState } from 'react' | ||
|
||
import { | ||
Conversation, | ||
Message as ConversationMessage, | ||
} from '@latitude-data/compiler' | ||
import { | ||
ChainEventTypes, | ||
EvaluationDto, | ||
StreamEventTypes, | ||
} from '@latitude-data/core/browser' | ||
import { | ||
ErrorMessage, | ||
MessageList, | ||
Text, | ||
useAutoScroll, | ||
} from '@latitude-data/web-ui' | ||
import { runPromptAction } from '$/actions/prompts/run' | ||
import { | ||
StreamMessage, | ||
Timer, | ||
TokenUsage, | ||
} from '$/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat' | ||
import { readStreamableValue } from 'ai/rsc' | ||
|
||
export const EVALUATION_PARAMETERS = ['messages', 'last_message'] | ||
|
||
export type Parameters = (typeof EVALUATION_PARAMETERS)[number] | ||
export type Inputs = { [key in Parameters]: string } | ||
|
||
export default function Chat({ | ||
evaluation, | ||
parameters, | ||
}: { | ||
evaluation: EvaluationDto | ||
parameters: Inputs | ||
}) { | ||
const [error, setError] = useState<Error | undefined>() | ||
const [tokens, setTokens] = useState<number>(0) | ||
const [isScrolledToBottom, setIsScrolledToBottom] = useState(false) | ||
const [startTime, _] = useState(performance.now()) | ||
const [endTime, setEndTime] = useState<number>() | ||
const containerRef = useRef<HTMLDivElement>(null) | ||
useAutoScroll(containerRef, { | ||
startAtBottom: true, | ||
onScrollChange: setIsScrolledToBottom, | ||
}) | ||
|
||
const runChainOnce = useRef(false) | ||
// Index where the chain ends and the chat begins | ||
const [chainLength, setChainLength] = useState<number>(Infinity) | ||
const [conversation, setConversation] = useState<Conversation | undefined>() | ||
const [responseStream, setResponseStream] = useState<string | undefined>() | ||
|
||
const addMessageToConversation = useCallback( | ||
(message: ConversationMessage) => { | ||
let newConversation: Conversation | ||
setConversation((prevConversation) => { | ||
newConversation = { | ||
...prevConversation, | ||
messages: [...(prevConversation?.messages ?? []), message], | ||
} as Conversation | ||
return newConversation | ||
}) | ||
return newConversation! | ||
}, | ||
[], | ||
) | ||
|
||
const runEvaluation = useCallback(async () => { | ||
setError(undefined) | ||
setResponseStream(undefined) | ||
|
||
let response = '' | ||
let messagesCount = 0 | ||
|
||
const [data, error] = await runPromptAction({ | ||
prompt: evaluation.metadata.prompt, | ||
parameters: parameters as { messages: string; last_message: string }, | ||
}) | ||
if (error) { | ||
setError(error) | ||
return | ||
} | ||
|
||
const { output } = data! | ||
|
||
for await (const serverEvent of readStreamableValue(output)) { | ||
if (!serverEvent) continue | ||
|
||
const { event, data } = serverEvent | ||
const hasMessages = 'messages' in data | ||
if (hasMessages) { | ||
data.messages.forEach(addMessageToConversation) | ||
messagesCount += data.messages.length | ||
} | ||
|
||
switch (event) { | ||
case StreamEventTypes.Latitude: { | ||
if (data.type === ChainEventTypes.Step) { | ||
if (data.isLastStep) setChainLength(messagesCount + 1) | ||
} else if (data.type === ChainEventTypes.Complete) { | ||
setTokens(data.response.usage.totalTokens) | ||
setEndTime(performance.now()) | ||
} else if (data.type === ChainEventTypes.Error) { | ||
setError(new Error(data.error.message)) | ||
} | ||
break | ||
} | ||
|
||
case StreamEventTypes.Provider: { | ||
if (data.type === 'text-delta') { | ||
response += data.textDelta | ||
setResponseStream(response) | ||
} else if (data.type === 'finish') { | ||
setResponseStream(undefined) | ||
response = '' | ||
} | ||
break | ||
} | ||
default: | ||
break | ||
} | ||
} | ||
}, [parameters, runPromptAction]) | ||
|
||
useEffect(() => { | ||
if (runChainOnce.current) return | ||
|
||
runChainOnce.current = true // Prevent double-running when StrictMode is enabled | ||
runEvaluation() | ||
}, [runEvaluation]) | ||
|
||
return ( | ||
<div className='flex flex-col h-full'> | ||
<div | ||
ref={containerRef} | ||
className='flex flex-col gap-3 h-full overflow-y-auto pb-12' | ||
> | ||
<Text.H6M>Prompt</Text.H6M> | ||
<MessageList | ||
messages={conversation?.messages.slice(0, chainLength - 1) ?? []} | ||
/> | ||
{(conversation?.messages.length ?? 0) >= chainLength && ( | ||
<> | ||
<MessageList | ||
messages={ | ||
conversation?.messages.slice(chainLength - 1, chainLength) ?? [] | ||
} | ||
variant='accent' | ||
/> | ||
{endTime && <Timer timeMs={endTime - startTime} />} | ||
</> | ||
)} | ||
{(conversation?.messages.length ?? 0) > chainLength && ( | ||
<> | ||
<Text.H6M>Chat</Text.H6M> | ||
<MessageList | ||
messages={conversation!.messages.slice(chainLength)} | ||
variant='outline' | ||
/> | ||
</> | ||
)} | ||
{error ? ( | ||
<ErrorMessage error={error} /> | ||
) : ( | ||
<StreamMessage | ||
responseStream={responseStream} | ||
conversation={conversation} | ||
chainLength={chainLength} | ||
/> | ||
)} | ||
</div> | ||
<div className='flex relative flex-row w-full items-center justify-center'> | ||
<TokenUsage | ||
isScrolledToBottom={isScrolledToBottom} | ||
tokens={tokens} | ||
responseStream={responseStream} | ||
/> | ||
</div> | ||
</div> | ||
) | ||
} |
Oops, something went wrong.