Skip to content

Commit

Permalink
feature: batch evaluations backend #2
Browse files Browse the repository at this point in the history
This commit implements the full cycle of a batch evaluation. Now users
can run batch evaluations on top of their prompts with dataset data.
  • Loading branch information
geclos committed Sep 11, 2024
1 parent 50707ee commit 248afdc
Show file tree
Hide file tree
Showing 48 changed files with 982 additions and 332 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { zValidator } from '@hono/zod-validator'
import { LogSources } from '@latitude-data/core/browser'
import { runDocumentAtCommit } from '@latitude-data/core/services/commits/runDocumentAtCommit'
import { pipeToStream } from '$/common/pipeToStream'
import { queues } from '$/jobs'
import { Factory } from 'hono/factory'
import { streamSSE } from 'hono/streaming'
import { z } from 'zod'
Expand All @@ -21,8 +20,6 @@ export const runHandler = factory.createHandlers(
zValidator('json', runSchema),
async (c) => {
return streamSSE(c, async (stream) => {
const startTime = Date.now()

const { projectId, commitUuid } = c.req.param()
const { documentPath, parameters, source } = c.req.valid('json')

Expand All @@ -44,17 +41,6 @@ export const runHandler = factory.createHandlers(
}).then((r) => r.unwrap())

await pipeToStream(stream, result.stream)

queues.defaultQueue.jobs.enqueueCreateDocumentLogJob({
commit,
data: {
uuid: result.documentLogUuid,
documentUuid: document.documentUuid,
resolvedContent: result.resolvedContent,
parameters,
duration: Date.now() - startTime,
},
})
})
},
)
12 changes: 6 additions & 6 deletions apps/web/src/actions/evaluations/runBatch.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
import * as factories from '@latitude-data/core/factories'
import { beforeEach, describe, expect, it, vi } from 'vitest'

import { runBatchAction } from './runBatch'
import { runBatchEvaluationAction } from './runBatch'

