Skip to content

Commit

Permalink
feature: only add json schema to chain last step
Browse files Browse the repository at this point in the history
  • Loading branch information
geclos committed Sep 19, 2024
1 parent 97bc51b commit 8516048
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 24 deletions.
4 changes: 2 additions & 2 deletions packages/core/src/services/ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ export async function ai({
documentLogUuid,
source,
onFinish,
schema = config.schema,
output = config.schema?.type || 'no-schema',
schema,
output,
transactionalLogs = false,
}: {
workspace: Workspace
Expand Down
198 changes: 197 additions & 1 deletion packages/core/src/services/chains/run.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ describe('runChain', () => {
expect.objectContaining({
config: { provider: 'openai', model: 'gpt-3.5-turbo' },
schema: undefined,
output: undefined,
output: 'no-schema',
}),
)
})
Expand Down Expand Up @@ -299,4 +299,200 @@ describe('runChain', () => {
}),
)
})

it('runs a chain with object schema and output', async () => {
const mockSchema = {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
},
} as const

const mockAiResponse = {
object: Promise.resolve({ name: 'John', age: 30 }),
usage: Promise.resolve({ totalTokens: 15 }),
fullStream: new ReadableStream({
start(controller) {
controller.enqueue({
type: 'object',
object: { name: 'John', age: 30 },
})
controller.close()
},
}),
}

vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any)

vi.mocked(mockChain.step!).mockResolvedValue({
completed: true,
conversation: {
messages: [
{
role: MessageRole.user,
content: [{ type: ContentType.text, text: 'Test message' }],
},
],
config: { provider: 'openai', model: 'gpt-3.5-turbo' },
},
})

const result = await runChain({
workspace,
chain: mockChain as Chain,
apikeys,
source: LogSources.API,
configOverrides: {
schema: mockSchema,
output: 'object',
},
})

expect(result.ok).toBe(true)
if (!result.ok) return

const response = await result.value.response
expect(response).toEqual({
documentLogUuid: expect.any(String),
object: { name: 'John', age: 30 },
text: '{"name":"John","age":30}',
usage: { totalTokens: 15 },
})

expect(aiModule.ai).toHaveBeenCalledWith(
expect.objectContaining({
schema: mockSchema,
output: 'object',
}),
)
})

it('runs a chain with array schema and output', async () => {
const mockSchema = {
type: 'array',
items: {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
},
},
} as const

const mockAiResponse = {
object: Promise.resolve([
{ name: 'John', age: 30 },
{ name: 'Jane', age: 25 },
]),
usage: Promise.resolve({ totalTokens: 20 }),
fullStream: new ReadableStream({
start(controller) {
controller.enqueue({
type: 'object',
object: [
{ name: 'John', age: 30 },
{ name: 'Jane', age: 25 },
],
})
controller.close()
},
}),
}

vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any)

vi.mocked(mockChain.step!).mockResolvedValue({
completed: true,
conversation: {
messages: [
{
role: MessageRole.user,
content: [{ type: ContentType.text, text: 'Test message' }],
},
],
config: { provider: 'openai', model: 'gpt-3.5-turbo' },
},
})

const result = await runChain({
workspace,
chain: mockChain as Chain,
apikeys,
source: LogSources.API,
configOverrides: {
schema: mockSchema,
output: 'array',
},
})

expect(result.ok).toBe(true)
if (!result.ok) return

const response = await result.value.response
expect(response).toEqual({
documentLogUuid: expect.any(String),
object: [
{ name: 'John', age: 30 },
{ name: 'Jane', age: 25 },
],
text: '[{"name":"John","age":30},{"name":"Jane","age":25}]',
usage: { totalTokens: 20 },
})

expect(aiModule.ai).toHaveBeenCalledWith(
expect.objectContaining({
schema: mockSchema,
output: 'array',
}),
)
})

