Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: only add json schema to chain last step #218

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/core/src/services/ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ export async function ai({
documentLogUuid,
source,
onFinish,
schema = config.schema,
output = config.schema?.type || 'no-schema',
schema,
output,
transactionalLogs = false,
}: {
workspace: Workspace
Expand Down
198 changes: 197 additions & 1 deletion packages/core/src/services/chains/run.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ describe('runChain', () => {
expect.objectContaining({
config: { provider: 'openai', model: 'gpt-3.5-turbo' },
schema: undefined,
output: undefined,
output: 'no-schema',
}),
)
})
Expand Down Expand Up @@ -299,4 +299,200 @@ describe('runChain', () => {
}),
)
})

it('runs a chain with object schema and output', async () => {
const mockSchema = {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
},
} as const

const mockAiResponse = {
object: Promise.resolve({ name: 'John', age: 30 }),
usage: Promise.resolve({ totalTokens: 15 }),
fullStream: new ReadableStream({
start(controller) {
controller.enqueue({
type: 'object',
object: { name: 'John', age: 30 },
})
controller.close()
},
}),
}

vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any)

vi.mocked(mockChain.step!).mockResolvedValue({
completed: true,
conversation: {
messages: [
{
role: MessageRole.user,
content: [{ type: ContentType.text, text: 'Test message' }],
},
],
config: { provider: 'openai', model: 'gpt-3.5-turbo' },
},
})

const result = await runChain({
workspace,
chain: mockChain as Chain,
apikeys,
source: LogSources.API,
configOverrides: {
schema: mockSchema,
output: 'object',
},
})

expect(result.ok).toBe(true)
if (!result.ok) return

const response = await result.value.response
expect(response).toEqual({
documentLogUuid: expect.any(String),
object: { name: 'John', age: 30 },
text: '{"name":"John","age":30}',
usage: { totalTokens: 15 },
})

expect(aiModule.ai).toHaveBeenCalledWith(
expect.objectContaining({
schema: mockSchema,
output: 'object',
}),
)
})

it('runs a chain with array schema and output', async () => {
const mockSchema = {
type: 'array',
items: {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
},
},
} as const

const mockAiResponse = {
object: Promise.resolve([
{ name: 'John', age: 30 },
{ name: 'Jane', age: 25 },
]),
usage: Promise.resolve({ totalTokens: 20 }),
fullStream: new ReadableStream({
start(controller) {
controller.enqueue({
type: 'object',
object: [
{ name: 'John', age: 30 },
{ name: 'Jane', age: 25 },
],
})
controller.close()
},
}),
}

vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any)

vi.mocked(mockChain.step!).mockResolvedValue({
completed: true,
conversation: {
messages: [
{
role: MessageRole.user,
content: [{ type: ContentType.text, text: 'Test message' }],
},
],
config: { provider: 'openai', model: 'gpt-3.5-turbo' },
},
})

const result = await runChain({
workspace,
chain: mockChain as Chain,
apikeys,
source: LogSources.API,
configOverrides: {
schema: mockSchema,
output: 'array',
},
})

expect(result.ok).toBe(true)
if (!result.ok) return

const response = await result.value.response
expect(response).toEqual({
documentLogUuid: expect.any(String),
object: [
{ name: 'John', age: 30 },
{ name: 'Jane', age: 25 },
],
text: '[{"name":"John","age":30},{"name":"Jane","age":25}]',
usage: { totalTokens: 20 },
})

expect(aiModule.ai).toHaveBeenCalledWith(
expect.objectContaining({
schema: mockSchema,
output: 'array',
}),
)
})

