From 6eaa1d0633f20c9c45cff2b758da5881c0922f3e Mon Sep 17 00:00:00 2001 From: Yuliia Naumenko Date: Tue, 17 Dec 2024 21:13:10 -0800 Subject: [PATCH] [AI Connector] Change completion subAction schema to be OpenAI compatible (#200249) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … ## Summary Summarize your PR. If it involves visual changes include a screenshot or gif. ### Checklist Check the PR satisfies following conditions. Reviewers should verify this PR satisfies this list as well. - [ ] Any text added follows [EUI's writing guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses sentence case text and includes [i18n support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md) - [ ] [Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html) was added for features that require explanation or tutorials - [ ] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios - [ ] If a plugin configuration key changed, check if it needs to be allowlisted in the cloud and added to the [docker list](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker) - [ ] This was checked for breaking HTTP API changes, and any breaking changes have been approved by the breaking-change committee. The `release_note:breaking` label should be applied in these situations. - [ ] [Flaky Test Runner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was used on any tests changed - [ ] The PR description includes the appropriate Release Notes section, and the correct `release_node:*` label is applied per the [guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process) ### Identify risks Does this PR introduce any risks? For example, consider risks like hard to test bugs, performance regression, potential of data loss. Describe the risk, its severity, and mitigation for each identified risk. Invite stakeholders and evaluate how to proceed before merging. - [ ] [See some risk examples](https://github.com/elastic/kibana/blob/main/RISK_MATRIX.mdx) - [ ] ... --- .../server/lib/gen_ai_token_tracking.ts | 5 +- .../common/inference/constants.ts | 3 + .../common/inference/schema.ts | 179 ++++++++++++++ .../common/inference/types.ts | 10 + .../connector_types/inference/constants.tsx | 16 +- .../inference/inference.test.tsx | 32 ++- .../connector_types/inference/inference.tsx | 22 +- .../connector_types/inference/params.test.tsx | 25 +- .../connector_types/inference/params.tsx | 76 +++++- .../public/connector_types/inference/types.ts | 7 + .../connector_types/inference/helpers.ts | 111 +++++++++ .../server/connector_types/inference/index.ts | 5 +- .../inference/inference.test.ts | 103 +++++--- .../connector_types/inference/inference.ts | 224 ++++++++++++++---- 14 files changed, 721 insertions(+), 97 deletions(-) create mode 100644 x-pack/plugins/stack_connectors/server/connector_types/inference/helpers.ts diff --git a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts index d73610892098d..ff73095ac2427 100644 --- a/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts +++ b/x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts @@ -295,4 +295,7 @@ export const getGenAiTokenTracking = async ({ }; export const shouldTrackGenAiToken = (actionTypeId: string) => - actionTypeId === '.gen-ai' || actionTypeId === '.bedrock' || actionTypeId === '.gemini'; + actionTypeId === '.gen-ai' || + actionTypeId === '.bedrock' || + actionTypeId === '.gemini' || + actionTypeId === '.inference'; diff --git a/x-pack/plugins/stack_connectors/common/inference/constants.ts b/x-pack/plugins/stack_connectors/common/inference/constants.ts index b795e54f5d32a..c2f2e6713270b 100644 --- a/x-pack/plugins/stack_connectors/common/inference/constants.ts +++ b/x-pack/plugins/stack_connectors/common/inference/constants.ts @@ -32,6 +32,9 @@ export enum ServiceProviderKeys { export const INFERENCE_CONNECTOR_ID = '.inference'; export enum SUB_ACTION { + UNIFIED_COMPLETION_ASYNC_ITERATOR = 'unified_completion_async_iterator', + UNIFIED_COMPLETION_STREAM = 'unified_completion_stream', + UNIFIED_COMPLETION = 'unified_completion', COMPLETION = 'completion', RERANK = 'rerank', TEXT_EMBEDDING = 'text_embedding', diff --git a/x-pack/plugins/stack_connectors/common/inference/schema.ts b/x-pack/plugins/stack_connectors/common/inference/schema.ts index 07b51cf9a5aa3..c62e9782bb517 100644 --- a/x-pack/plugins/stack_connectors/common/inference/schema.ts +++ b/x-pack/plugins/stack_connectors/common/inference/schema.ts @@ -23,6 +23,176 @@ export const ChatCompleteParamsSchema = schema.object({ input: schema.string(), }); +// subset of OpenAI.ChatCompletionMessageParam https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts +const AIMessage = schema.object({ + role: schema.string(), + content: schema.maybe(schema.string()), + name: schema.maybe(schema.string()), + tool_calls: schema.maybe( + schema.arrayOf( + schema.object({ + id: schema.string(), + function: schema.object({ + arguments: schema.maybe(schema.string()), + name: schema.maybe(schema.string()), + }), + type: schema.string(), + }) + ) + ), + tool_call_id: schema.maybe(schema.string()), +}); + +const AITool = schema.object({ + type: schema.string(), + function: schema.object({ + name: schema.string(), + description: schema.maybe(schema.string()), + parameters: schema.maybe(schema.recordOf(schema.string(), schema.any())), + }), +}); + +// subset of OpenAI.ChatCompletionCreateParamsBase https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts +export const UnifiedChatCompleteParamsSchema = schema.object({ + body: schema.object({ + messages: schema.arrayOf(AIMessage, { defaultValue: [] }), + model: schema.maybe(schema.string()), + /** + * The maximum number of [tokens](/tokenizer) that can be generated in the chat + * completion. This value can be used to control + * [costs](https://openai.com/api/pricing/) for text generated via API. + * + * This value is now deprecated in favor of `max_completion_tokens`, and is not + * compatible with + * [o1 series models](https://platform.openai.com/docs/guides/reasoning). + */ + max_tokens: schema.maybe(schema.number()), + /** + * Developer-defined tags and values used for filtering completions in the + * [dashboard](https://platform.openai.com/chat-completions). + */ + metadata: schema.maybe(schema.recordOf(schema.string(), schema.string())), + /** + * How many chat completion choices to generate for each input message. Note that + * you will be charged based on the number of generated tokens across all of the + * choices. Keep `n` as `1` to minimize costs. + */ + n: schema.maybe(schema.number()), + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + stop: schema.maybe( + schema.nullable(schema.oneOf([schema.string(), schema.arrayOf(schema.string())])) + ), + /** + * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + * make the output more random, while lower values like 0.2 will make it more + * focused and deterministic. + * + * We generally recommend altering this or `top_p` but not both. + */ + temperature: schema.maybe(schema.number()), + /** + * Controls which (if any) tool is called by the model. `none` means the model will + * not call any tool and instead generates a message. `auto` means the model can + * pick between generating a message or calling one or more tools. `required` means + * the model must call one or more tools. Specifying a particular tool via + * `{"type": "function", "function": {"name": "my_function"}}` forces the model to + * call that tool. + * + * `none` is the default when no tools are present. `auto` is the default if tools + * are present. + */ + tool_choice: schema.maybe( + schema.oneOf([ + schema.string(), + schema.object({ + type: schema.string(), + function: schema.object({ + name: schema.string(), + }), + }), + ]) + ), + /** + * A list of tools the model may call. Currently, only functions are supported as a + * tool. Use this to provide a list of functions the model may generate JSON inputs + * for. A max of 128 functions are supported. + */ + tools: schema.maybe(schema.arrayOf(AITool)), + /** + * An alternative to sampling with temperature, called nucleus sampling, where the + * model considers the results of the tokens with top_p probability mass. So 0.1 + * means only the tokens comprising the top 10% probability mass are considered. + * + * We generally recommend altering this or `temperature` but not both. + */ + top_p: schema.maybe(schema.number()), + /** + * A unique identifier representing your end-user, which can help OpenAI to monitor + * and detect abuse. + * [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + */ + user: schema.maybe(schema.string()), + }), + // abort signal from client + signal: schema.maybe(schema.any()), +}); + +export const UnifiedChatCompleteResponseSchema = schema.object({ + id: schema.string(), + choices: schema.arrayOf( + schema.object({ + finish_reason: schema.maybe( + schema.nullable( + schema.oneOf([ + schema.literal('stop'), + schema.literal('length'), + schema.literal('tool_calls'), + schema.literal('content_filter'), + schema.literal('function_call'), + ]) + ) + ), + index: schema.maybe(schema.number()), + message: schema.object({ + content: schema.maybe(schema.nullable(schema.string())), + refusal: schema.maybe(schema.nullable(schema.string())), + role: schema.maybe(schema.string()), + tool_calls: schema.maybe( + schema.arrayOf( + schema.object({ + id: schema.maybe(schema.string()), + index: schema.maybe(schema.number()), + function: schema.maybe( + schema.object({ + arguments: schema.maybe(schema.string()), + name: schema.maybe(schema.string()), + }) + ), + type: schema.maybe(schema.string()), + }), + { defaultValue: [] } + ) + ), + }), + }), + { defaultValue: [] } + ), + created: schema.maybe(schema.number()), + model: schema.maybe(schema.string()), + object: schema.maybe(schema.string()), + usage: schema.maybe( + schema.nullable( + schema.object({ + completion_tokens: schema.maybe(schema.number()), + prompt_tokens: schema.maybe(schema.number()), + total_tokens: schema.maybe(schema.number()), + }) + ) + ), +}); + export const ChatCompleteResponseSchema = schema.arrayOf( schema.object({ result: schema.string(), @@ -66,3 +236,12 @@ export const TextEmbeddingResponseSchema = schema.arrayOf( ); export const StreamingResponseSchema = schema.stream(); + +// Run action schema +export const DashboardActionParamsSchema = schema.object({ + dashboardId: schema.string(), +}); + +export const DashboardActionResponseSchema = schema.object({ + available: schema.boolean(), +}); diff --git a/x-pack/plugins/stack_connectors/common/inference/types.ts b/x-pack/plugins/stack_connectors/common/inference/types.ts index d8b846ce19422..1593429792e07 100644 --- a/x-pack/plugins/stack_connectors/common/inference/types.ts +++ b/x-pack/plugins/stack_connectors/common/inference/types.ts @@ -18,12 +18,19 @@ import { SparseEmbeddingResponseSchema, TextEmbeddingParamsSchema, TextEmbeddingResponseSchema, + UnifiedChatCompleteParamsSchema, + UnifiedChatCompleteResponseSchema, + DashboardActionParamsSchema, + DashboardActionResponseSchema, } from './schema'; import { ConfigProperties } from '../dynamic_config/types'; export type Config = TypeOf; export type Secrets = TypeOf; +export type UnifiedChatCompleteParams = TypeOf; +export type UnifiedChatCompleteResponse = TypeOf; + export type ChatCompleteParams = TypeOf; export type ChatCompleteResponse = TypeOf; @@ -38,6 +45,9 @@ export type TextEmbeddingResponse = TypeOf; export type StreamingResponse = TypeOf; +export type DashboardActionParams = TypeOf; +export type DashboardActionResponse = TypeOf; + export type FieldsConfiguration = Record; export interface InferenceProvider { diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/constants.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/constants.tsx index 8427caaf49ffc..1b635ca8fe887 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/inference/constants.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/constants.tsx @@ -25,13 +25,27 @@ export const DEFAULT_TEXT_EMBEDDING_BODY = { inputType: 'ingest', }; +export const DEFAULT_UNIFIED_CHAT_COMPLETE_BODY = { + body: { + messages: [ + { + role: 'user', + content: 'Hello world', + }, + ], + }, +}; + export const DEFAULTS_BY_TASK_TYPE: Record = { [SUB_ACTION.COMPLETION]: DEFAULT_CHAT_COMPLETE_BODY, + [SUB_ACTION.UNIFIED_COMPLETION]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY, + [SUB_ACTION.UNIFIED_COMPLETION_STREAM]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY, + [SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY, [SUB_ACTION.RERANK]: DEFAULT_RERANK_BODY, [SUB_ACTION.SPARSE_EMBEDDING]: DEFAULT_SPARSE_EMBEDDING_BODY, [SUB_ACTION.TEXT_EMBEDDING]: DEFAULT_TEXT_EMBEDDING_BODY, }; -export const DEFAULT_TASK_TYPE = 'completion'; +export const DEFAULT_TASK_TYPE = 'unified_completion'; export const DEFAULT_PROVIDER = 'elasticsearch'; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.test.tsx index 76dc50a316e65..b67264674aebe 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.test.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.test.tsx @@ -44,8 +44,16 @@ describe('OpenAI action params validation', () => { subActionParams: { input: ['message test'], query: 'foobar' }, }, { - subAction: SUB_ACTION.COMPLETION, - subActionParams: { input: 'message test' }, + subAction: SUB_ACTION.UNIFIED_COMPLETION, + subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } }, + }, + { + subAction: SUB_ACTION.UNIFIED_COMPLETION_STREAM, + subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } }, + }, + { + subAction: SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR, + subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } }, }, { subAction: SUB_ACTION.TEXT_EMBEDDING, @@ -55,6 +63,10 @@ describe('OpenAI action params validation', () => { subAction: SUB_ACTION.SPARSE_EMBEDDING, subActionParams: { input: 'message test' }, }, + { + subAction: SUB_ACTION.COMPLETION, + subActionParams: { input: 'message test' }, + }, ])( 'validation succeeds when params are valid for subAction $subAction', async ({ subAction, subActionParams }) => { @@ -63,19 +75,25 @@ describe('OpenAI action params validation', () => { subActionParams, }; expect(await actionTypeModel.validateParams(actionParams)).toEqual({ - errors: { input: [], subAction: [], inputType: [], query: [] }, + errors: { body: [], input: [], subAction: [], inputType: [], query: [] }, }); } ); test('params validation fails when params is a wrong object', async () => { const actionParams = { - subAction: SUB_ACTION.COMPLETION, + subAction: SUB_ACTION.UNIFIED_COMPLETION, subActionParams: { body: 'message {test}' }, }; expect(await actionTypeModel.validateParams(actionParams)).toEqual({ - errors: { input: ['Input is required.'], inputType: [], query: [], subAction: [] }, + errors: { + body: ['Messages is required.'], + inputType: [], + query: [], + subAction: [], + input: [], + }, }); }); @@ -86,6 +104,7 @@ describe('OpenAI action params validation', () => { expect(await actionTypeModel.validateParams(actionParams)).toEqual({ errors: { + body: [], input: [], inputType: [], query: [], @@ -102,6 +121,7 @@ describe('OpenAI action params validation', () => { expect(await actionTypeModel.validateParams(actionParams)).toEqual({ errors: { + body: [], input: [], inputType: [], query: [], @@ -118,6 +138,7 @@ describe('OpenAI action params validation', () => { expect(await actionTypeModel.validateParams(actionParams)).toEqual({ errors: { + body: [], input: ['Input is required.', 'Input does not have a valid Array format.'], inputType: [], query: ['Query is required.'], @@ -134,6 +155,7 @@ describe('OpenAI action params validation', () => { expect(await actionTypeModel.validateParams(actionParams)).toEqual({ errors: { + body: [], input: [], inputType: ['Input type is required.'], query: [], diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.tsx index e16d03306c166..388da0556801c 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/inference.tsx @@ -19,6 +19,7 @@ import { InferenceActionParams, InferenceConnector } from './types'; interface ValidationErrors { subAction: string[]; input: string[]; + body: string[]; // rerank only query: string[]; // text_embedding only @@ -40,14 +41,28 @@ export function getConnectorType(): InferenceConnector { const translations = await import('./translations'); const errors: ValidationErrors = { input: [], + body: [], subAction: [], inputType: [], query: [], }; if ( - subAction === SUB_ACTION.RERANK || + subAction === SUB_ACTION.UNIFIED_COMPLETION || + subAction === SUB_ACTION.UNIFIED_COMPLETION_STREAM || + subAction === SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR + ) { + if ( + !Array.isArray(subActionParams.body.messages) || + !subActionParams.body.messages.length + ) { + errors.body.push(translations.getRequiredMessage('Messages')); + } + } + + if ( subAction === SUB_ACTION.COMPLETION || + subAction === SUB_ACTION.RERANK || subAction === SUB_ACTION.TEXT_EMBEDDING || subAction === SUB_ACTION.SPARSE_EMBEDDING ) { @@ -76,10 +91,13 @@ export function getConnectorType(): InferenceConnector { errors.subAction.push(translations.getRequiredMessage('Action')); } else if ( ![ - SUB_ACTION.COMPLETION, + SUB_ACTION.UNIFIED_COMPLETION, + SUB_ACTION.UNIFIED_COMPLETION_STREAM, + SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR, SUB_ACTION.SPARSE_EMBEDDING, SUB_ACTION.RERANK, SUB_ACTION.TEXT_EMBEDDING, + SUB_ACTION.COMPLETION, ].includes(subAction) ) { errors.subAction.push(translations.INVALID_ACTION); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/params.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/params.test.tsx index 49773edc2246a..ba094ec64f6bd 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/inference/params.test.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/params.test.tsx @@ -15,8 +15,8 @@ describe('Inference Params Fields renders', () => { const { getByTestId } = render( { index={0} /> ); - expect(getByTestId('inferenceInput')).toBeInTheDocument(); - expect(getByTestId('inferenceInput')).toHaveProperty('value', 'What is Elastic?'); + expect(getByTestId('inference-bodyJsonEditor')).toBeInTheDocument(); + expect(getByTestId('bodyJsonEditor')).toHaveProperty( + 'value', + `{\"messages\":[{\"role\":\"user\",\"content\":\"What is Elastic?\"}]}` + ); }); test.each(['openai', 'googleaistudio'])( @@ -76,15 +79,25 @@ describe('Inference Params Fields renders', () => { /> ); expect(editAction).toHaveBeenCalledTimes(2); - expect(editAction).toHaveBeenCalledWith('subAction', SUB_ACTION.COMPLETION, 0); if (provider === 'openai') { + expect(editAction).toHaveBeenCalledWith('subAction', SUB_ACTION.UNIFIED_COMPLETION, 0); expect(editAction).toHaveBeenCalledWith( 'subActionParams', - { input: 'What is Elastic?' }, + { + body: { + messages: [ + { + content: 'Hello world', + role: 'user', + }, + ], + }, + }, 0 ); } if (provider === 'googleaistudio') { + expect(editAction).toHaveBeenCalledWith('subAction', SUB_ACTION.COMPLETION, 0); expect(editAction).toHaveBeenCalledWith( 'subActionParams', { input: 'What is Elastic?' }, diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/params.tsx b/x-pack/plugins/stack_connectors/public/connector_types/inference/params.tsx index c24fff24c33f6..be162e70493bc 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/inference/params.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/params.tsx @@ -12,11 +12,13 @@ import { } from '@kbn/triggers-actions-ui-plugin/public'; import { EuiTextArea, EuiFormRow, EuiSpacer, EuiSelect } from '@elastic/eui'; import type { RuleFormParamsErrors } from '@kbn/response-ops-rule-form'; +import { ActionVariable } from '@kbn/alerting-types'; import { ChatCompleteParams, RerankParams, SparseEmbeddingParams, TextEmbeddingParams, + UnifiedChatCompleteParams, } from '../../../common/inference/types'; import { DEFAULTS_BY_TASK_TYPE } from './constants'; import * as i18n from './translations'; @@ -25,28 +27,38 @@ import { InferenceActionConnector, InferenceActionParams } from './types'; const InferenceServiceParamsFields: React.FunctionComponent< ActionParamsProps -> = ({ actionParams, editAction, index, errors, actionConnector }) => { +> = ({ actionParams, editAction, index, errors, actionConnector, messageVariables }) => { const { subAction, subActionParams } = actionParams; - const { taskType } = (actionConnector as unknown as InferenceActionConnector).config; + const { taskType, provider } = (actionConnector as unknown as InferenceActionConnector).config; useEffect(() => { if (!subAction) { - editAction('subAction', taskType, index); + editAction( + 'subAction', + provider === 'openai' && taskType === 'completion' + ? SUB_ACTION.UNIFIED_COMPLETION + : taskType, + index + ); } - }, [editAction, index, subAction, taskType]); + }, [editAction, index, provider, subAction, taskType]); useEffect(() => { if (!subActionParams) { editAction( 'subActionParams', { - ...(DEFAULTS_BY_TASK_TYPE[taskType] ?? {}), + ...(DEFAULTS_BY_TASK_TYPE[ + provider === 'openai' && taskType === 'completion' + ? SUB_ACTION.UNIFIED_COMPLETION + : taskType + ] ?? {}), }, index ); } - }, [editAction, index, subActionParams, taskType]); + }, [editAction, index, provider, subActionParams, taskType]); const editSubActionParams = useCallback( (params: Partial) => { @@ -55,6 +67,28 @@ const InferenceServiceParamsFields: React.FunctionComponent< [editAction, index, subActionParams] ); + if (subAction === SUB_ACTION.UNIFIED_COMPLETION) { + return ( + + ); + } + + if (subAction === SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR) { + return ( + + ); + } + if (subAction === SUB_ACTION.COMPLETION) { return ( ) => void; + messageVariables: ActionVariable[] | undefined; +}> = ({ subActionParams, editSubActionParams, errors, messageVariables }) => { + const { body } = subActionParams ?? {}; + + return ( + <> + { + editSubActionParams({ body: JSON.parse(json) }); + }} + onBlur={() => { + if (!subActionParams.body) { + editSubActionParams({ body: { messages: [] } }); + } + }} + dataTestSubj="inference-bodyJsonEditor" + /> + + ); +}; + const CompletionParamsFields: React.FunctionComponent<{ subActionParams: ChatCompleteParams; errors: RuleFormParamsErrors; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/inference/types.ts b/x-pack/plugins/stack_connectors/public/connector_types/inference/types.ts index 1bd55793bc463..1756e213a1a7a 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/inference/types.ts +++ b/x-pack/plugins/stack_connectors/public/connector_types/inference/types.ts @@ -13,9 +13,16 @@ import { RerankParams, SparseEmbeddingParams, TextEmbeddingParams, + UnifiedChatCompleteParams, } from '../../../common/inference/types'; export type InferenceActionParams = + | { subAction: SUB_ACTION.UNIFIED_COMPLETION_STREAM; subActionParams: UnifiedChatCompleteParams } + | { subAction: SUB_ACTION.UNIFIED_COMPLETION; subActionParams: UnifiedChatCompleteParams } + | { + subAction: SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR; + subActionParams: UnifiedChatCompleteParams; + } | { subAction: SUB_ACTION.COMPLETION; subActionParams: ChatCompleteParams } | { subAction: SUB_ACTION.RERANK; subActionParams: RerankParams } | { subAction: SUB_ACTION.SPARSE_EMBEDDING; subActionParams: SparseEmbeddingParams } diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/helpers.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/helpers.ts new file mode 100644 index 0000000000000..7c6bfab9c6396 --- /dev/null +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/helpers.ts @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { last, lastValueFrom, map, merge, Observable, scan, share } from 'rxjs'; +import type { Readable } from 'node:stream'; +import { createParser } from 'eventsource-parser'; +import { UnifiedChatCompleteResponse } from '../../../common/inference/types'; + +// TODO: Extract to the common package with appex-ai +export function eventSourceStreamIntoObservable(readable: Readable) { + return new Observable((subscriber) => { + const parser = createParser({ + onEvent: (event) => { + subscriber.next(event.data); + }, + }); + + async function processStream() { + for await (const chunk of readable) { + parser.feed(chunk.toString()); + } + } + + processStream().then( + () => { + subscriber.complete(); + }, + (error) => { + subscriber.error(error); + } + ); + }); +} + +export function chunksIntoMessage(obs$: Observable) { + const shared$ = obs$.pipe(share()); + + return lastValueFrom( + merge( + shared$, + shared$.pipe( + scan( + (prev, chunk) => { + if (chunk.choices.length > 0 && !chunk.usage) { + prev.choices[0].message.content += chunk.choices[0].message.content ?? ''; + + chunk.choices[0].message.tool_calls?.forEach((toolCall) => { + if (toolCall.index !== undefined) { + const prevToolCallLength = prev.choices[0].message.tool_calls?.length ?? 0; + if (prevToolCallLength - 1 !== toolCall.index) { + if (!prev.choices[0].message.tool_calls) { + prev.choices[0].message.tool_calls = []; + } + prev.choices[0].message.tool_calls.push({ + function: { + name: '', + arguments: '', + }, + id: '', + }); + } + const prevToolCall = prev.choices[0].message.tool_calls[toolCall.index]; + + if (toolCall.function?.name) { + prevToolCall.function.name += toolCall.function?.name; + } + if (toolCall.function?.arguments) { + prevToolCall.function.arguments += toolCall.function?.arguments; + } + if (toolCall.id) { + prevToolCall.id += toolCall.id; + } + if (toolCall.type) { + prevToolCall.type = toolCall.type; + } + } + }); + } else if (chunk.usage) { + prev.usage = chunk.usage; + } + return { ...prev, id: chunk.id, model: chunk.model }; + }, + { + choices: [ + { + message: { + content: '', + role: 'assistant', + }, + }, + ], + object: 'chat.completion', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any + ), + last(), + map((concatenatedChunk): UnifiedChatCompleteResponse => { + // TODO: const validatedToolCalls = validateToolCalls(concatenatedChunk.choices[0].message.tool_calls); + if (concatenatedChunk.choices[0].message.content === '') { + concatenatedChunk.choices[0].message.content = null; + } + return concatenatedChunk; + }) + ) + ) + ); +} diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/index.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/index.ts index 18af48bc18a51..5af6773d15fe9 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/inference/index.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/index.ts @@ -161,7 +161,10 @@ export const configValidator = (configObject: Config, validatorServices: Validat ); } - if (!Object.keys(SUB_ACTION).includes(taskType.toUpperCase())) { + if ( + !taskType.includes('completion') && + !Object.keys(SUB_ACTION).includes(taskType.toUpperCase()) + ) { throw new Error( `Task type is not supported${ taskType && taskType.length ? `: ${taskType}` : `` diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts index a79bd0360598b..4aa28d2952dba 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts @@ -9,7 +9,7 @@ import { InferenceConnector } from './inference'; import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.mock'; import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; import { actionsMock } from '@kbn/actions-plugin/server/mocks'; -import { PassThrough, Transform } from 'stream'; +import { Readable, Transform } from 'stream'; import {} from '@kbn/actions-plugin/server/types'; import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks'; import { InferenceInferenceResponse } from '@elastic/elasticsearch/lib/api/types'; @@ -29,7 +29,7 @@ describe('InferenceConnector', () => { ], }; - describe('performApiCompletion', () => { + describe('performApiUnifiedCompletion', () => { const mockEsClient = elasticsearchClientMock.createClusterClient().asScoped().asInternalUser; beforeEach(() => { @@ -60,28 +60,44 @@ describe('InferenceConnector', () => { }); it('uses the completion task_type is supplied', async () => { - const response = await connector.performApiCompletion({ - input: 'What is Elastic?', + const stream = Readable.from([ + `data: {"id":"chatcmpl-AbLKRuRMZCAcMMQdl96KMTUgAfZNg","choices":[{"delta":{"content":" you"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}\n\n`, + `data: [DONE]\n\n`, + ]); + mockEsClient.transport.request.mockResolvedValue(stream); + + const response = await connector.performApiUnifiedCompletion({ + body: { messages: [{ content: 'What is Elastic?', role: 'user' }] }, }); - expect(mockEsClient.inference.inference).toBeCalledTimes(1); - expect(mockEsClient.inference.inference).toHaveBeenCalledWith( + expect(mockEsClient.transport.request).toBeCalledTimes(1); + expect(mockEsClient.transport.request).toHaveBeenCalledWith( { - inference_id: 'test', - input: 'What is Elastic?', - task_type: 'completion', + body: { + messages: [ + { + content: 'What is Elastic?', + role: 'user', + }, + ], + n: undefined, + }, + method: 'POST', + path: '_inference/completion/test/_unified', }, - { asStream: false } + { asStream: true } ); - expect(response).toEqual(mockResponse.completion); + expect(response.choices[0].message.content).toEqual(' you'); }); it('errors during API calls are properly handled', async () => { // @ts-ignore - mockEsClient.inference.inference = mockError; + mockEsClient.transport.request = mockError; - await expect(connector.performApiCompletion({ input: 'What is Elastic?' })).rejects.toThrow( - 'API Error' - ); + await expect( + connector.performApiUnifiedCompletion({ + body: { messages: [{ content: 'What is Elastic?', role: 'user' }] }, + }) + ).rejects.toThrow('API Error'); }); }); @@ -223,6 +239,7 @@ describe('InferenceConnector', () => { }; beforeEach(() => { + jest.clearAllMocks(); // @ts-ignore mockStream(); }); @@ -238,7 +255,7 @@ describe('InferenceConnector', () => { }, provider: 'elasticsearch', taskType: 'completion', - inferenceId: '', + inferenceId: 'test', taskTypeConfig: {}, }, secrets: { providerSecrets: {} }, @@ -247,13 +264,23 @@ describe('InferenceConnector', () => { }); it('the API call is successful with correct request parameters', async () => { - await connector.performApiCompletionStream({ input: 'Hello world' }); - expect(mockEsClient.inference.inference).toBeCalledTimes(1); - expect(mockEsClient.inference.inference).toHaveBeenCalledWith( + await connector.performApiUnifiedCompletionStream({ + body: { messages: [{ content: 'Hello world', role: 'user' }] }, + }); + expect(mockEsClient.transport.request).toBeCalledTimes(1); + expect(mockEsClient.transport.request).toHaveBeenCalledWith( { - inference_id: '', - input: 'Hello world', - task_type: 'completion', + body: { + messages: [ + { + content: 'Hello world', + role: 'user', + }, + ], + n: undefined, + }, + method: 'POST', + path: '_inference/completion/test/_unified', }, { asStream: true } ); @@ -261,32 +288,42 @@ describe('InferenceConnector', () => { it('signal is properly passed to streamApi', async () => { const signal = jest.fn() as unknown as AbortSignal; - await connector.performApiCompletionStream({ input: 'Hello world', signal }); + await connector.performApiUnifiedCompletionStream({ + body: { messages: [{ content: 'Hello world', role: 'user' }] }, + signal, + }); - expect(mockEsClient.inference.inference).toHaveBeenCalledWith( + expect(mockEsClient.transport.request).toHaveBeenCalledWith( { - inference_id: '', - input: 'Hello world', - task_type: 'completion', + body: { messages: [{ content: 'Hello world', role: 'user' }], n: undefined }, + method: 'POST', + path: '_inference/completion/test/_unified', }, - { asStream: true, signal } + { asStream: true } ); }); it('errors during API calls are properly handled', async () => { // @ts-ignore - mockEsClient.inference.inference = mockError; + mockEsClient.transport.request = mockError; await expect( - connector.performApiCompletionStream({ input: 'What is Elastic?' }) + connector.performApiUnifiedCompletionStream({ + body: { messages: [{ content: 'What is Elastic?', role: 'user' }] }, + }) ).rejects.toThrow('API Error'); }); it('responds with a readable stream', async () => { - const response = await connector.performApiCompletionStream({ - input: 'What is Elastic?', + const stream = Readable.from([ + `data: {"id":"chatcmpl-AbLKRuRMZCAcMMQdl96KMTUgAfZNg","choices":[{"delta":{"content":" you"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}\n\n`, + `data: [DONE]\n\n`, + ]); + mockEsClient.transport.request.mockResolvedValue(stream); + const response = await connector.performApiUnifiedCompletionStream({ + body: { messages: [{ content: 'What is Elastic?', role: 'user' }] }, }); - expect(response instanceof PassThrough).toEqual(true); + expect(response instanceof Readable).toEqual(true); }); }); }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts index d9aa4bf044e1d..d6c9af0e1365e 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts @@ -6,36 +6,44 @@ */ import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; - -import { PassThrough, Stream } from 'stream'; -import { IncomingMessage } from 'http'; +import { Stream } from 'openai/streaming'; +import { Readable } from 'stream'; import { AxiosError } from 'axios'; import { InferenceInferenceRequest, InferenceInferenceResponse, - InferenceTaskType, } from '@elastic/elasticsearch/lib/api/types'; +import { ConnectorUsageCollector } from '@kbn/actions-plugin/server/usage'; +import { filter, from, identity, map, mergeMap, Observable, tap } from 'rxjs'; +import OpenAI from 'openai'; +import { ChatCompletionChunk } from 'openai/resources'; import { ChatCompleteParamsSchema, RerankParamsSchema, SparseEmbeddingParamsSchema, TextEmbeddingParamsSchema, + UnifiedChatCompleteParamsSchema, } from '../../../common/inference/schema'; import { Config, Secrets, - ChatCompleteParams, - ChatCompleteResponse, - StreamingResponse, RerankParams, RerankResponse, SparseEmbeddingParams, SparseEmbeddingResponse, TextEmbeddingParams, TextEmbeddingResponse, + UnifiedChatCompleteParams, + UnifiedChatCompleteResponse, + DashboardActionParams, + DashboardActionResponse, + ChatCompleteParams, + ChatCompleteResponse, } from '../../../common/inference/types'; import { SUB_ACTION } from '../../../common/inference/constants'; +import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard'; +import { chunksIntoMessage, eventSourceStreamIntoObservable } from './helpers'; export class InferenceConnector extends SubActionConnector { // Not using Axios @@ -60,10 +68,25 @@ export class InferenceConnector extends SubActionConnector { } private registerSubActions() { + // non-streaming unified completion task this.registerSubAction({ - name: SUB_ACTION.COMPLETION, - method: 'performApiCompletion', - schema: ChatCompleteParamsSchema, + name: SUB_ACTION.UNIFIED_COMPLETION, + method: 'performApiUnifiedCompletion', + schema: UnifiedChatCompleteParamsSchema, + }); + + // streaming unified completion task + this.registerSubAction({ + name: SUB_ACTION.UNIFIED_COMPLETION_STREAM, + method: 'performApiUnifiedCompletionStream', + schema: UnifiedChatCompleteParamsSchema, + }); + + // streaming unified completion task for langchain + this.registerSubAction({ + name: SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR, + method: 'performApiUnifiedCompletionAsyncIterator', + schema: UnifiedChatCompleteParamsSchema, }); this.registerSubAction({ @@ -85,8 +108,8 @@ export class InferenceConnector extends SubActionConnector { }); this.registerSubAction({ - name: SUB_ACTION.COMPLETION_STREAM, - method: 'performApiCompletionStream', + name: SUB_ACTION.COMPLETION, + method: 'performApiCompletion', schema: ChatCompleteParamsSchema, }); } @@ -96,16 +119,112 @@ export class InferenceConnector extends SubActionConnector { * @param input the text on which you want to perform the inference task. * @signal abort signal */ - public async performApiCompletion({ - input, - signal, - }: ChatCompleteParams & { signal?: AbortSignal }): Promise { - const response = await this.performInferenceApi( - { inference_id: this.inferenceId, input, task_type: 'completion' }, - false, - signal + public async performApiUnifiedCompletion( + params: UnifiedChatCompleteParams + ): Promise { + const res = await this.performApiUnifiedCompletionStream(params); + + const obs$ = from(eventSourceStreamIntoObservable(res as unknown as Readable)).pipe( + filter((line) => !!line && line !== '[DONE]'), + map((line) => { + return JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } }; + }), + tap((line) => { + if ('error' in line) { + throw new Error(line.error.message); + } + if ( + 'choices' in line && + line.choices.length && + line.choices[0].finish_reason === 'length' + ) { + throw new Error('createTokenLimitReachedError()'); + } + }), + filter((line): line is OpenAI.ChatCompletionChunk => { + return 'object' in line && line.object === 'chat.completion.chunk'; + }), + mergeMap((chunk): Observable => { + const events: UnifiedChatCompleteResponse[] = []; + events.push({ + choices: chunk.choices.map((c) => ({ + message: { + tool_calls: c.delta.tool_calls?.map((t) => ({ + index: t.index, + id: t.id, + function: t.function, + type: t.type, + })), + content: c.delta.content, + refusal: c.delta.refusal, + role: c.delta.role, + }, + finish_reason: c.finish_reason, + index: c.index, + })), + id: chunk.id, + model: chunk.model, + object: chunk.object, + usage: chunk.usage, + }); + return from(events); + }), + identity + ); + + return chunksIntoMessage(obs$); + } + + /** + * responsible for making a esClient inference method to perform chat completetion task endpoint and returning the service response data + * @param input the text on which you want to perform the inference task. + * @signal abort signal + */ + public async performApiUnifiedCompletionStream(params: UnifiedChatCompleteParams) { + return await this.esClient.transport.request( + { + method: 'POST', + path: `_inference/completion/${this.inferenceId}/_unified`, + body: { ...params.body, n: undefined }, // exclude n param for now, constant is used on the inference API side + }, + { + asStream: true, + } ); - return response.completion!; + } + + /** + * Streamed requests (langchain) + * @param params - the request body + * @returns { + * consumerStream: Stream; the result to be read/transformed on the server and sent to the client via Server Sent Events + * tokenCountStream: Stream; the result for token counting stream + * } + */ + public async performApiUnifiedCompletionAsyncIterator( + params: UnifiedChatCompleteParams & { signal?: AbortSignal }, + connectorUsageCollector: ConnectorUsageCollector + ): Promise<{ + consumerStream: Stream; + tokenCountStream: Stream; + }> { + try { + connectorUsageCollector.addRequestBodyBytes(undefined, params.body); + const res = await this.performApiUnifiedCompletionStream(params); + const controller = new AbortController(); + // splits the stream in two, one is used for the UI and other for token tracking + + const stream = Stream.fromSSEResponse( + { body: res } as unknown as Response, + controller + ); + const teed = stream.tee(); + return { consumerStream: teed[0], tokenCountStream: teed[1] }; + // since we do not use the sub action connector request method, we need to do our own error handling + } catch (e) { + const errorMessage = this.getResponseErrorMessage(e); + throw new Error(errorMessage); + } } /** @@ -198,35 +317,56 @@ export class InferenceConnector extends SubActionConnector { } } - private async streamAPI({ + /** + * responsible for making a esClient inference method to perform chat completetion task endpoint and returning the service response data + * @param input the text on which you want to perform the inference task. + * @signal abort signal + */ + public async performApiCompletion({ input, signal, - }: ChatCompleteParams & { signal?: AbortSignal }): Promise { + }: ChatCompleteParams & { signal?: AbortSignal }): Promise { const response = await this.performInferenceApi( - { inference_id: this.inferenceId, input, task_type: this.taskType as InferenceTaskType }, - true, + { inference_id: this.inferenceId, input, task_type: 'completion' }, + false, signal ); - - return (response as unknown as Stream).pipe(new PassThrough()); + return response.completion!; } /** - * takes input. It calls the streamApi method to make a - * request to the Inference API with the message. It then returns a Transform stream - * that pipes the response from the API through the transformToString function, - * which parses the proprietary response into a string of the response text alone - * @param input A message to be sent to the API - * @signal abort signal + * retrieves a dashboard from the Kibana server and checks if the + * user has the necessary privileges to access it. + * @param dashboardId The ID of the dashboard to retrieve. */ - public async performApiCompletionStream({ - input, - signal, - }: ChatCompleteParams & { signal?: AbortSignal }): Promise { - const res = (await this.streamAPI({ - input, - signal, - })) as unknown as IncomingMessage; - return res; + public async getDashboard({ + dashboardId, + }: DashboardActionParams): Promise { + const privilege = (await this.esClient.transport.request({ + path: '/_security/user/_has_privileges', + method: 'POST', + body: { + index: [ + { + names: ['.kibana-event-log-*'], + allow_restricted_indices: true, + privileges: ['read'], + }, + ], + }, + })) as { has_all_requested: boolean }; + + if (!privilege?.has_all_requested) { + return { available: false }; + } + + const response = await initDashboard({ + logger: this.logger, + savedObjectsClient: this.savedObjectsClient, + dashboardId, + genAIProvider: 'Inference', + }); + + return { available: response.success }; } }