diff --git a/packages/core/src/services/ai/index.ts b/packages/core/src/services/ai/index.ts index 1388ef470..319fd6f0c 100644 --- a/packages/core/src/services/ai/index.ts +++ b/packages/core/src/services/ai/index.ts @@ -50,8 +50,8 @@ export async function ai({ documentLogUuid, source, onFinish, - schema = config.schema, - output = config.schema?.type || 'no-schema', + schema, + output, transactionalLogs = false, }: { workspace: Workspace diff --git a/packages/core/src/services/chains/run.test.ts b/packages/core/src/services/chains/run.test.ts index 2ede8bfc4..c22de6a25 100644 --- a/packages/core/src/services/chains/run.test.ts +++ b/packages/core/src/services/chains/run.test.ts @@ -87,7 +87,7 @@ describe('runChain', () => { expect.objectContaining({ config: { provider: 'openai', model: 'gpt-3.5-turbo' }, schema: undefined, - output: undefined, + output: 'no-schema', }), ) }) @@ -299,4 +299,200 @@ describe('runChain', () => { }), ) }) + + it('runs a chain with object schema and output', async () => { + const mockSchema = { + type: 'object', + properties: { + name: { type: 'string' }, + age: { type: 'number' }, + }, + } as const + + const mockAiResponse = { + object: Promise.resolve({ name: 'John', age: 30 }), + usage: Promise.resolve({ totalTokens: 15 }), + fullStream: new ReadableStream({ + start(controller) { + controller.enqueue({ + type: 'object', + object: { name: 'John', age: 30 }, + }) + controller.close() + }, + }), + } + + vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any) + + vi.mocked(mockChain.step!).mockResolvedValue({ + completed: true, + conversation: { + messages: [ + { + role: MessageRole.user, + content: [{ type: ContentType.text, text: 'Test message' }], + }, + ], + config: { provider: 'openai', model: 'gpt-3.5-turbo' }, + }, + }) + + const result = await runChain({ + workspace, + chain: mockChain as Chain, + apikeys, + source: LogSources.API, + configOverrides: { + schema: mockSchema, + output: 'object', + }, + }) + + expect(result.ok).toBe(true) + if (!result.ok) return + + const response = await result.value.response + expect(response).toEqual({ + documentLogUuid: expect.any(String), + object: { name: 'John', age: 30 }, + text: '{"name":"John","age":30}', + usage: { totalTokens: 15 }, + }) + + expect(aiModule.ai).toHaveBeenCalledWith( + expect.objectContaining({ + schema: mockSchema, + output: 'object', + }), + ) + }) + + it('runs a chain with array schema and output', async () => { + const mockSchema = { + type: 'array', + items: { + type: 'object', + properties: { + name: { type: 'string' }, + age: { type: 'number' }, + }, + }, + } as const + + const mockAiResponse = { + object: Promise.resolve([ + { name: 'John', age: 30 }, + { name: 'Jane', age: 25 }, + ]), + usage: Promise.resolve({ totalTokens: 20 }), + fullStream: new ReadableStream({ + start(controller) { + controller.enqueue({ + type: 'object', + object: [ + { name: 'John', age: 30 }, + { name: 'Jane', age: 25 }, + ], + }) + controller.close() + }, + }), + } + + vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any) + + vi.mocked(mockChain.step!).mockResolvedValue({ + completed: true, + conversation: { + messages: [ + { + role: MessageRole.user, + content: [{ type: ContentType.text, text: 'Test message' }], + }, + ], + config: { provider: 'openai', model: 'gpt-3.5-turbo' }, + }, + }) + + const result = await runChain({ + workspace, + chain: mockChain as Chain, + apikeys, + source: LogSources.API, + configOverrides: { + schema: mockSchema, + output: 'array', + }, + }) + + expect(result.ok).toBe(true) + if (!result.ok) return + + const response = await result.value.response + expect(response).toEqual({ + documentLogUuid: expect.any(String), + object: [ + { name: 'John', age: 30 }, + { name: 'Jane', age: 25 }, + ], + text: '[{"name":"John","age":30},{"name":"Jane","age":25}]', + usage: { totalTokens: 20 }, + }) + + expect(aiModule.ai).toHaveBeenCalledWith( + expect.objectContaining({ + schema: mockSchema, + output: 'array', + }), + ) + }) + + it('runs a chain with no-schema output', async () => { + const mockAiResponse = createMockAiResponse( + 'AI response without schema', + 10, + ) + vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any) + + vi.mocked(mockChain.step!).mockResolvedValue({ + completed: true, + conversation: { + messages: [ + { + role: MessageRole.user, + content: [{ type: ContentType.text, text: 'Test message' }], + }, + ], + config: { provider: 'openai', model: 'gpt-3.5-turbo' }, + }, + }) + + const result = await runChain({ + workspace, + chain: mockChain as Chain, + apikeys, + source: LogSources.API, + configOverrides: { + output: 'no-schema', + }, + }) + + expect(result.ok).toBe(true) + if (!result.ok) return + + const response = await result.value.response + expect(response).toEqual({ + documentLogUuid: expect.any(String), + text: 'AI response without schema', + usage: { totalTokens: 10 }, + toolCalls: [], + }) + + expect(aiModule.ai).toHaveBeenCalledWith( + expect.objectContaining({ + output: 'no-schema', + }), + ) + }) }) diff --git a/packages/core/src/services/chains/run.ts b/packages/core/src/services/chains/run.ts index 3ec3a9fa0..ff9577438 100644 --- a/packages/core/src/services/chains/run.ts +++ b/packages/core/src/services/chains/run.ts @@ -20,6 +20,12 @@ import { streamToGenerator } from '../../lib/streamToGenerator' import { ai, Config, validateConfig } from '../ai' export type CachedApiKeys = Map +type ConfigOverrides = + | { + schema: JSONSchema7 + output: 'object' | 'array' + } + | { output: 'no-schema' } export async function runChain({ workspace, @@ -34,10 +40,7 @@ export async function runChain({ generateUUID?: () => string source: LogSources apikeys: CachedApiKeys - configOverrides?: { - schema: JSONSchema7 - output: 'object' | 'array' | 'no-schema' - } + configOverrides?: ConfigOverrides }) { const documentLogUuid = generateUUID() @@ -96,13 +99,10 @@ async function iterate({ previousApiKey?: ProviderApiKey documentLogUuid: string previousResponse?: ChainTextResponse - configOverrides?: { - schema: JSONSchema7 - output: 'object' | 'array' | 'no-schema' - } + configOverrides?: ConfigOverrides }) { try { - const stepResult = await computeStepData({ + const step = await computeStepData({ chain, previousResponse, apikeys, @@ -110,26 +110,26 @@ async function iterate({ sentCount: previousCount, }) - publishStepStartEvent(controller, stepResult) + publishStepStartEvent(controller, step) const aiResult = await ai({ workspace, source, documentLogUuid, - messages: stepResult.conversation.messages, - config: stepResult.config, - provider: stepResult.apiKey, - schema: configOverrides?.schema, - output: configOverrides?.output, - transactionalLogs: stepResult.completed, + messages: step.conversation.messages, + config: step.config, + provider: step.apiKey, + schema: getSchemaForAI(step, configOverrides), + output: getOutputForAI(step, configOverrides), + transactionalLogs: step.completed, }) await streamAIResult(controller, aiResult) const response = await createChainResponse(aiResult, documentLogUuid) - if (stepResult.completed) { - await handleCompletedChain(controller, stepResult, response) + if (step.completed) { + await handleCompletedChain(controller, step, response) return response } else { publishStepCompleteEvent(controller, response) @@ -141,8 +141,8 @@ async function iterate({ documentLogUuid, apikeys, controller, - previousApiKey: stepResult.apiKey, - previousCount: stepResult.sentCount, + previousApiKey: step.apiKey, + previousCount: step.sentCount, previousResponse: response as ChainTextResponse, configOverrides, }) @@ -153,7 +153,24 @@ async function iterate({ } } -// Helper functions +function getSchemaForAI( + step: Awaited>, + configOverrides?: ConfigOverrides, +) { + return step.completed + ? // @ts-expect-error - schema does not exist in some types of configOverrides which is fine + configOverrides?.schema || step.config.schema + : undefined +} + +function getOutputForAI( + step: Awaited>, + configOverrides?: ConfigOverrides, +) { + return step.completed + ? configOverrides?.output || step.config.schema?.type || 'no-schema' + : undefined +} function publishStepStartEvent( controller: ReadableStreamDefaultController,