Skip to content

Commit

Permalink
[8.x] [inference] Add simulated function calling (#192544) (#193275)
Browse files Browse the repository at this point in the history
# Backport

This will backport the following commits from `main` to `8.x`:
- [[inference] Add simulated function calling
(#192544)](#192544)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Pierre
Gayvallet","email":"[email protected]"},"sourceCommit":{"committedDate":"2024-09-18T10:42:28Z","message":"[inference]
Add simulated function calling (#192544)\n\n## Summary\r\n\r\nAdd
simulated function calling to the inference plugin. For now, only\r\nthe
openAI adapter is supported. This is done by adding a new,
optional\r\n`functionCalling` parameter to the chat and task
APIs\r\n\r\nImplementation was adapted from the equivalent feature in
the o11y\r\nassistant.\r\n\r\n---------\r\n\r\nCo-authored-by: Dario
Gieselaar
<[email protected]>","sha":"181d61723136084ec57801fab1cc99457c047977","branchLabelMapping":{"^v9.0.0$":"main","^v8.16.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","backport:prev-minor","Team:Obs
AI Assistant","ci:project-deploy-observability","v8.16.0","Team:AI
Infra"],"title":"[inference] Add simulated function calling
","number":192544,"url":"https://github.com/elastic/kibana/pull/192544","mergeCommit":{"message":"[inference]
Add simulated function calling (#192544)\n\n## Summary\r\n\r\nAdd
simulated function calling to the inference plugin. For now, only\r\nthe
openAI adapter is supported. This is done by adding a new,
optional\r\n`functionCalling` parameter to the chat and task
APIs\r\n\r\nImplementation was adapted from the equivalent feature in
the o11y\r\nassistant.\r\n\r\n---------\r\n\r\nCo-authored-by: Dario
Gieselaar
<[email protected]>","sha":"181d61723136084ec57801fab1cc99457c047977"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/192544","number":192544,"mergeCommit":{"message":"[inference]
Add simulated function calling (#192544)\n\n## Summary\r\n\r\nAdd
simulated function calling to the inference plugin. For now, only\r\nthe
openAI adapter is supported. This is done by adding a new,
optional\r\n`functionCalling` parameter to the chat and task
APIs\r\n\r\nImplementation was adapted from the equivalent feature in
the o11y\r\nassistant.\r\n\r\n---------\r\n\r\nCo-authored-by: Dario
Gieselaar
<[email protected]>","sha":"181d61723136084ec57801fab1cc99457c047977"}},{"branch":"8.x","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Pierre Gayvallet <[email protected]>
  • Loading branch information
kibanamachine and pgayvallet authored Sep 18, 2024
1 parent 44abd1a commit c8043c9
Show file tree
Hide file tree
Showing 32 changed files with 472 additions and 33 deletions.
3 changes: 3 additions & 0 deletions x-pack/plugins/inference/common/chat_complete/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ export type ChatCompletionEvent<TToolOptions extends ToolOptions = ToolOptions>
| ChatCompletionTokenCountEvent
| ChatCompletionMessageEvent<TToolOptions>;

export type FunctionCallingMode = 'native' | 'simulated';

/**
* Request a completion from the LLM based on a prompt or conversation.
*
Expand All @@ -92,5 +94,6 @@ export type ChatCompleteAPI = <TToolOptions extends ToolOptions = ToolOptions>(
connectorId: string;
system?: string;
messages: Message[];
functionCalling?: FunctionCallingMode;
} & TToolOptions
) => ChatCompletionResponse<TToolOptions>;
5 changes: 3 additions & 2 deletions x-pack/plugins/inference/common/chat_complete/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
* 2.0.
*/

import type { Message } from '.';
import { ToolOptions } from './tools';
import type { Message, FunctionCallingMode } from '.';
import type { ToolOptions } from './tools';

export type ChatCompleteRequestBody = {
connectorId: string;
stream?: boolean;
system?: string;
messages: Message[];
functionCalling?: FunctionCallingMode;
} & ToolOptions;
7 changes: 4 additions & 3 deletions x-pack/plugins/inference/common/output/create_output_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import { OutputAPI, OutputEvent, OutputEventType } from '.';
import { ensureMultiTurn } from '../ensure_multi_turn';

export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
return (id, { connectorId, input, schema, system, previousMessages }) => {
return (id, { connectorId, input, schema, system, previousMessages, functionCalling }) => {
return chatCompleteApi({
connectorId,
system,
functionCalling,
messages: ensureMultiTurn([
...(previousMessages || []),
{
Expand All @@ -26,12 +27,12 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI {
...(schema
? {
tools: {
output: {
structuredOutput: {
description: `Use the following schema to respond to the user's request in structured data, so it can be parsed and handled.`,
schema,
},
},
toolChoice: { function: 'output' as const },
toolChoice: { function: 'structuredOutput' as const },
}
: {}),
}).pipe(
Expand Down
3 changes: 2 additions & 1 deletion x-pack/plugins/inference/common/output/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import { Observable } from 'rxjs';
import { ServerSentEventBase } from '@kbn/sse-utils';
import { FromToolSchema, ToolSchema } from '../chat_complete/tool_schema';
import { Message } from '../chat_complete';
import type { Message, FunctionCallingMode } from '../chat_complete';

export enum OutputEventType {
OutputUpdate = 'output',
Expand Down Expand Up @@ -61,6 +61,7 @@ export type OutputAPI = <
input: string;
schema?: TOutputSchema;
previousMessages?: Message[];
functionCalling?: FunctionCallingMode;
}
) => Observable<
OutputEvent<TId, TOutputSchema extends ToolSchema ? FromToolSchema<TOutputSchema> : undefined>
Expand Down
3 changes: 2 additions & 1 deletion x-pack/plugins/inference/public/chat_complete/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ import type { ChatCompleteRequestBody } from '../../common/chat_complete/request
import { httpResponseIntoObservable } from '../util/http_response_into_observable';

export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI {
return ({ connectorId, messages, system, toolChoice, tools }) => {
return ({ connectorId, messages, system, toolChoice, tools, functionCalling }) => {
const body: ChatCompleteRequestBody = {
connectorId,
system,
messages,
toolChoice,
tools,
functionCalling,
};

return from(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import type {
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
} from 'openai/resources';
import { filter, from, map, switchMap, tap, throwError } from 'rxjs';
import { filter, from, map, switchMap, tap, throwError, identity } from 'rxjs';
import { Readable, isReadable } from 'stream';
import {
ChatCompletionChunkEvent,
Expand All @@ -26,18 +26,38 @@ import { createTokenLimitReachedError } from '../../../../common/chat_complete/e
import { createInferenceInternalError } from '../../../../common/errors';
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
import type { InferenceConnectorAdapter } from '../../types';
import {
wrapWithSimulatedFunctionCalling,
parseInlineFunctionCalls,
} from '../../simulated_function_calling';

export const openAIAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools }) => {
chatComplete: ({ executor, system, messages, toolChoice, tools, functionCalling, logger }) => {
const stream = true;
const simulatedFunctionCalling = functionCalling === 'simulated';

const request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string } = {
stream,
messages: messagesToOpenAI({ system, messages }),
tool_choice: toolChoiceToOpenAI(toolChoice),
tools: toolsToOpenAI(tools),
temperature: 0,
};
let request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string };
if (simulatedFunctionCalling) {
const wrapped = wrapWithSimulatedFunctionCalling({
system,
messages,
toolChoice,
tools,
});
request = {
stream,
messages: messagesToOpenAI({ system: wrapped.system, messages: wrapped.messages }),
temperature: 0,
};
} else {
request = {
stream,
messages: messagesToOpenAI({ system, messages }),
tool_choice: toolChoiceToOpenAI(toolChoice),
tools: toolsToOpenAI(tools),
temperature: 0,
};
}

return from(
executor.invoke({
Expand Down Expand Up @@ -94,7 +114,8 @@ export const openAIAdapter: InferenceConnectorAdapter = {
};
}) ?? [],
};
})
}),
simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity
);
},
};
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugins/inference/server/chat_complete/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export function createChatCompleteApi({
toolChoice,
tools,
system,
functionCalling,
}): ChatCompletionResponse => {
return defer(async () => {
const actionsClient = await actions.getActionsClientWithRequest(request);
Expand Down Expand Up @@ -58,6 +59,7 @@ export function createChatCompleteApi({
toolChoice,
tools,
logger,
functionCalling,
});
}),
chunksIntoMessage({
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export const TOOL_USE_START = '<|tool_use_start|>';
export const TOOL_USE_END = '<|tool_use_end|>';
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { TOOL_USE_END, TOOL_USE_START } from './constants';
import { ToolDefinition } from '../../../common/chat_complete/tools';

export function getSystemMessageInstructions({
tools,
}: {
tools?: Record<string, ToolDefinition>;
}) {
const formattedTools = Object.entries(tools ?? {}).map(([name, tool]) => {
return {
name,
...tool,
};
});

if (formattedTools.length) {
return `In this environment, you have access to a set of tools you can use to answer the user's question.
DO NOT call a tool when it is not listed.
ONLY define input that is defined in the tool properties.
If a tool does not have properties, leave them out.
It is EXTREMELY important that you generate valid JSON between the \`\`\`json and \`\`\` delimiters.
You may call them like this.
Given the following tool:
${JSON.stringify({
name: 'my_tool',
description: 'A tool to call',
schema: {
type: 'object',
properties: {
myProperty: {
type: 'string',
},
},
},
})}
Use it the following way:
${TOOL_USE_START}
\`\`\`json
${JSON.stringify({ name: 'my_tool', input: { myProperty: 'myValue' } })}
\`\`\`\
${TOOL_USE_END}
Given the following tool:
${JSON.stringify({
name: 'my_tool_without_parameters',
description: 'A tool to call without parameters',
})}
Use it the following way:
${TOOL_USE_START}
\`\`\`json
${JSON.stringify({ name: 'my_tool_without_parameters', input: {} })}
\`\`\`\
${TOOL_USE_END}
Here are the tools available:
${JSON.stringify(
formattedTools.map((tool) => ({
name: tool.name,
description: tool.description,
...(tool.schema ? { schema: tool.schema } : {}),
}))
)}
`;
}

return `No tools are available anymore. DO NOT UNDER ANY CIRCUMSTANCES call any tool, regardless of whether it was previously called.`;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export { wrapWithSimulatedFunctionCalling } from './wrap_with_simulated_function_calling';
export { parseInlineFunctionCalls } from './parse_inline_function_calls';
Loading

0 comments on commit c8043c9

Please sign in to comment.