Skip to content

Commit

Permalink
Support GPT parallel function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Jan 24, 2024
1 parent 4055482 commit f1b90b1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 22 deletions.
12 changes: 4 additions & 8 deletions packages/app/src/components/nodes/ChatNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export const ChatNodeOutput: FC<{
const costAll = coerceTypeOptional(outputs['cost' as PortId], 'number[]') ?? [];
const durationAll = coerceTypeOptional(outputs['duration' as PortId], 'number[]') ?? [];

const functionCallOutput = outputs['function-call' as PortId];
const functionCallOutput = outputs['function-call' as PortId] ?? outputs['function-calls' as PortId];
const functionCallAll =
functionCallOutput?.type === 'object[]'
? functionCallOutput.value
Expand Down Expand Up @@ -79,19 +79,15 @@ export const ChatNodeOutput: FC<{
const cost = coerceTypeOptional(outputs['cost' as PortId], 'number');
const duration = coerceTypeOptional(outputs['duration' as PortId], 'number');

const functionCallOutput = outputs['function-call' as PortId];
const functionCall =
functionCallOutput?.type === 'object'
? functionCallOutput.value
: coerceTypeOptional(functionCallOutput, 'string');
const functionCallOutput = outputs['function-call' as PortId] ?? outputs['function-calls' as PortId];

return (
<ChatNodeOutputSingle
outputText={outputText}
requestTokens={requestTokens}
responseTokens={responseTokens}
cost={cost}
functionCall={functionCall}
functionCall={functionCallOutput?.value as object}
duration={duration}
fullscreen={fullscreen}
renderMarkdown={renderMarkdown}
Expand Down Expand Up @@ -182,7 +178,7 @@ export const ChatNodeOutputSingle: FC<{
</div>
{functionCall && (
<div className="function-call">
<h4>Function Call:</h4>
<h4>{Array.isArray(functionCall) ? 'Function Calls' : 'Function Call'}:</h4>
<div className="pre-wrap">
<RenderDataValue value={inferType(functionCall)} />
</div>
Expand Down
58 changes: 44 additions & 14 deletions packages/core/src/model/nodes/ChatNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export type ChatNodeConfigData = {
toolChoice?: 'none' | 'auto' | 'function';
toolChoiceFunction?: string;
responseFormat?: 'text' | 'json';
parallelFunctionCalling?: boolean;
};

export type ChatNodeData = ChatNodeConfigData & {
Expand Down Expand Up @@ -128,6 +129,8 @@ export class ChatNodeImpl extends NodeImpl<ChatNode> {

cache: false,
useAsGraphPartialOutput: true,

parallelFunctionCalling: true,
},
};

Expand Down Expand Up @@ -342,12 +345,21 @@ export class ChatNodeImpl extends NodeImpl<ChatNode> {
}

if (this.data.enableFunctionUse) {
outputs.push({
dataType: 'object',
id: 'function-call' as PortId,
title: 'Function Call',
description: 'The function call that was made, if any.',
});
if (this.data.parallelFunctionCalling) {
outputs.push({
dataType: 'object[]',
id: 'function-calls' as PortId,
title: 'Function Calls',
description: 'The function calls that were made, if any.',
});
} else {
outputs.push({
dataType: 'object',
id: 'function-call' as PortId,
title: 'Function Call',
description: 'The function call that was made, if any.',
});
}
}

outputs.push({
Expand Down Expand Up @@ -515,6 +527,12 @@ export class ChatNodeImpl extends NodeImpl<ChatNode> {
label: 'Enable Function Use',
dataKey: 'enableFunctionUse',
},
{
type: 'toggle',
label: 'Enable Parallel Function Calling',
dataKey: 'parallelFunctionCalling',
hideIf: (data) => !data.enableFunctionUse,
},
{
type: 'dropdown',
label: 'Tool Choice',
Expand Down Expand Up @@ -897,14 +915,26 @@ export class ChatNodeImpl extends NodeImpl<ChatNode> {
})),
};
} else {
output['function-call' as PortId] = {
type: 'object',
value: {
name: functionCalls[0]![0]?.name,
arguments: functionCalls[0]![0]?.lastParsedArguments,
id: functionCalls[0]![0]?.id,
} as Record<string, unknown>,
};
if (this.data.parallelFunctionCalling) {
console.dir({ functionCalls });
output['function-calls' as PortId] = {
type: 'object[]',
value: functionCalls[0]!.map((functionCall) => ({
name: functionCall.name,
arguments: functionCall.lastParsedArguments,
id: functionCall.id,
})),
};
} else {
output['function-call' as PortId] = {
type: 'object',
value: {
name: functionCalls[0]![0]?.name,
arguments: functionCalls[0]![0]?.lastParsedArguments,
id: functionCalls[0]![0]?.id,
} as Record<string, unknown>,
};
}
}
}

Expand Down

0 comments on commit f1b90b1

Please sign in to comment.