Skip to content

Commit

Permalink
Response schema
Browse files Browse the repository at this point in the history
  • Loading branch information
csansoon committed Sep 11, 2024
1 parent 50707ee commit 311b280
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
2 changes: 1 addition & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"@latitude-data/mailers": "workspace:^",
"@sindresorhus/slugify": "^2.2.1",
"@t3-oss/env-core": "^0.11.1",
"ai": "^3.2.42",
"ai": "^3.3.3",
"argon2": "^0.41.0",
"csv-parse": "^5.5.6",
"drizzle-orm": "^0.33.0",
Expand Down
42 changes: 30 additions & 12 deletions packages/core/src/services/ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import {
CompletionTokenUsage,
CoreMessage,
FinishReason,
jsonSchema,
streamObject,
streamText,
} from 'ai'
import { v4 } from 'uuid'
Expand All @@ -17,6 +19,7 @@ import { z } from 'zod'
import { LogSources, ProviderApiKey, Providers } from '../../browser'
import { publisher } from '../../events/publisher'
import { CreateProviderLogProps } from '../providerLogs/create'
import { JSONSchema7 } from 'json-schema'

export type FinishCallbackEvent = {
finishReason: FinishReason
Expand All @@ -38,14 +41,17 @@ export type FinishCallbackEvent = {
}
export type FinishCallback = (event: FinishCallbackEvent) => void

export type Config = {
export type GenerationConfig = {
[key: string]: any
provider: string
model: string
schema?: JSONSchema7
azure?: { resourceName: string }
}

export type PartialConfig = Omit<Config, 'provider'>

export type Config = GenerationConfig & {
provider: string
}

const GROQ_API_URL = 'https://api.groq.com/openai/v1'

Expand All @@ -56,7 +62,7 @@ function createProvider({
}: {
provider: Providers
apiKey: string
config?: PartialConfig
config?: GenerationConfig
}) {
switch (provider) {
case Providers.OpenAI:
Expand Down Expand Up @@ -100,12 +106,12 @@ export async function ai(
provider: apiProvider,
prompt,
messages,
config,
config: _config,
documentLogUuid,
source,
}: {
provider: ProviderApiKey
config: PartialConfig
config: GenerationConfig
messages: Message[]
documentLogUuid?: string
prompt?: string
Expand All @@ -126,12 +132,22 @@ export async function ai(
} = apiProvider
const model = config.model
const m = createProvider({ provider, apiKey, config })(model)
const { provider, token: apiKey, id: providerId } = apiProvider

const config = {
..._config,
schema: _config.schema ? jsonSchema(_config.schema) : undefined,
structuredOutputs: _config.schema !== undefined,
} as GenerationConfig

const model = createProvider({ provider, apiKey, config })(config.model)

const result = await streamText({
model: m,
const props = {
...config,
model,
prompt,
messages: messages as CoreMessage[],
onFinish: (event) => {
onFinish: (event: FinishCallbackEvent) => {
publisher.publish({
type: 'aiProviderCallCompleted',
data: {
Expand All @@ -141,7 +157,7 @@ export async function ai(
documentLogUuid,
providerId,
providerType,
model,
model: config.model,
config,
messages,
responseText: event.text,
Expand All @@ -159,11 +175,13 @@ export async function ai(
},
})

const result = config.structuredOutputs ? await streamObject(props) : await streamText(props)

return {
fullStream: result.fullStream,
text: result.text,
text: config.structuredOutputs ? result.object : result.text,
usage: result.usage,
toolCalls: result.toolCalls,
toolCalls: config.structuredOutputs ? [] : result.toolCalls,
}
}

Expand Down

0 comments on commit 311b280

Please sign in to comment.