diff --git a/packages/core/package.json b/packages/core/package.json index be2b5e3f3..496a2b775 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -85,7 +85,7 @@ "@latitude-data/mailers": "workspace:^", "@sindresorhus/slugify": "^2.2.1", "@t3-oss/env-core": "^0.11.1", - "ai": "^3.2.42", + "ai": "^3.3.3", "argon2": "^0.41.0", "csv-parse": "^5.5.6", "drizzle-orm": "^0.33.0", diff --git a/packages/core/src/services/ai/index.ts b/packages/core/src/services/ai/index.ts index ed2627ddb..68bfd3546 100644 --- a/packages/core/src/services/ai/index.ts +++ b/packages/core/src/services/ai/index.ts @@ -9,6 +9,8 @@ import { CompletionTokenUsage, CoreMessage, FinishReason, + jsonSchema, + streamObject, streamText, } from 'ai' import { v4 } from 'uuid' @@ -17,6 +19,7 @@ import { z } from 'zod' import { LogSources, ProviderApiKey, Providers } from '../../browser' import { publisher } from '../../events/publisher' import { CreateProviderLogProps } from '../providerLogs/create' +import { JSONSchema7 } from 'json-schema' export type FinishCallbackEvent = { finishReason: FinishReason @@ -38,14 +41,17 @@ export type FinishCallbackEvent = { } export type FinishCallback = (event: FinishCallbackEvent) => void -export type Config = { +export type GenerationConfig = { [key: string]: any - provider: string model: string + schema?: JSONSchema7 azure?: { resourceName: string } } -export type PartialConfig = Omit + +export type Config = GenerationConfig & { + provider: string +} const GROQ_API_URL = 'https://api.groq.com/openai/v1' @@ -56,7 +62,7 @@ function createProvider({ }: { provider: Providers apiKey: string - config?: PartialConfig + config?: GenerationConfig }) { switch (provider) { case Providers.OpenAI: @@ -100,12 +106,12 @@ export async function ai( provider: apiProvider, prompt, messages, - config, + config: _config, documentLogUuid, source, }: { provider: ProviderApiKey - config: PartialConfig + config: GenerationConfig messages: Message[] documentLogUuid?: string prompt?: string @@ -126,12 +132,22 @@ export async function ai( } = apiProvider const model = config.model const m = createProvider({ provider, apiKey, config })(model) + const { provider, token: apiKey, id: providerId } = apiProvider + + const config = { + ..._config, + schema: _config.schema ? jsonSchema(_config.schema) : undefined, + structuredOutputs: _config.schema !== undefined, + } as GenerationConfig + + const model = createProvider({ provider, apiKey, config })(config.model) - const result = await streamText({ - model: m, + const props = { + ...config, + model, prompt, messages: messages as CoreMessage[], - onFinish: (event) => { + onFinish: (event: FinishCallbackEvent) => { publisher.publish({ type: 'aiProviderCallCompleted', data: { @@ -141,7 +157,7 @@ export async function ai( documentLogUuid, providerId, providerType, - model, + model: config.model, config, messages, responseText: event.text, @@ -159,11 +175,13 @@ export async function ai( }, }) + const result = config.structuredOutputs ? await streamObject(props) : await streamText(props) + return { fullStream: result.fullStream, - text: result.text, + text: config.structuredOutputs ? result.object : result.text, usage: result.usage, - toolCalls: result.toolCalls, + toolCalls: config.structuredOutputs ? [] : result.toolCalls, } }