Skip to content

Commit

Permalink
fix: assign evaluation results to the evaluation provider log
Browse files Browse the repository at this point in the history
  • Loading branch information
geclos committed Sep 19, 2024
1 parent fc7696a commit 7de1eb2
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 141 deletions.
19 changes: 14 additions & 5 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
TextStreamPart,
} from 'ai'

import { ProviderLog } from './browser'
import { Config } from './services/ai'

export const LATITUDE_EVENT = 'latitudeEventsChannel'
Expand Down Expand Up @@ -43,19 +44,27 @@ export type ChainStepTextResponse = {
text: string
usage: CompletionTokenUsage
toolCalls: ToolCall[]
documentLogUuid: string
providerLog: undefined
}
export type ChainStepObjectResponse = {
object: any
text: string
usage: CompletionTokenUsage
documentLogUuid: string
providerLog: undefined
}

export type ChainTextResponse = ChainStepTextResponse & {
documentLogUuid: string
export type ChainTextResponse = Omit<ChainStepTextResponse, 'providerLog'> & {
providerLog: ProviderLog
}
export type ChainObjectResponse = ChainStepObjectResponse & {
documentLogUuid: string
export type ChainObjectResponse = Omit<
ChainStepObjectResponse,
'providerLog'
> & {
providerLog: ProviderLog
}
export type ChainStepResponse = ChainStepTextResponse | ChainStepObjectResponse
export type ChainCallResponse = ChainTextResponse | ChainObjectResponse

export enum LogSources {
Expand Down Expand Up @@ -89,7 +98,7 @@ type LatitudeEventData =
}
| {
type: ChainEventTypes.StepComplete
response: ChainCallResponse
response: ChainStepResponse
}
| {
type: ChainEventTypes.Complete
Expand Down
165 changes: 100 additions & 65 deletions packages/core/src/services/ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ import {
import { JSONSchema7 } from 'json-schema'
import { v4 } from 'uuid'

import { LogSources, ProviderApiKey, Workspace } from '../../browser'
import {
LogSources,
ProviderApiKey,
ProviderLog,
Workspace,
} from '../../browser'
import { cache } from '../../cache'
import { publisher } from '../../events/publisher'
import { createProviderLog } from '../providerLogs/create'
Expand Down Expand Up @@ -49,7 +54,6 @@ export async function ai({
config,
documentLogUuid,
source,
onFinish,
schema,
output,
transactionalLogs = false,
Expand All @@ -64,104 +68,58 @@ export async function ai({
schema?: JSONSchema7
output?: 'object' | 'array' | 'no-schema'
transactionalLogs?: boolean
onFinish?: FinishCallback
}) {
await checkDefaultProviderUsage({ provider: apiProvider, workspace })

const startTime = Date.now()
const {
provider,
token: apiKey,
id: providerId,
provider: providerType,
} = apiProvider
const { provider, token: apiKey } = apiProvider
const model = config.model
const m = createProvider({ provider, apiKey, config })(model)

const commonOptions = {
model: m,
prompt,
messages: messages as CoreMessage[],
}

const createFinishHandler = (isStructured: boolean) => async (event: any) => {
const commonData = {
uuid: v4(),
source,
generatedAt: new Date(),
documentLogUuid,
providerId,
providerType,
model,
config,
messages,
toolCalls: event.toolCalls?.map((t: any) => ({
id: t.toolCallId,
name: t.toolName,
arguments: t.args,
})),
usage: event.usage,
duration: Date.now() - startTime,
}

const payload = {
type: 'aiProviderCallCompleted' as 'aiProviderCallCompleted',
data: {
...commonData,
responseText: event.text,
responseObject: isStructured ? event.object : undefined,
},
}

publisher.publishLater({
type: payload.type,
data: {
...payload.data,
workspaceId: apiProvider.workspaceId,
},
})

let providerLogUuid
if (transactionalLogs) {
const providerLog = await createProviderLog(payload.data).then((r) =>
r.unwrap(),
)
providerLogUuid = providerLog.uuid
} else {
const queues = await setupJobs()
queues.defaultQueue.jobs.enqueueCreateProviderLogJob(payload.data)
}

onFinish?.({ ...event, providerLogUuid })
}
const { onFinish, providerLog } = createFinishHandler({
isStructured: false,
startTime,
apiProvider,
source,
documentLogUuid,
messages,
config,
transactionalLogs,
})

if (schema && output) {
const result = await streamObject({
...commonOptions,
schema: jsonSchema(schema),
// @ts-expect-error - output is vale but depending on the type of schema
// @ts-expect-error - output is valid but depending on the type of schema
// there might be a mismatch (e.g you pass an object schema but the
// output is "array"). Not really an issue we need to defend atm
// output is "array"). Not really an issue we need to defend atm.
output,
onFinish: createFinishHandler(true),
onFinish,
})

return {
fullStream: result.fullStream,
object: result.object,
usage: result.usage,
providerLog,
}
} else {
const result = await streamText({
...commonOptions,
onFinish: createFinishHandler(false),
onFinish,
})

return {
fullStream: result.fullStream,
text: result.text,
usage: result.usage,
toolCalls: result.toolCalls,
providerLog,
}
}
}
Expand All @@ -185,6 +143,83 @@ const checkDefaultProviderUsage = async ({
}
}

const createFinishHandler = ({
isStructured,
startTime,
apiProvider,
source,
messages,
config,
transactionalLogs,
documentLogUuid,
}: {
isStructured: boolean
startTime: number
apiProvider: ProviderApiKey
source: LogSources
messages: Message[]
config: PartialConfig
transactionalLogs: boolean
documentLogUuid?: string
}) => {
let resolveProviderLog: (value?: ProviderLog) => void
const providerLogPromise = new Promise<ProviderLog | undefined>((resolve) => {
resolveProviderLog = resolve
})

return {
providerLog: providerLogPromise,
onFinish: async (event: any) => {
const commonData = {
uuid: v4(),
source,
generatedAt: new Date(),
documentLogUuid,
providerId: apiProvider.id,
providerType: apiProvider.provider,
model: config.model,
config,
messages,
toolCalls: event.toolCalls?.map((t: any) => ({
id: t.toolCallId,
name: t.toolName,
arguments: t.args,
})),
usage: event.usage,
duration: Date.now() - startTime,
}

const payload = {
type: 'aiProviderCallCompleted' as 'aiProviderCallCompleted',
data: {
...commonData,
responseText: event.text,
responseObject: isStructured ? event.object : undefined,
},
}

publisher.publishLater({
type: payload.type,
data: {
...payload.data,
workspaceId: apiProvider.workspaceId,
},
})

if (transactionalLogs) {
const providerLog = await createProviderLog(payload.data).then((r) =>
r.unwrap(),
)
resolveProviderLog(providerLog)
} else {
const queues = await setupJobs()
queues.defaultQueue.jobs.enqueueCreateProviderLogJob(payload.data)
resolveProviderLog()
}
},
}
}

export { estimateCost } from './estimateCost'
export { validateConfig } from './helpers'
export type { PartialConfig, Config } from './helpers'
Loading

0 comments on commit 7de1eb2

Please sign in to comment.