it('runs a chain with no-schema output', async () => {
const mockAiResponse = createMockAiResponse(
'AI response without schema',
10,
)
vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any)

vi.mocked(mockChain.step!).mockResolvedValue({
completed: true,
conversation: {
messages: [
{
role: MessageRole.user,
content: [{ type: ContentType.text, text: 'Test message' }],
},
],
config: { provider: 'openai', model: 'gpt-3.5-turbo' },
},
})

const result = await runChain({
workspace,
chain: mockChain as Chain,
apikeys,
source: LogSources.API,
configOverrides: {
output: 'no-schema',
},
})

expect(result.ok).toBe(true)
if (!result.ok) return

const response = await result.value.response
expect(response).toEqual({
documentLogUuid: expect.any(String),
text: 'AI response without schema',
usage: { totalTokens: 10 },
toolCalls: [],
})

expect(aiModule.ai).toHaveBeenCalledWith(
expect.objectContaining({
output: 'no-schema',
}),
)
})
})
59 changes: 38 additions & 21 deletions packages/core/src/services/chains/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ import { streamToGenerator } from '../../lib/streamToGenerator'
import { ai, Config, validateConfig } from '../ai'

export type CachedApiKeys = Map<string, ProviderApiKey>
type ConfigOverrides =
| {
schema: JSONSchema7
output: 'object' | 'array'
}
| { output: 'no-schema' }

export async function runChain({
workspace,
Expand All @@ -34,10 +40,7 @@ export async function runChain({
generateUUID?: () => string
source: LogSources
apikeys: CachedApiKeys
configOverrides?: {
schema: JSONSchema7
output: 'object' | 'array' | 'no-schema'
}
configOverrides?: ConfigOverrides
}) {
const documentLogUuid = generateUUID()

Expand Down Expand Up @@ -96,40 +99,37 @@ async function iterate({
previousApiKey?: ProviderApiKey
documentLogUuid: string
previousResponse?: ChainTextResponse
configOverrides?: {
schema: JSONSchema7
output: 'object' | 'array' | 'no-schema'
}
configOverrides?: ConfigOverrides
}) {
try {
const stepResult = await computeStepData({
const step = await computeStepData({
chain,
previousResponse,
apikeys,
apiKey: previousApiKey,
sentCount: previousCount,
})

publishStepStartEvent(controller, stepResult)
publishStepStartEvent(controller, step)

const aiResult = await ai({
workspace,
source,
documentLogUuid,
messages: stepResult.conversation.messages,
config: stepResult.config,
provider: stepResult.apiKey,
schema: configOverrides?.schema,
output: configOverrides?.output,
transactionalLogs: stepResult.completed,
messages: step.conversation.messages,
config: step.config,
provider: step.apiKey,
schema: getSchemaForAI(step, configOverrides),
output: getOutputForAI(step, configOverrides),
transactionalLogs: step.completed,
})

await streamAIResult(controller, aiResult)

const response = await createChainResponse(aiResult, documentLogUuid)

if (stepResult.completed) {
await handleCompletedChain(controller, stepResult, response)
if (step.completed) {
await handleCompletedChain(controller, step, response)
return response
} else {
publishStepCompleteEvent(controller, response)
Expand All @@ -141,8 +141,8 @@ async function iterate({
documentLogUuid,
apikeys,
controller,
previousApiKey: stepResult.apiKey,
previousCount: stepResult.sentCount,
previousApiKey: step.apiKey,
previousCount: step.sentCount,
previousResponse: response as ChainTextResponse,
configOverrides,
})
Expand All @@ -153,7 +153,24 @@ async function iterate({
}
}

// Helper functions
function getSchemaForAI(
step: Awaited<ReturnType<typeof computeStepData>>,
configOverrides?: ConfigOverrides,
) {
return step.completed
? // @ts-expect-error - schema does not exist in some types of configOverrides which is fine
configOverrides?.schema || step.config.schema
: undefined
}

function getOutputForAI(
step: Awaited<ReturnType<typeof computeStepData>>,
configOverrides?: ConfigOverrides,
) {
return step.completed
? configOverrides?.output || step.config.schema?.type || 'no-schema'
: undefined
}

function publishStepStartEvent(
controller: ReadableStreamDefaultController,
Expand Down
Loading