it('runs a chain with no-schema output', async () => {
const mockAiResponse = createMockAiResponse(
'AI response without schema',
10,
)
vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any)

vi.mocked(mockChain.step!).mockResolvedValue({
completed: true,
conversation: {
messages: [
{
role: MessageRole.user,
content: [{ type: ContentType.text, text: 'Test message' }],
},
],
config: { provider: 'openai', model: 'gpt-3.5-turbo' },
},
})

const result = await runChain({
workspace,
chain: mockChain as Chain,
apikeys,
source: LogSources.API,
configOverrides: {
output: 'no-schema',
},
})

expect(result.ok).toBe(true)
if (!result.ok) return

const response = await result.value.response
expect(response).toEqual({
documentLogUuid: expect.any(String),
text: 'AI response without schema',
usage: { totalTokens: 10 },
toolCalls: [],
})

expect(aiModule.ai).toHaveBeenCalledWith(
expect.objectContaining({
output: 'no-schema',
}),
)
})
})
59 changes: 38 additions & 21 deletions packages/core/src/services/chains/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ import { streamToGenerator } from '../../lib/streamToGenerator'
import { ai, Config, validateConfig } from '../ai'

export type CachedApiKeys = Map<string, ProviderApiKey>
type ConfigOverrides =
| {
schema: JSONSchema7
output: 'object' | 'array'
}
| { output: 'no-schema' }

export async function runChain({
workspace,
Expand All @@ -34,10 +40,7 @@ export async function runChain({
generateUUID?: () => string
source: LogSources
apikeys: CachedApiKeys
configOverrides?: {
schema: JSONSchema7
output: 'object' | 'array' | 'no-schema'
}
configOverrides?: ConfigOverrides
}) {
const documentLogUuid = generateUUID()

Expand Down Expand Up @@ -96,40 +99,37 @@ async function iterate({
previousApiKey?: ProviderApiKey
documentLogUuid: string
previousResponse?: ChainTextResponse
configOverrides?: {
schema: JSONSchema7
output: 'object' | 'array' | 'no-schema'
}
configOverrides?: ConfigOverrides
}) {
try {
const stepResult = await computeStepData({
const step = await computeStepData({
chain,
previousResponse,
apikeys,
apiKey: previousApiKey,
sentCount: previousCount,
})

publishStepStartEvent(controller, stepResult)
publishStepStartEvent(controller, step)

const aiResult = await ai({
workspace,
source,
documentLogUuid,
messages: stepResult.conversation.messages,
config: stepResult.config,
provider: stepResult.apiKey,
schema: configOverrides?.schema,
output: configOverrides?.output,
transactionalLogs: stepResult.completed,
messages: step.conversation.messages,
config: step.config,
provider: step.apiKey,
schema: getSchemaForAI(step, configOverrides),
output: getOutputForAI(step, configOverrides),
transactionalLogs: step.completed,
})

await streamAIResult(controller, aiResult)

const response = await createChainResponse(aiResult, documentLogUuid)

if (stepResult.completed) {
await handleCompletedChain(controller, stepResult, response)
if (step.completed) {
await handleCompletedChain(controller, step, response)
return response
} else {
publishStepCompleteEvent(controller, response)
Expand All @@ -141,8 +141,8 @@ async function iterate({
documentLogUuid,
apikeys,
controller,
previousApiKey: stepResult.apiKey,
previousCount: stepResult.sentCount,
previousApiKey: step.apiKey,
previousCount: step.sentCount,
previousResponse: response as ChainTextResponse,
configOverrides,
})
Expand All @@ -153,7 +153,24 @@ async function iterate({
}
}

// Helper functions
function getSchemaForAI(
step: Awaited<ReturnType<typeof computeStepData>>,
configOverrides?: ConfigOverrides,
) {
return step.completed
? // @ts-expect-error - schema does not exist in some types of configOverrides which is fine
configOverrides?.schema || step.config.schema
: undefined
}

function getOutputForAI(
step: Awaited<ReturnType<typeof computeStepData>>,
configOverrides?: ConfigOverrides,
) {
return step.completed
? configOverrides?.output || step.config.schema?.type || 'no-schema'
: undefined
}

function publishStepStartEvent(
controller: ReadableStreamDefaultController,
Expand Down

0 comments on commit 8516048

Please sign in to comment.