Skip to content

Commit

Permalink
feature: structured outputs (#104)
Browse files Browse the repository at this point in the history
Adds support for structured outputs responses from providers and
implements it in evaluations and across our UI

Co-authored-by: Gerard Clos <[email protected]>
  • Loading branch information
csansoon and geclos authored Sep 16, 2024
1 parent 1ee8ef1 commit 6240403
Show file tree
Hide file tree
Showing 39 changed files with 3,145 additions and 389 deletions.
8 changes: 6 additions & 2 deletions apps/web/src/actions/providerLogs/fetch.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
'use server'

import { ProviderLogsRepository } from '@latitude-data/core/repositories'
import providerLogPresenter from '$/presenters/providerLogPresenter'
import { z } from 'zod'

import { authProcedure } from '../procedures'
Expand Down Expand Up @@ -30,7 +31,7 @@ export const getProviderLogsAction = authProcedure
result = await scope.findAll({ limit: 1000 }).then((r) => r.unwrap())
}

return result
return result.map(providerLogPresenter)
})

export const getProviderLogAction = authProcedure
Expand All @@ -39,5 +40,8 @@ export const getProviderLogAction = authProcedure
.handler(async ({ input, ctx }) => {
const { providerLogId } = input
const scope = new ProviderLogsRepository(ctx.workspace.id)
return await scope.find(providerLogId).then((r) => r.unwrap())
return await scope
.find(providerLogId)
.then((r) => r.unwrap())
.then(providerLogPresenter)
})
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export default function Playground({
setInputs({
messages: JSON.stringify(formatConversation(providerLog)),
context: JSON.stringify(formatContext(providerLog)),
response: providerLog.responseText,
response: providerLog.response,
prompt: '',
parameters: '',
config: '',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { useState } from 'react'
import { capitalize } from 'lodash-es'

import { MessageContent, TextContent } from '@latitude-data/compiler'
import { HEAD_COMMIT } from '@latitude-data/core/browser'
import { HEAD_COMMIT, ProviderLogDto } from '@latitude-data/core/browser'
import {
Badge,
Button,
Expand Down Expand Up @@ -193,7 +193,9 @@ const ProviderLogMessages = ({
providerLogUuid?: string
}) => {
const { data } = useProviderLogs({ documentUuid })
const providerLog = data?.find((log) => log.uuid === providerLogUuid)
const providerLog = data?.find(
(log) => log.uuid === providerLogUuid,
) as ProviderLogDto
if (!providerLog) {
return (
<div className='flex flex-col items-center justify-center rounded-lg border border-2 bg-secondary p-4 h-[480px]'>
Expand Down Expand Up @@ -225,7 +227,7 @@ const ProviderLogMessages = ({
<Badge variant={roleVariant('assistant')}>Assistant</Badge>
</div>
<div className='pl-4'>
<Text.H6M>{printMessageContent(providerLog.responseText)}</Text.H6M>
<Text.H6M>{printMessageContent(providerLog.response)}</Text.H6M>
</div>
</div>
</div>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { useCallback, useContext, useEffect, useRef, useState } from 'react'

import {
AssistantMessage,
ContentType,
Conversation,
Message as ConversationMessage,
Expand Down Expand Up @@ -95,8 +96,8 @@ export default function Chat({
const { event, data } = serverEvent
const hasMessages = 'messages' in data
if (hasMessages) {
data.messages.forEach(addMessageToConversation)
messagesCount += data.messages.length
data.messages!.forEach(addMessageToConversation)
messagesCount += data.messages!.length
}

switch (event) {
Expand All @@ -109,6 +110,7 @@ export default function Chat({
} else if (data.type === ChainEventTypes.Error) {
setError(new Error(data.error.message))
}

break
}

Expand Down Expand Up @@ -161,17 +163,17 @@ export default function Chat({

const { event, data } = serverEvent

const hasMessages = 'messages' in data

if (hasMessages) {
data.messages.forEach(addMessageToConversation)
}

switch (event) {
case StreamEventTypes.Latitude: {
if (data.type === ChainEventTypes.Error) {
setError(new Error(data.error.message))
} else if (data.type === ChainEventTypes.Complete) {
addMessageToConversation({
role: MessageRole.assistant,
content: data.response.text,
} as AssistantMessage)
}

break
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import { useMemo } from 'react'

import { AssistantMessage, Message, MessageRole } from '@latitude-data/compiler'
import { ProviderLog } from '@latitude-data/core/browser'
import { ProviderLogDto } from '@latitude-data/core/browser'
import { MessageList } from '@latitude-data/web-ui'

export function EvaluationResultMessages({
providerLog,
}: {
providerLog?: ProviderLog
providerLog?: ProviderLogDto
}) {
const messages = useMemo<Message[]>(() => {
if (!providerLog) return [] as Message[]

const responseMessage = {
role: MessageRole.assistant,
content: providerLog.responseText,
content: providerLog.response,
toolCalls: providerLog.toolCalls,
} as AssistantMessage

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ReactNode } from 'react'

import { ProviderLog } from '@latitude-data/core/browser'
import { ProviderLogDto } from '@latitude-data/core/browser'
import { EvaluationResultWithMetadata } from '@latitude-data/core/repositories'
import {
ClickToCopy,
Expand Down Expand Up @@ -48,7 +48,7 @@ export function EvaluationResultMetadata({
providerLog,
}: {
evaluationResult: EvaluationResultWithMetadata
providerLog?: ProviderLog
providerLog?: ProviderLogDto
}) {
const { data: providers, isLoading: providersLoading } = useProviderApiKeys()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import { useState } from 'react'

import { ProviderLog } from '@latitude-data/core/browser'
import { ProviderLogDto } from '@latitude-data/core/browser'
import { EvaluationResultWithMetadata } from '@latitude-data/core/repositories'
import { TabSelector } from '@latitude-data/web-ui'

Expand All @@ -14,7 +14,7 @@ export function EvaluationResultInfo({
providerLog,
}: {
evaluationResult: EvaluationResultWithMetadata
providerLog?: ProviderLog
providerLog?: ProviderLogDto
}) {
const [selectedTab, setSelectedTab] = useState<string>('metadata')
return (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import { useMemo } from 'react'

import { AssistantMessage, Message, MessageRole } from '@latitude-data/compiler'
import { ProviderLog } from '@latitude-data/core/browser'
import { ProviderLogDto } from '@latitude-data/core/browser'
import { MessageList } from '@latitude-data/web-ui'

export function DocumentLogMessages({
providerLogs,
}: {
providerLogs?: ProviderLog[]
providerLogs?: ProviderLogDto[]
}) {
const messages = useMemo<Message[]>(() => {
const lastLog = providerLogs?.[providerLogs.length - 1]
if (!lastLog) return [] as Message[]

const responseMessage = {
role: MessageRole.assistant,
content: lastLog.responseText,
content: lastLog.response,
toolCalls: lastLog.toolCalls,
} as AssistantMessage

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ReactNode, useMemo } from 'react'

import { ProviderLog } from '@latitude-data/core/browser'
import { ProviderLogDto } from '@latitude-data/core/browser'
import { DocumentLogWithMetadata } from '@latitude-data/core/repositories'
import {
ClickToCopy,
Expand Down Expand Up @@ -48,7 +48,7 @@ export function DocumentLogMetadata({
providerLogs,
}: {
documentLog: DocumentLogWithMetadata
providerLogs?: ProviderLog[]
providerLogs?: ProviderLogDto[]
}) {
const { data: providers, isLoading: providersLoading } = useProviderApiKeys()
const lastProviderLog = useMemo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import { useState } from 'react'

import { ProviderLog } from '@latitude-data/core/browser'
import { ProviderLogDto } from '@latitude-data/core/browser'
import { DocumentLogWithMetadata } from '@latitude-data/core/repositories'
import { TabSelector } from '@latitude-data/web-ui'

Expand All @@ -14,7 +14,7 @@ export function DocumentLogInfo({
providerLogs,
}: {
documentLog: DocumentLogWithMetadata
providerLogs?: ProviderLog[]
providerLogs?: ProviderLogDto[]
}) {
const [selectedTab, setSelectedTab] = useState<string>('metadata')
return (
Expand Down
16 changes: 16 additions & 0 deletions apps/web/src/presenters/providerLogPresenter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { omit } from 'lodash-es'

import type { ProviderLog, ProviderLogDto } from '@latitude-data/core/browser'

export default function providerLogPresenter(
providerLog: ProviderLog,
): ProviderLogDto {
return {
...omit(providerLog, 'responseText', 'responseObject'),
response:
providerLog.responseText ||
(providerLog.responseObject
? JSON.stringify(providerLog.responseObject)
: ''),
}
}
15 changes: 8 additions & 7 deletions apps/web/src/stores/providerLogs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import { compact } from 'lodash-es'

import type { ProviderLog } from '@latitude-data/core/browser'
import { ProviderLogDto } from '@latitude-data/core/browser'
import { useToast } from '@latitude-data/web-ui'
import {
getProviderLogAction,
Expand All @@ -19,10 +19,10 @@ export default function useProviderLogs(
) {
const { toast } = useToast()
const {
data = undefined,
data = [],
isLoading,
error: swrError,
} = useSWR<ProviderLog[] | undefined>(
} = useSWR<ProviderLogDto[]>(
compact(['providerLogs', documentUuid, documentLogUuid]),
async () => {
const [data, error] = await getProviderLogsAction({
Expand All @@ -38,10 +38,11 @@ export default function useProviderLogs(
description: error.formErrors?.[0] || error.message,
variant: 'destructive',
})
return undefined

return []
}

return data as ProviderLog[]
return data
},
opts,
)
Expand All @@ -62,7 +63,7 @@ export function useProviderLog(
data = undefined,
isLoading,
error: swrError,
} = useSWR<ProviderLog | undefined>(
} = useSWR<ProviderLogDto | undefined>(
compact(['providerLog', providerLogId]),
async () => {
if (!providerLogId) return undefined
Expand All @@ -82,7 +83,7 @@ export function useProviderLog(
return undefined
}

return data as ProviderLog
return data
},
opts,
)
Expand Down
3 changes: 3 additions & 0 deletions packages/core/drizzle/0052_famous_daimon_hellstrom.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE "latitude"."provider_logs" ALTER COLUMN "response_text" DROP DEFAULT;--> statement-breakpoint
ALTER TABLE "latitude"."provider_logs" ALTER COLUMN "response_text" DROP NOT NULL;--> statement-breakpoint
ALTER TABLE "latitude"."provider_logs" ADD COLUMN "response_object" jsonb;
Loading

0 comments on commit 6240403

Please sign in to comment.