-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
NovacloudBot
committed
Jul 26, 2024
1 parent
f41b030
commit c9732b0
Showing
5 changed files
with
305 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import { describe, expect, it } from 'vitest' | ||
|
||
import OpenAI from 'openai' | ||
import { z } from 'zod' | ||
import { createChatCompletionFunction } from '../function' | ||
import { completionWithFunctions } from './completion-with-functions' | ||
|
||
describe( | ||
'Completion with Functions', | ||
() => { | ||
it('should run completion with functions', async () => { | ||
const client = new OpenAI() | ||
const res = await completionWithFunctions({ | ||
client, | ||
instructions: `Call test function foo and return it's value in fooResponse`, | ||
prompt: `blah is '123'`, | ||
functions: [ | ||
createChatCompletionFunction({ | ||
name: 'test', | ||
description: 'test function', | ||
parameters: z.object({ blah: z.string() }), | ||
handler: async ({ blah }) => { | ||
return `hello ${blah}` | ||
} | ||
}) | ||
] | ||
}) | ||
|
||
expect(res.content).toEqual('hello 123') | ||
}) | ||
}, | ||
{ concurrent: true } | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import type { ChatCompletionFunction } from '../function' | ||
import OpenAI from 'openai' | ||
import type { z } from 'zod' | ||
|
||
import { | ||
ChatCompletionMessageParam, | ||
ChatCompletionTool | ||
} from 'openai/resources' | ||
import zodToJsonSchema from 'zod-to-json-schema' | ||
import { | ||
ChatCompletionCreateParamsBase, | ||
ChatCompletionMessage, | ||
ChatCompletionMessageToolCall | ||
} from 'openai/resources/chat/completions' | ||
|
||
type CompletionOpts = Partial< | ||
Omit<ChatCompletionCreateParamsBase, 'functions' | 'tools'> | ||
> & { | ||
client: OpenAI | ||
// options?: Partial<OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming> | ||
instructions: string | ||
prompt?: string | ||
messages?: ChatCompletionMessageParam[] | ||
} | ||
|
||
export type CompletionOptsWithFunctionOpts = CompletionOpts & { | ||
functions?: ChatCompletionFunction[] | ||
parallelFunctionExecution?: false | ||
} | ||
|
||
export const functionToOpenAIChatCompletionTool = <T extends z.ZodRawShape>( | ||
fn: ChatCompletionFunction<T> | ||
): ChatCompletionTool => { | ||
const params = fn.parameters ? zodToJsonSchema(fn.parameters) : undefined | ||
return { | ||
type: 'function', | ||
function: { | ||
name: fn.name, | ||
description: fn.description, | ||
parameters: params | ||
} | ||
} | ||
} | ||
|
||
export const completionWithFunctions = async ( | ||
opts: CompletionOptsWithFunctionOpts | ||
): Promise<ChatCompletionMessage> => { | ||
const { | ||
client, | ||
instructions, | ||
prompt, | ||
functions, | ||
parallelFunctionExecution: parallelToolCalls, | ||
messages, | ||
model, | ||
...rest | ||
} = opts | ||
|
||
// initialize messages | ||
const _messages: ChatCompletionMessageParam[] = messages ?? [ | ||
{ role: 'system', content: instructions } | ||
] | ||
if (prompt) { | ||
_messages.push({ role: 'user', content: prompt }) | ||
} | ||
|
||
const response = await client.chat.completions.create({ | ||
Check failure on line 67 in src/completions/completion-with-functions.ts GitHub Actions / build (20.x)src/completions/completion-with-functions.spec.ts > Completion with Functions > should run completion with functions
Check failure on line 67 in src/completions/completion-with-functions.ts GitHub Actions / build (20.x)src/completions/completion-with-json.spec.ts > Completion with JSON response > should run completion with JSON response
|
||
model: model, | ||
messages: _messages, | ||
tools: functions?.map(functionToOpenAIChatCompletionTool), | ||
...rest, | ||
stream: false | ||
}) | ||
|
||
let message = response?.choices?.[0]?.message | ||
|
||
const handleToolCall = async (toolCall: ChatCompletionMessageToolCall) => { | ||
try { | ||
const fn = functions?.find((f) => f.name === toolCall.function.name) | ||
if (!fn) { | ||
throw new Error( | ||
`Function ${toolCall.function.name} not found in functions: [${functions?.map((f) => f.name).join(', ')}]` | ||
) | ||
} | ||
const output = await fn.handler(JSON.parse(toolCall.function.arguments)) | ||
return { | ||
tool_call_id: toolCall.id, | ||
output | ||
} | ||
} catch (e) { | ||
return { | ||
tool_call_id: toolCall.id, | ||
output: `Failed with error: ${e}` | ||
} | ||
} | ||
} | ||
|
||
if (message?.tool_calls) { | ||
let toolCallResults: { | ||
tool_call_id: string | ||
output: string | ||
}[] = [] | ||
if (parallelToolCalls === false) { | ||
for (const toolCall of message?.tool_calls) { | ||
const res = await handleToolCall(toolCall) | ||
toolCallResults.push(res) | ||
} | ||
} else { | ||
toolCallResults = await Promise.all( | ||
message?.tool_calls.map(handleToolCall) | ||
) | ||
} | ||
_messages.push(message) | ||
for (const res of toolCallResults) { | ||
_messages.push({ | ||
tool_call_id: res.tool_call_id, | ||
role: 'tool', | ||
content: res.output | ||
}) | ||
} | ||
return completionWithFunctions({ | ||
...opts, | ||
messages: _messages, | ||
prompt: undefined | ||
}) | ||
} | ||
|
||
if (message) { | ||
return message | ||
} else { | ||
throw new Error('Invalid response (empty message)') | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import { describe, expect, it } from 'vitest' | ||
|
||
import OpenAI from 'openai' | ||
import { z } from 'zod' | ||
import { completionWithJsonResponse } from './completion-with-json' | ||
import { createChatCompletionFunction } from '../function' | ||
|
||
describe( | ||
'Completion with JSON response', | ||
() => { | ||
it('should run completion with JSON response', async () => { | ||
const client = new OpenAI() | ||
const res = await completionWithJsonResponse({ | ||
client, | ||
instructions: `Call test function foo and return it's value in fooResponse`, | ||
prompt: `blah is '123'`, | ||
responseObject: z.object({ | ||
fooResponse: z.string() | ||
}), | ||
functions: [ | ||
createChatCompletionFunction({ | ||
name: 'test', | ||
description: 'test function', | ||
parameters: z.object({ blah: z.string() }), | ||
handler: async ({ blah }) => { | ||
return `hello ${blah}` | ||
} | ||
}) | ||
] | ||
}) | ||
expect(res.fooResponse).toEqual('hello 123') | ||
}) | ||
}, | ||
{ concurrent: true } | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import type { z } from 'zod' | ||
import type { ChatCompletionFunction } from '../function' | ||
|
||
import { ChatCompletionTool } from 'openai/resources' | ||
import zodToJsonSchema from 'zod-to-json-schema' | ||
import { | ||
CompletionOptsWithFunctionOpts, | ||
completionWithFunctions | ||
} from './completion-with-functions' | ||
|
||
export type CompletionOptsWithJsonResponse<T extends z.ZodRawShape> = | ||
CompletionOptsWithFunctionOpts & { | ||
responseObject: z.ZodObject<T> | ||
} | ||
|
||
export const functionToOpenAIChatCompletionTool = <T extends z.ZodRawShape>( | ||
fn: ChatCompletionFunction<T> | ||
): ChatCompletionTool => { | ||
const params = fn.parameters ? zodToJsonSchema(fn.parameters) : undefined | ||
return { | ||
type: 'function', | ||
function: { | ||
name: fn.name, | ||
description: fn.description, | ||
parameters: params | ||
} | ||
} | ||
} | ||
|
||
export const completionWithJsonResponse = async <T extends z.ZodRawShape>( | ||
opts: CompletionOptsWithJsonResponse<T> | ||
): Promise<z.infer<z.ZodObject<T>>> => { | ||
const { responseObject, prompt, ...rest } = opts | ||
const responseObjectSchema = JSON.stringify( | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
zodToJsonSchema(responseObject) | ||
) | ||
|
||
const _prompt = `Output JSON must be single object (only one JSON object) conforming to the following JsonSchema7:\n${responseObjectSchema}\n\n${prompt ? `${prompt}\n\n` : ''}\n` | ||
const res = await completionWithFunctions({ | ||
...rest, | ||
response_format: { type: 'json_object' }, | ||
prompt: _prompt | ||
}) | ||
|
||
if (!res.content) { | ||
throw new Error('Invalid response (null)') | ||
} | ||
|
||
try { | ||
const content = res.content.replace(/^```json\n/, '').replace(/```$/, '') | ||
let parsedContent = JSON.parse(content) | ||
if (parsedContent.$schema && parsedContent.properties) { | ||
parsedContent = parsedContent.properties | ||
} | ||
const parsed = responseObject.parse(parsedContent) | ||
return parsed | ||
} catch (err) { | ||
throw new Error(`Failed to parse response: ${err}, json: '${res.content}'`) | ||
} | ||
} | ||
|
||
export const completionWithJsonResponseWithRetry = async < | ||
T extends z.ZodRawShape | ||
>( | ||
props: CompletionOptsWithJsonResponse<T>, | ||
retryCount = 2 | ||
): Promise<z.infer<z.ZodObject<T>>> => { | ||
let latestErr: Error | undefined | ||
try { | ||
return await completionWithJsonResponse(props) | ||
} catch (err) { | ||
latestErr = err as Error | ||
if (retryCount <= 0) { | ||
return await completionWithJsonResponseWithRetry( | ||
{ | ||
...props, | ||
response_format: { type: 'json_object' }, | ||
messages: [ | ||
...(props.messages ?? []), | ||
{ | ||
role: 'user', | ||
content: [ | ||
{ | ||
type: 'text', | ||
text: `Your latest reply contains following error:\n\`${err}\`` | ||
} | ||
] | ||
} | ||
] | ||
}, | ||
retryCount - 1 | ||
) | ||
} | ||
} | ||
throw new Error(`Max retries reached. Last error: ${latestErr}`) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,13 @@ | ||
export { completionWithFunctions } from './completions/completion-with-functions' | ||
export { | ||
completionWithJsonResponse, | ||
completionWithJsonResponseWithRetry | ||
} from './completions/completion-with-json' | ||
|
||
export { Assistant, AssistantOpts } from './assistant' | ||
export { promptWithPick } from './chains/prompt-with-pick' | ||
export { promptWithRetry } from './chains/prompt-with-retry' | ||
export { createChatCompletionFunction } from './function' | ||
export { createOpenAIClient, getDefaultOpenAIClient } from './openai-client' | ||
|
||
export { Thread, ThreadPromptWithFunctionOpts } from './thread' |