const mocks = vi.hoisted(() => ({
getSession: vi.fn(),
Expand Down Expand Up @@ -40,7 +40,7 @@ vi.mock('@latitude-data/jobs', () => ({
describe('runBatchAction', () => {
describe('unauthorized', () => {
it('errors when the user is not authenticated', async () => {
const [_, error] = await runBatchAction({
const [_, error] = await runBatchEvaluationAction({
datasetId: 1,
projectId: 1,
documentUuid: 'doc-uuid',
Expand Down Expand Up @@ -95,7 +95,7 @@ describe('runBatchAction', () => {
})

it('successfully enqueues a batch evaluation job', async () => {
const [result, error] = await runBatchAction({
const [result, error] = await runBatchEvaluationAction({
datasetId: dataset.id,
projectId: project.id,
documentUuid: document.documentUuid,
Expand Down Expand Up @@ -128,7 +128,7 @@ describe('runBatchAction', () => {
})

it('handles optional parameters', async () => {
const [result, error] = await runBatchAction({
const [result, error] = await runBatchEvaluationAction({
datasetId: dataset.id,
projectId: project.id,
documentUuid: document.documentUuid,
Expand Down Expand Up @@ -156,7 +156,7 @@ describe('runBatchAction', () => {
})

it('handles errors when resources are not found', async () => {
const [_, error] = await runBatchAction({
const [_, error] = await runBatchEvaluationAction({
datasetId: 999999,
projectId: project.id,
documentUuid: document.documentUuid,
Expand All @@ -176,7 +176,7 @@ describe('runBatchAction', () => {
name: 'Test Evaluation 2',
})

const [result, error] = await runBatchAction({
const [result, error] = await runBatchEvaluationAction({
datasetId: dataset.id,
projectId: project.id,
documentUuid: document.documentUuid,
Expand Down
8 changes: 5 additions & 3 deletions apps/web/src/actions/evaluations/runBatch.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
'use server'

import {
DatasetsRepository,
DocumentVersionsRepository,
Expand All @@ -9,16 +11,16 @@ import { z } from 'zod'

import { authProcedure } from '../procedures'

export const runBatchAction = authProcedure
export const runBatchEvaluationAction = authProcedure
.createServerAction()
.input(
z.object({
datasetId: z.number(),
projectId: z.number(),
documentUuid: z.string(),
commitUuid: z.string(),
fromLine: z.number(),
toLine: z.number(),
fromLine: z.number().optional(),
toLine: z.number().optional(),
parameters: z.record(z.number()).optional(),
evaluationIds: z.array(z.number()),
}),
Expand Down
2 changes: 1 addition & 1 deletion apps/web/src/actions/prompts/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export const runPromptAction = authProcedure
.input(
z.object({
prompt: z.string(),
parameters: z.object({ messages: z.string(), last_message: z.string() }),
parameters: z.record(z.any()),
}),
)
.handler(async ({ ctx, input }) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,18 @@ import {
} from '$/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground/Chat'
import { readStreamableValue } from 'ai/rsc'

export const EVALUATION_PARAMETERS = ['messages', 'last_message']
export const EVALUATION_PARAMETERS = [
'messages',
'context',
'response',
'prompt',
'parameters',
'cost',
'latency',
'config',
]

export type Parameters = (typeof EVALUATION_PARAMETERS)[number]
export type Inputs = { [key in Parameters]: string }

export default function Chat({
clearChat,
Expand All @@ -36,7 +44,7 @@ export default function Chat({
}: {
clearChat: () => void
evaluation: EvaluationDto
parameters: Inputs
parameters: Record<string, string>
}) {
const [error, setError] = useState<Error | undefined>()
const [tokens, setTokens] = useState<number>(0)
Expand Down Expand Up @@ -79,7 +87,7 @@ export default function Chat({

const [data, error] = await runPromptAction({
prompt: evaluation.metadata.prompt,
parameters: parameters as { messages: string; last_message: string },
parameters,
})
if (error) {
setError(error)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,50 +1,24 @@
'use client'

import { useCallback, useEffect, useMemo, useState } from 'react'
import { capitalize } from 'lodash-es'

import {
ConversationMetadata,
Message,
MessageContent,
TextContent,
} from '@latitude-data/compiler'
import { ConversationMetadata } from '@latitude-data/compiler'
import { EvaluationDto } from '@latitude-data/core/browser'
import { Badge, Icon, Text, TextArea } from '@latitude-data/web-ui'
import {
formatContext,
formatConversation,
} from '@latitude-data/core/services/providerLogs/formatForEvaluation'
import { Badge, Icon, Input, Text } from '@latitude-data/web-ui'
import { convertParams } from '$/app/(private)/projects/[projectId]/versions/[commitUuid]/documents/[documentUuid]/_components/DocumentEditor/Editor/Playground'
import { ROUTES } from '$/services/routes'
import useProviderLogs from '$/stores/providerLogs'
import Link from 'next/link'
import { useSearchParams } from 'next/navigation'

import { Header } from '../Header'
import Chat, { EVALUATION_PARAMETERS, Inputs } from './Chat'
import Chat, { EVALUATION_PARAMETERS } from './Chat'
import Preview from './Preview'

function convertMessage(message: Message) {
if (typeof message.content === 'string') {
return `${capitalize(message.role)}: \n ${message.content}`
} else {
const content = message.content[0] as MessageContent
if (content.type === 'text') {
return `${capitalize(message.role)}: \n ${(content as TextContent).text}`
} else {
return `${capitalize(message.role)}: <${content.type} message>`
}
}
}

function convertMessages(messages: Message[]) {
return messages.map((message) => convertMessage(message)).join('\n')
}

function convertParams(inputs: Inputs) {
return Object.fromEntries(
Object.entries(inputs).map(([key, value]) => {
return [key, value]
}),
)
}

export default function Playground({
evaluation,
metadata,
Expand All @@ -53,7 +27,7 @@ export default function Playground({
metadata: ConversationMetadata
}) {
const [mode, setMode] = useState<'preview' | 'chat'>('preview')
const [inputs, setInputs] = useState<Inputs>(
const [inputs, setInputs] = useState<Record<string, string>>(
Object.fromEntries(
EVALUATION_PARAMETERS.map((param: string) => [param, '']),
),
Expand All @@ -74,8 +48,14 @@ export default function Playground({
useEffect(() => {
if (providerLog) {
setInputs({
messages: convertMessages(providerLog.messages),
last_message: `Assistant: ${providerLog.responseText}`,
messages: JSON.stringify(formatConversation(providerLog)),
context: JSON.stringify(formatContext(providerLog)),
response: providerLog.responseText,
prompt: '',
parameters: '',
config: '',
duration: '',
cost: '',
})
}
}, [setInput, providerLog])
Expand Down Expand Up @@ -106,7 +86,7 @@ export default function Playground({
>
<Badge variant='accent'>&#123;&#123;{param}&#125;&#125;</Badge>
<div className='flex flex-grow w-full'>
<TextArea
<Input
value={value}
onChange={(e) => setInput(param, e.currentTarget.value)}
/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import { Header } from '../Header'
import Chat from './Chat'
import Preview from './Preview'

function convertParams(
inputs: Record<string, string>,
): Record<string, unknown> {
export function convertParams(inputs: Record<string, string>) {
return Object.fromEntries(
Object.entries(inputs).map(([key, value]) => {
return [key, value]
try {
return [key, JSON.parse(value)]
} catch (e) {
return [key, value]
}
}),
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
'use client'

import { Button, TableWithHeader } from '@latitude-data/web-ui'
import { runBatchEvaluationAction } from '$/actions/evaluations/runBatch'
import useLatitudeAction from '$/hooks/useLatitudeAction'
import { ROUTES } from '$/services/routes'
import Link from 'next/link'

export function Actions({
projectId,
commitUuid,
documentUuid,
}: {
projectId: string
commitUuid: string
documentUuid: string
}) {
const href = ROUTES.projects
.detail({ id: Number(projectId) })
.commits.detail({ uuid: commitUuid })
.documents.detail({ uuid: documentUuid }).evaluations.dashboard.connect.root
const { execute: executeBatchEvaluation } = useLatitudeAction(
runBatchEvaluationAction,
)

return (
<>
<Link href={href}>
<TableWithHeader.Button>Connect evaluation</TableWithHeader.Button>
</Link>
<Button
fancy
onClick={() =>
executeBatchEvaluation({
projectId: Number(projectId),
commitUuid,
documentUuid,
evaluationIds: [1],
datasetId: 1,
parameters: {
nombre_usuario: 0,
nombre: 2,
},
})
}
>
Execute batch evaluation
</Button>
</>
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ import { ReactNode } from 'react'

import { TableWithHeader } from '@latitude-data/web-ui'
import { getEvaluationsByDocumentUuidCached } from '$/app/(private)/_data-access'
import { ROUTES } from '$/services/routes'
import Link from 'next/link'

import { Actions } from './_components/Actions'
import EvaluationsLayoutClient from './_components/Layout'

export default async function EvaluationsLayout({
Expand All @@ -15,20 +14,17 @@ export default async function EvaluationsLayout({
params: { projectId: string; commitUuid: string; documentUuid: string }
}) {
const evaluations = await getEvaluationsByDocumentUuidCached(documentUuid)
const href = ROUTES.projects
.detail({ id: Number(projectId) })
.commits.detail({ uuid: commitUuid })
.documents.detail({ uuid: documentUuid }).evaluations.dashboard.connect.root

return (
<div className='w-full p-6'>
{children}
<TableWithHeader
title='Evaluations'
actions={
<Link href={href}>
<TableWithHeader.Button>Connect evaluation</TableWithHeader.Button>
</Link>
<Actions
projectId={projectId}
commitUuid={commitUuid}
documentUuid={documentUuid}
/>
}
table={<EvaluationsLayoutClient evaluations={evaluations} />}
/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ export function DocumentLogMetadata({
trigger={
<div className='flex flex-row items-center gap-x-1'>
<Text.H5 color='foregroundMuted'>
{formatCostInMillicents(documentLog.cost_in_millicents ?? 0)}
{formatCostInMillicents(documentLog.costInMillicents ?? 0)}
</Text.H5>
<Icon name='info' className='text-muted-foreground' />
</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export const DocumentLogsTable = ({
</TableCell>
<TableCell>
<Text.H4 noWrap>
{formatCostInMillicents(documentLog.cost_in_millicents || 0)}
{formatCostInMillicents(documentLog.costInMillicents || 0)}
</Text.H4>
</TableCell>
</TableRow>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
'use server'

import { ReactNode } from 'react'

import {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Username, Identifier,First name,Last name
booker12,9012,Rachel,Booker
grey07,2070,Laura,Grey
johnson81,4081,Craig,Johnson
jenkins46,9346,Mary,Jenkins
smith79,5079,Jamie,Smith

Loading

0 comments on commit 248afdc

Please sign in to comment.