diff --git a/x-pack/packages/kbn-ai-assistant/src/chat/welcome_message.tsx b/x-pack/packages/kbn-ai-assistant/src/chat/welcome_message.tsx index 0783c7f64620..6133df55c57e 100644 --- a/x-pack/packages/kbn-ai-assistant/src/chat/welcome_message.tsx +++ b/x-pack/packages/kbn-ai-assistant/src/chat/welcome_message.tsx @@ -10,7 +10,7 @@ import { css } from '@emotion/css'; import { EuiFlexGroup, EuiFlexItem, EuiSpacer, useCurrentEuiBreakpoint } from '@elastic/eui'; import type { ActionConnector } from '@kbn/triggers-actions-ui-plugin/public'; import { GenerativeAIForObservabilityConnectorFeatureId } from '@kbn/actions-plugin/common'; -import { isSupportedConnectorType } from '@kbn/observability-ai-assistant-plugin/public'; +import { isSupportedConnectorType } from '@kbn/inference-common'; import { AssistantBeacon } from '@kbn/ai-assistant-icon'; import type { UseKnowledgeBaseResult } from '../hooks/use_knowledge_base'; import type { UseGenAIConnectorsResult } from '../hooks/use_genai_connectors'; diff --git a/x-pack/packages/kbn-ai-assistant/tsconfig.json b/x-pack/packages/kbn-ai-assistant/tsconfig.json index c23f92085c28..d33b8642561e 100644 --- a/x-pack/packages/kbn-ai-assistant/tsconfig.json +++ b/x-pack/packages/kbn-ai-assistant/tsconfig.json @@ -37,6 +37,7 @@ "@kbn/ml-plugin", "@kbn/share-plugin", "@kbn/ai-assistant-common", + "@kbn/inference-common", "@kbn/storybook", "@kbn/ai-assistant-icon", ] diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts index 134b0f02811f..0c6d254c0f52 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts @@ -95,3 +95,9 @@ export { } from './src/errors'; export { truncateList } from './src/truncate_list'; +export { + InferenceConnectorType, + isSupportedConnectorType, + isSupportedConnector, + type InferenceConnector, +} from './src/connectors'; diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors.test.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors.test.ts new file mode 100644 index 000000000000..a4729aa8a857 --- /dev/null +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors.test.ts @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { + InferenceConnectorType, + isSupportedConnectorType, + isSupportedConnector, + RawConnector, + COMPLETION_TASK_TYPE, +} from './connectors'; + +const createRawConnector = (parts: Partial): RawConnector => { + return { + id: 'id', + actionTypeId: 'connector-type', + name: 'some connector', + config: {}, + ...parts, + }; +}; + +describe('isSupportedConnectorType', () => { + it('returns true for supported connector types', () => { + expect(isSupportedConnectorType(InferenceConnectorType.OpenAI)).toBe(true); + expect(isSupportedConnectorType(InferenceConnectorType.Bedrock)).toBe(true); + expect(isSupportedConnectorType(InferenceConnectorType.Gemini)).toBe(true); + expect(isSupportedConnectorType(InferenceConnectorType.Inference)).toBe(true); + }); + it('returns false for unsupported connector types', () => { + expect(isSupportedConnectorType('anything-else')).toBe(false); + }); +}); + +describe('isSupportedConnector', () => { + // TODO + + it('returns true for OpenAI connectors', () => { + expect( + isSupportedConnector(createRawConnector({ actionTypeId: InferenceConnectorType.OpenAI })) + ).toBe(true); + }); + + it('returns true for Bedrock connectors', () => { + expect( + isSupportedConnector(createRawConnector({ actionTypeId: InferenceConnectorType.Bedrock })) + ).toBe(true); + }); + + it('returns true for Gemini connectors', () => { + expect( + isSupportedConnector(createRawConnector({ actionTypeId: InferenceConnectorType.Gemini })) + ).toBe(true); + }); + + it('returns true for OpenAI connectors with the right taskType', () => { + expect( + isSupportedConnector( + createRawConnector({ + actionTypeId: InferenceConnectorType.Inference, + config: { taskType: COMPLETION_TASK_TYPE }, + }) + ) + ).toBe(true); + }); + + it('returns false for OpenAI connectors with a bad taskType', () => { + expect( + isSupportedConnector( + createRawConnector({ + actionTypeId: InferenceConnectorType.Inference, + config: { taskType: 'embeddings' }, + }) + ) + ).toBe(false); + }); + + it('returns false for OpenAI connectors without taskType', () => { + expect( + isSupportedConnector( + createRawConnector({ + actionTypeId: InferenceConnectorType.Inference, + config: {}, + }) + ) + ).toBe(false); + }); +}); diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors.ts new file mode 100644 index 000000000000..da77d973614b --- /dev/null +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors.ts @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +/** + * The list of connector types that can be used with the inference APIs + */ +export enum InferenceConnectorType { + OpenAI = '.gen-ai', + Bedrock = '.bedrock', + Gemini = '.gemini', + Inference = '.inference', +} + +export const COMPLETION_TASK_TYPE = 'completion'; + +const allSupportedConnectorTypes = Object.values(InferenceConnectorType); + +export interface InferenceConnector { + type: InferenceConnectorType; + name: string; + connectorId: string; +} + +/** + * Checks if a given connector type is compatible for inference. + * + * Note: this check is not sufficient to assert if a given connector can be + * used for inference, as `.inference` connectors need additional check logic. + * Please use `isSupportedConnector` instead when possible. + */ +export function isSupportedConnectorType(id: string): id is InferenceConnectorType { + return allSupportedConnectorTypes.includes(id as InferenceConnectorType); +} + +/** + * Checks if a given connector is compatible for inference. + * + * A connector is compatible if: + * 1. its type is in the list of allowed types + * 2. for inference connectors, if its taskType is "completion" + */ +export function isSupportedConnector(connector: RawConnector): connector is RawInferenceConnector { + if (!isSupportedConnectorType(connector.actionTypeId)) { + return false; + } + if (connector.actionTypeId === InferenceConnectorType.Inference) { + const config = connector.config ?? {}; + if (config.taskType !== COMPLETION_TASK_TYPE) { + return false; + } + } + return true; +} + +/** + * Connector types are living in the actions plugin and we can't afford + * having dependencies from this package to some mid-level plugin, + * so we're just using our own connector mixin type. + */ +export interface RawConnector { + id: string; + actionTypeId: string; + name: string; + config?: Record; +} + +interface RawInferenceConnector { + id: string; + actionTypeId: InferenceConnectorType; + name: string; + config?: Record; +} diff --git a/x-pack/platform/plugins/shared/inference/common/connectors.ts b/x-pack/platform/plugins/shared/inference/common/connectors.ts deleted file mode 100644 index ee628f520fef..000000000000 --- a/x-pack/platform/plugins/shared/inference/common/connectors.ts +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -export enum InferenceConnectorType { - OpenAI = '.gen-ai', - Bedrock = '.bedrock', - Gemini = '.gemini', -} - -const allSupportedConnectorTypes = Object.values(InferenceConnectorType); - -export interface InferenceConnector { - type: InferenceConnectorType; - name: string; - connectorId: string; -} - -export function isSupportedConnectorType(id: string): id is InferenceConnectorType { - return allSupportedConnectorTypes.includes(id as InferenceConnectorType); -} diff --git a/x-pack/platform/plugins/shared/inference/common/http_apis.ts b/x-pack/platform/plugins/shared/inference/common/http_apis.ts index c07fcd29b221..f6a60051e84f 100644 --- a/x-pack/platform/plugins/shared/inference/common/http_apis.ts +++ b/x-pack/platform/plugins/shared/inference/common/http_apis.ts @@ -5,8 +5,12 @@ * 2.0. */ -import type { FunctionCallingMode, Message, ToolOptions } from '@kbn/inference-common'; -import { InferenceConnector } from './connectors'; +import type { + FunctionCallingMode, + Message, + ToolOptions, + InferenceConnector, +} from '@kbn/inference-common'; export type ChatCompleteRequestBody = { connectorId: string; diff --git a/x-pack/platform/plugins/shared/inference/public/types.ts b/x-pack/platform/plugins/shared/inference/public/types.ts index 735abfb5459a..f07fe1e63683 100644 --- a/x-pack/platform/plugins/shared/inference/public/types.ts +++ b/x-pack/platform/plugins/shared/inference/public/types.ts @@ -5,8 +5,7 @@ * 2.0. */ -import type { ChatCompleteAPI, OutputAPI } from '@kbn/inference-common'; -import type { InferenceConnector } from '../common/connectors'; +import type { ChatCompleteAPI, OutputAPI, InferenceConnector } from '@kbn/inference-common'; /* eslint-disable @typescript-eslint/no-empty-interface*/ diff --git a/x-pack/platform/plugins/shared/inference/scripts/util/kibana_client.ts b/x-pack/platform/plugins/shared/inference/scripts/util/kibana_client.ts index ef6f1c4fdcdc..a3a75ea98052 100644 --- a/x-pack/platform/plugins/shared/inference/scripts/util/kibana_client.ts +++ b/x-pack/platform/plugins/shared/inference/scripts/util/kibana_client.ts @@ -25,9 +25,9 @@ import { withoutOutputUpdateEvents, type ToolOptions, ChatCompleteOptions, + type InferenceConnector, } from '@kbn/inference-common'; import type { ChatCompleteRequestBody } from '../../common/http_apis'; -import type { InferenceConnector } from '../../common/connectors'; import { createOutputApi } from '../../common/output/create_output_api'; import { eventSourceStreamIntoObservable } from '../../server/util/event_source_stream_into_observable'; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/get_inference_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/get_inference_adapter.test.ts index 558e0cd06ef9..f6613152f9f0 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/get_inference_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/get_inference_adapter.test.ts @@ -5,11 +5,12 @@ * 2.0. */ -import { InferenceConnectorType } from '../../../common/connectors'; +import { InferenceConnectorType } from '@kbn/inference-common'; import { getInferenceAdapter } from './get_inference_adapter'; import { openAIAdapter } from './openai'; import { geminiAdapter } from './gemini'; import { bedrockClaudeAdapter } from './bedrock'; +import { inferenceAdapter } from './inference'; describe('getInferenceAdapter', () => { it('returns the openAI adapter for OpenAI type', () => { @@ -23,4 +24,8 @@ describe('getInferenceAdapter', () => { it('returns the bedrock adapter for Bedrock type', () => { expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(bedrockClaudeAdapter); }); + + it('returns the inference adapter for Inference type', () => { + expect(getInferenceAdapter(InferenceConnectorType.Inference)).toBe(inferenceAdapter); + }); }); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/get_inference_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/get_inference_adapter.ts index f34b0c27a339..ec5e6803ab86 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/get_inference_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/get_inference_adapter.ts @@ -5,11 +5,12 @@ * 2.0. */ -import { InferenceConnectorType } from '../../../common/connectors'; +import { InferenceConnectorType } from '@kbn/inference-common'; import type { InferenceConnectorAdapter } from '../types'; import { openAIAdapter } from './openai'; import { geminiAdapter } from './gemini'; import { bedrockClaudeAdapter } from './bedrock'; +import { inferenceAdapter } from './inference'; export const getInferenceAdapter = ( connectorType: InferenceConnectorType @@ -23,6 +24,9 @@ export const getInferenceAdapter = ( case InferenceConnectorType.Bedrock: return bedrockClaudeAdapter; + + case InferenceConnectorType.Inference: + return inferenceAdapter; } return undefined; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/index.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/index.ts new file mode 100644 index 000000000000..040b4103dae8 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/index.ts @@ -0,0 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { inferenceAdapter } from './inference_adapter'; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts new file mode 100644 index 000000000000..7cf5fc7bdfb8 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import OpenAI from 'openai'; +import { v4 } from 'uuid'; +import { PassThrough } from 'stream'; +import { lastValueFrom, Subject, toArray } from 'rxjs'; +import type { Logger } from '@kbn/logging'; +import { loggerMock } from '@kbn/logging-mocks'; +import { ChatCompletionEventType, MessageRole } from '@kbn/inference-common'; +import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream'; +import { InferenceExecutor } from '../../utils/inference_executor'; +import { inferenceAdapter } from './inference_adapter'; + +function createOpenAIChunk({ + delta, + usage, +}: { + delta?: OpenAI.ChatCompletionChunk['choices'][number]['delta']; + usage?: OpenAI.ChatCompletionChunk['usage']; +}): OpenAI.ChatCompletionChunk { + return { + choices: delta + ? [ + { + finish_reason: null, + index: 0, + delta, + }, + ] + : [], + created: new Date().getTime(), + id: v4(), + model: 'gpt-4o', + object: 'chat.completion.chunk', + usage, + }; +} + +describe('inferenceAdapter', () => { + const executorMock = { + invoke: jest.fn(), + } as InferenceExecutor & { invoke: jest.MockedFn }; + + const logger = { + debug: jest.fn(), + error: jest.fn(), + } as unknown as Logger; + + beforeEach(() => { + executorMock.invoke.mockReset(); + }); + + const defaultArgs = { + executor: executorMock, + logger: loggerMock.create(), + }; + + describe('when creating the request', () => { + beforeEach(() => { + executorMock.invoke.mockImplementation(async () => { + return { + actionId: '', + status: 'ok', + data: new PassThrough(), + }; + }); + }); + + it('emits chunk events', async () => { + const source$ = new Subject>(); + + executorMock.invoke.mockImplementation(async () => { + return { + actionId: '', + status: 'ok', + data: observableIntoEventSourceStream(source$, logger), + }; + }); + + const response$ = inferenceAdapter.chatComplete({ + ...defaultArgs, + messages: [ + { + role: MessageRole.User, + content: 'Hello', + }, + ], + }); + + source$.next( + createOpenAIChunk({ + delta: { + content: 'First', + }, + }) + ); + + source$.next( + createOpenAIChunk({ + delta: { + content: ', second', + }, + }) + ); + + source$.complete(); + + const allChunks = await lastValueFrom(response$.pipe(toArray())); + + expect(allChunks).toEqual([ + { + content: 'First', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + { + content: ', second', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + ]); + }); + + it('propagates the abort signal when provided', () => { + const abortController = new AbortController(); + + inferenceAdapter.chatComplete({ + logger, + executor: executorMock, + messages: [{ role: MessageRole.User, content: 'question' }], + abortSignal: abortController.signal, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + expect(executorMock.invoke).toHaveBeenCalledWith({ + subAction: 'unified_completion_stream', + subActionParams: expect.objectContaining({ + signal: abortController.signal, + }), + }); + }); + }); +}); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.ts new file mode 100644 index 000000000000..323dec4f5789 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.ts @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type OpenAI from 'openai'; +import { from, identity, switchMap, throwError } from 'rxjs'; +import { isReadable, Readable } from 'stream'; +import { createInferenceInternalError } from '@kbn/inference-common'; +import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable'; +import type { InferenceConnectorAdapter } from '../../types'; +import { + parseInlineFunctionCalls, + wrapWithSimulatedFunctionCalling, +} from '../../simulated_function_calling'; +import { + toolsToOpenAI, + toolChoiceToOpenAI, + messagesToOpenAI, + processOpenAIStream, +} from '../openai'; + +export const inferenceAdapter: InferenceConnectorAdapter = { + chatComplete: ({ + executor, + system, + messages, + toolChoice, + tools, + functionCalling, + logger, + abortSignal, + }) => { + const simulatedFunctionCalling = functionCalling === 'simulated'; + + let request: Omit & { model?: string }; + if (simulatedFunctionCalling) { + const wrapped = wrapWithSimulatedFunctionCalling({ + system, + messages, + toolChoice, + tools, + }); + request = { + messages: messagesToOpenAI({ system: wrapped.system, messages: wrapped.messages }), + }; + } else { + request = { + messages: messagesToOpenAI({ system, messages }), + tool_choice: toolChoiceToOpenAI(toolChoice), + tools: toolsToOpenAI(tools), + }; + } + + return from( + executor.invoke({ + subAction: 'unified_completion_stream', + subActionParams: { + body: request, + signal: abortSignal, + }, + }) + ).pipe( + switchMap((response) => { + if (response.status === 'error') { + return throwError(() => + createInferenceInternalError('Error calling the inference API', { + rootError: response.serviceMessage, + }) + ); + } + if (isReadable(response.data as any)) { + return eventSourceStreamIntoObservable(response.data as Readable); + } + return throwError(() => + createInferenceInternalError('Unexpected error', response.data as Record) + ); + }), + processOpenAIStream(), + simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity + ); + }, +}; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/from_openai.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/from_openai.ts new file mode 100644 index 000000000000..750ae4710104 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/from_openai.ts @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type OpenAI from 'openai'; +import { + ChatCompletionChunkEvent, + ChatCompletionEventType, + ChatCompletionTokenCountEvent, +} from '@kbn/inference-common'; + +export function chunkFromOpenAI(chunk: OpenAI.ChatCompletionChunk): ChatCompletionChunkEvent { + const delta = chunk.choices[0].delta; + + return { + type: ChatCompletionEventType.ChatCompletionChunk, + content: delta.content ?? '', + tool_calls: + delta.tool_calls?.map((toolCall) => { + return { + function: { + name: toolCall.function?.name ?? '', + arguments: toolCall.function?.arguments ?? '', + }, + toolCallId: toolCall.id ?? '', + index: toolCall.index, + }; + }) ?? [], + }; +} + +export function tokenCountFromOpenAI( + completionUsage: OpenAI.CompletionUsage +): ChatCompletionTokenCountEvent { + return { + type: ChatCompletionEventType.ChatCompletionTokenCount, + tokens: { + completion: completionUsage.completion_tokens, + prompt: completionUsage.prompt_tokens, + total: completionUsage.total_tokens, + }, + }; +} diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/index.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/index.ts index 9aa1d94e01a5..ddf8441756cb 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/index.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/index.ts @@ -6,3 +6,5 @@ */ export { openAIAdapter } from './openai_adapter'; +export { toolChoiceToOpenAI, messagesToOpenAI, toolsToOpenAI } from './to_openai'; +export { processOpenAIStream } from './process_openai_stream'; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts index 9b7fbc388024..d93dee627ec1 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts @@ -15,7 +15,7 @@ import { loggerMock } from '@kbn/logging-mocks'; import { ChatCompletionEventType, MessageRole } from '@kbn/inference-common'; import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream'; import { InferenceExecutor } from '../../utils/inference_executor'; -import { openAIAdapter } from '.'; +import { openAIAdapter } from './openai_adapter'; function createOpenAIChunk({ delta, diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts index 0529820b1bfb..8806429882e3 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts @@ -6,41 +6,17 @@ */ import type OpenAI from 'openai'; -import type { - ChatCompletionAssistantMessageParam, - ChatCompletionMessageParam, - ChatCompletionSystemMessageParam, - ChatCompletionToolMessageParam, - ChatCompletionUserMessageParam, -} from 'openai/resources'; -import { - filter, - from, - identity, - map, - mergeMap, - Observable, - switchMap, - tap, - throwError, -} from 'rxjs'; +import { from, identity, switchMap, throwError } from 'rxjs'; import { isReadable, Readable } from 'stream'; -import { - ChatCompletionChunkEvent, - ChatCompletionEventType, - ChatCompletionTokenCountEvent, - createInferenceInternalError, - Message, - MessageRole, - ToolOptions, -} from '@kbn/inference-common'; -import { createTokenLimitReachedError } from '../../errors'; +import { createInferenceInternalError } from '@kbn/inference-common'; import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable'; import type { InferenceConnectorAdapter } from '../../types'; import { parseInlineFunctionCalls, wrapWithSimulatedFunctionCalling, } from '../../simulated_function_calling'; +import { messagesToOpenAI, toolsToOpenAI, toolChoiceToOpenAI } from './to_openai'; +import { processOpenAIStream } from './process_openai_stream'; export const openAIAdapter: InferenceConnectorAdapter = { chatComplete: ({ @@ -95,158 +71,8 @@ export const openAIAdapter: InferenceConnectorAdapter = { createInferenceInternalError('Unexpected error', response.data as Record) ); }), - filter((line) => !!line && line !== '[DONE]'), - map( - (line) => JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } } - ), - tap((line) => { - if ('error' in line) { - throw createInferenceInternalError(line.error.message); - } - if ( - 'choices' in line && - line.choices.length && - line.choices[0].finish_reason === 'length' - ) { - throw createTokenLimitReachedError(); - } - }), - filter((line): line is OpenAI.ChatCompletionChunk => { - return 'object' in line && line.object === 'chat.completion.chunk'; - }), - mergeMap((chunk): Observable => { - const events: Array = []; - if (chunk.usage) { - events.push(tokenCountFromOpenAI(chunk.usage)); - } - if (chunk.choices?.length) { - events.push(chunkFromOpenAI(chunk)); - } - return from(events); - }), + processOpenAIStream(), simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity ); }, }; - -function chunkFromOpenAI(chunk: OpenAI.ChatCompletionChunk): ChatCompletionChunkEvent { - const delta = chunk.choices[0].delta; - - return { - type: ChatCompletionEventType.ChatCompletionChunk, - content: delta.content ?? '', - tool_calls: - delta.tool_calls?.map((toolCall) => { - return { - function: { - name: toolCall.function?.name ?? '', - arguments: toolCall.function?.arguments ?? '', - }, - toolCallId: toolCall.id ?? '', - index: toolCall.index, - }; - }) ?? [], - }; -} - -function tokenCountFromOpenAI( - completionUsage: OpenAI.CompletionUsage -): ChatCompletionTokenCountEvent { - return { - type: ChatCompletionEventType.ChatCompletionTokenCount, - tokens: { - completion: completionUsage.completion_tokens, - prompt: completionUsage.prompt_tokens, - total: completionUsage.total_tokens, - }, - }; -} - -function toolsToOpenAI(tools: ToolOptions['tools']): OpenAI.ChatCompletionCreateParams['tools'] { - return tools - ? Object.entries(tools).map(([toolName, { description, schema }]) => { - return { - type: 'function', - function: { - name: toolName, - description, - parameters: (schema ?? { - type: 'object' as const, - properties: {}, - }) as unknown as Record, - }, - }; - }) - : undefined; -} - -function toolChoiceToOpenAI( - toolChoice: ToolOptions['toolChoice'] -): OpenAI.ChatCompletionCreateParams['tool_choice'] { - return typeof toolChoice === 'string' - ? toolChoice - : toolChoice - ? { - function: { - name: toolChoice.function, - }, - type: 'function' as const, - } - : undefined; -} - -function messagesToOpenAI({ - system, - messages, -}: { - system?: string; - messages: Message[]; -}): OpenAI.ChatCompletionMessageParam[] { - const systemMessage: ChatCompletionSystemMessageParam | undefined = system - ? { role: 'system', content: system } - : undefined; - - return [ - ...(systemMessage ? [systemMessage] : []), - ...messages.map((message): ChatCompletionMessageParam => { - const role = message.role; - - switch (role) { - case MessageRole.Assistant: - const assistantMessage: ChatCompletionAssistantMessageParam = { - role: 'assistant', - content: message.content, - tool_calls: message.toolCalls?.map((toolCall) => { - return { - function: { - name: toolCall.function.name, - arguments: - 'arguments' in toolCall.function - ? JSON.stringify(toolCall.function.arguments) - : '{}', - }, - id: toolCall.toolCallId, - type: 'function', - }; - }), - }; - return assistantMessage; - - case MessageRole.User: - const userMessage: ChatCompletionUserMessageParam = { - role: 'user', - content: message.content, - }; - return userMessage; - - case MessageRole.Tool: - const toolMessage: ChatCompletionToolMessageParam = { - role: 'tool', - content: JSON.stringify(message.response), - tool_call_id: message.toolCallId, - }; - return toolMessage; - } - }), - ]; -} diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/process_openai_stream.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/process_openai_stream.ts new file mode 100644 index 000000000000..65384ed52e5f --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/process_openai_stream.ts @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type OpenAI from 'openai'; +import { filter, from, map, mergeMap, Observable, tap } from 'rxjs'; +import { + ChatCompletionChunkEvent, + ChatCompletionTokenCountEvent, + createInferenceInternalError, +} from '@kbn/inference-common'; +import { createTokenLimitReachedError } from '../../errors'; +import { tokenCountFromOpenAI, chunkFromOpenAI } from './from_openai'; + +export function processOpenAIStream() { + return (source: Observable) => { + return source.pipe( + filter((line) => !!line && line !== '[DONE]'), + map( + (line) => JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } } + ), + tap((line) => { + if ('error' in line) { + throw createInferenceInternalError(line.error.message); + } + if ( + 'choices' in line && + line.choices.length && + line.choices[0].finish_reason === 'length' + ) { + throw createTokenLimitReachedError(); + } + }), + filter((line): line is OpenAI.ChatCompletionChunk => { + return 'object' in line && line.object === 'chat.completion.chunk'; + }), + mergeMap((chunk): Observable => { + const events: Array = []; + if (chunk.usage) { + events.push(tokenCountFromOpenAI(chunk.usage)); + } + if (chunk.choices?.length) { + events.push(chunkFromOpenAI(chunk)); + } + return from(events); + }) + ); + }; +} diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.test.ts new file mode 100644 index 000000000000..978f775c5d3d --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.test.ts @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { MessageRole, ToolChoiceType } from '@kbn/inference-common'; +import { messagesToOpenAI, toolChoiceToOpenAI, toolsToOpenAI } from './to_openai'; + +describe('toolChoiceToOpenAI', () => { + it('returns the right value for tool choice types', () => { + expect(toolChoiceToOpenAI(ToolChoiceType.none)).toEqual('none'); + expect(toolChoiceToOpenAI(ToolChoiceType.auto)).toEqual('auto'); + expect(toolChoiceToOpenAI(ToolChoiceType.required)).toEqual('required'); + }); + + it('returns the right value for undefined', () => { + expect(toolChoiceToOpenAI(undefined)).toBeUndefined(); + }); + + it('returns the right value for named functions', () => { + expect(toolChoiceToOpenAI({ function: 'foo' })).toEqual({ + type: 'function', + function: { name: 'foo' }, + }); + }); +}); + +describe('toolsToOpenAI', () => { + it('converts tools to the expected format', () => { + expect( + toolsToOpenAI({ + myTool: { + description: 'my tool', + schema: { + type: 'object', + description: 'my tool schema', + properties: { + foo: { + type: 'string', + }, + }, + }, + }, + }) + ).toMatchInlineSnapshot(` + Array [ + Object { + "function": Object { + "description": "my tool", + "name": "myTool", + "parameters": Object { + "description": "my tool schema", + "properties": Object { + "foo": Object { + "type": "string", + }, + }, + "type": "object", + }, + }, + "type": "function", + }, + ] + `); + }); +}); + +describe('messagesToOpenAI', () => { + it('converts a user message', () => { + expect( + messagesToOpenAI({ + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + }) + ).toEqual([ + { + content: 'question', + role: 'user', + }, + ]); + }); + + it('converts single message and system', () => { + expect( + messagesToOpenAI({ + system: 'system message', + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + }) + ).toEqual([ + { + content: 'system message', + role: 'system', + }, + { + content: 'question', + role: 'user', + }, + ]); + }); + + it('converts a tool call', () => { + expect( + messagesToOpenAI({ + messages: [ + { + role: MessageRole.Tool, + name: 'tool', + response: {}, + toolCallId: 'callId', + }, + ], + }) + ).toEqual([ + { + content: '{}', + role: 'tool', + tool_call_id: 'callId', + }, + ]); + }); + + it('converts an assistant message', () => { + expect( + messagesToOpenAI({ + messages: [ + { + role: MessageRole.Assistant, + content: 'response', + }, + ], + }) + ).toEqual([ + { + role: 'assistant', + content: 'response', + }, + ]); + }); + + it('converts an assistant tool call', () => { + expect( + messagesToOpenAI({ + messages: [ + { + role: MessageRole.Assistant, + content: null, + toolCalls: [ + { + toolCallId: 'id', + function: { + name: 'function', + arguments: {}, + }, + }, + ], + }, + ], + }) + ).toEqual([ + { + role: 'assistant', + content: '', + tool_calls: [ + { + function: { + arguments: '{}', + name: 'function', + }, + id: 'id', + type: 'function', + }, + ], + }, + ]); + }); +}); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.ts new file mode 100644 index 000000000000..709b1fd4c6bf --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.ts @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type OpenAI from 'openai'; +import type { + ChatCompletionAssistantMessageParam, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionUserMessageParam, +} from 'openai/resources'; +import { Message, MessageRole, ToolOptions } from '@kbn/inference-common'; + +export function toolsToOpenAI( + tools: ToolOptions['tools'] +): OpenAI.ChatCompletionCreateParams['tools'] { + return tools + ? Object.entries(tools).map(([toolName, { description, schema }]) => { + return { + type: 'function', + function: { + name: toolName, + description, + parameters: (schema ?? { + type: 'object' as const, + properties: {}, + }) as unknown as Record, + }, + }; + }) + : undefined; +} + +export function toolChoiceToOpenAI( + toolChoice: ToolOptions['toolChoice'] +): OpenAI.ChatCompletionCreateParams['tool_choice'] { + return typeof toolChoice === 'string' + ? toolChoice + : toolChoice + ? { + function: { + name: toolChoice.function, + }, + type: 'function' as const, + } + : undefined; +} + +export function messagesToOpenAI({ + system, + messages, +}: { + system?: string; + messages: Message[]; +}): OpenAI.ChatCompletionMessageParam[] { + const systemMessage: ChatCompletionSystemMessageParam | undefined = system + ? { role: 'system', content: system } + : undefined; + + return [ + ...(systemMessage ? [systemMessage] : []), + ...messages.map((message): ChatCompletionMessageParam => { + const role = message.role; + + switch (role) { + case MessageRole.Assistant: + const assistantMessage: ChatCompletionAssistantMessageParam = { + role: 'assistant', + content: message.content ?? '', + tool_calls: message.toolCalls?.map((toolCall) => { + return { + function: { + name: toolCall.function.name, + arguments: + 'arguments' in toolCall.function + ? JSON.stringify(toolCall.function.arguments) + : '{}', + }, + id: toolCall.toolCallId, + type: 'function', + }; + }), + }; + return assistantMessage; + + case MessageRole.User: + const userMessage: ChatCompletionUserMessageParam = { + role: 'user', + content: message.content, + }; + return userMessage; + + case MessageRole.Tool: + const toolMessage: ChatCompletionToolMessageParam = { + role: 'tool', + content: JSON.stringify(message.response), + tool_call_id: message.toolCallId, + }; + return toolMessage; + } + }), + ]; +} diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.test.ts index 1821b553dd6a..1965d731885a 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.test.ts @@ -6,7 +6,7 @@ */ import { actionsClientMock } from '@kbn/actions-plugin/server/mocks'; -import { InferenceConnector, InferenceConnectorType } from '../../../common/connectors'; +import { InferenceConnector, InferenceConnectorType } from '@kbn/inference-common'; import { createInferenceExecutor, type InferenceExecutor } from './inference_executor'; describe('createInferenceExecutor', () => { diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.ts index c461e6b6cdfb..0849e71ccf97 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.ts @@ -11,7 +11,7 @@ import type { ActionsClient, PluginStartContract as ActionsPluginStart, } from '@kbn/actions-plugin/server'; -import type { InferenceConnector } from '../../../common/connectors'; +import type { InferenceConnector } from '@kbn/inference-common'; import { getConnectorById } from '../../util/get_connector_by_id'; export interface InferenceInvokeOptions { @@ -28,7 +28,7 @@ export type InferenceInvokeResult = ActionTypeExecutorResult InferenceConnector; - invoke(params: InferenceInvokeOptions): Promise; + invoke(params: InferenceInvokeOptions): Promise>; } export const createInferenceExecutor = ({ @@ -40,7 +40,7 @@ export const createInferenceExecutor = ({ }): InferenceExecutor => { return { getConnector: () => connector, - async invoke({ subAction, subActionParams }): Promise { + async invoke({ subAction, subActionParams }): Promise> { return await actionsClient.execute({ actionId: connector.connectorId, params: { diff --git a/x-pack/platform/plugins/shared/inference/server/inference_client/types.ts b/x-pack/platform/plugins/shared/inference/server/inference_client/types.ts index 193ce83f6d7b..4037eac3fb7c 100644 --- a/x-pack/platform/plugins/shared/inference/server/inference_client/types.ts +++ b/x-pack/platform/plugins/shared/inference/server/inference_client/types.ts @@ -10,8 +10,8 @@ import type { ChatCompleteAPI, BoundOutputAPI, OutputAPI, + InferenceConnector, } from '@kbn/inference-common'; -import type { InferenceConnector } from '../../common/connectors'; /** * An inference client, scoped to a request, that can be used to interact with LLMs. diff --git a/x-pack/platform/plugins/shared/inference/server/routes/connectors.ts b/x-pack/platform/plugins/shared/inference/server/routes/connectors.ts index 240e11a37f20..d28dfc6780af 100644 --- a/x-pack/platform/plugins/shared/inference/server/routes/connectors.ts +++ b/x-pack/platform/plugins/shared/inference/server/routes/connectors.ts @@ -10,7 +10,7 @@ import { InferenceConnector, InferenceConnectorType, isSupportedConnectorType, -} from '../../common/connectors'; +} from '@kbn/inference-common'; import type { InferenceServerStart, InferenceStartDependencies } from '../types'; export function registerConnectorsRoute({ diff --git a/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts index af7f35115325..2ef7d05bdbd5 100644 --- a/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts +++ b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { InferenceConnector, InferenceConnectorType } from '../../common/connectors'; +import { InferenceConnector, InferenceConnectorType } from '@kbn/inference-common'; export const createInferenceConnectorMock = ( parts: Partial = {} diff --git a/x-pack/platform/plugins/shared/inference/server/test_utils/inference_executor.ts b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_executor.ts index 64b5100a9db3..9203f5eacf0d 100644 --- a/x-pack/platform/plugins/shared/inference/server/test_utils/inference_executor.ts +++ b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_executor.ts @@ -5,7 +5,7 @@ * 2.0. */ -import type { InferenceConnector } from '../../common/connectors'; +import type { InferenceConnector } from '@kbn/inference-common'; import { InferenceExecutor } from '../chat_complete/utils'; import { createInferenceConnectorMock } from './inference_connector'; diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.test.ts b/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.test.ts index 7387944950f4..17b5cbe86d7f 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.test.ts @@ -7,7 +7,7 @@ import type { ActionResult as ActionConnector } from '@kbn/actions-plugin/server'; import { actionsClientMock } from '@kbn/actions-plugin/server/mocks'; -import { InferenceConnectorType } from '../../common/connectors'; +import { InferenceConnectorType } from '@kbn/inference-common'; import { getConnectorById } from './get_connector_by_id'; describe('getConnectorById', () => { @@ -68,7 +68,7 @@ describe('getConnectorById', () => { await expect(() => getConnectorById({ actionsClient, connectorId }) ).rejects.toThrowErrorMatchingInlineSnapshot( - `"Type '.tcp-pigeon' not recognized as a supported connector type"` + `"Connector 'tcp-pigeon-3-0' of type '.tcp-pigeon' not recognized as a supported connector"` ); }); diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.ts b/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.ts index 1dbf9a6f0d75..4bdbff0e1fec 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.ts @@ -6,8 +6,11 @@ */ import type { ActionsClient, ActionResult as ActionConnector } from '@kbn/actions-plugin/server'; -import { createInferenceRequestError } from '@kbn/inference-common'; -import { isSupportedConnectorType, type InferenceConnector } from '../../common/connectors'; +import { + createInferenceRequestError, + isSupportedConnector, + type InferenceConnector, +} from '@kbn/inference-common'; /** * Retrieves a connector given the provided `connectorId` and asserts it's an inference connector @@ -29,11 +32,9 @@ export const getConnectorById = async ({ throw createInferenceRequestError(`No connector found for id '${connectorId}'`, 400); } - const actionTypeId = connector.actionTypeId; - - if (!isSupportedConnectorType(actionTypeId)) { + if (!isSupportedConnector(connector)) { throw createInferenceRequestError( - `Type '${actionTypeId}' not recognized as a supported connector type`, + `Connector '${connector.id}' of type '${connector.actionTypeId}' not recognized as a supported connector`, 400 ); } @@ -41,6 +42,6 @@ export const getConnectorById = async ({ return { connectorId: connector.id, name: connector.name, - type: actionTypeId, + type: connector.actionTypeId, }; }; diff --git a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/common/connectors.ts b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/common/connectors.ts deleted file mode 100644 index f176f4009ac8..000000000000 --- a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/common/connectors.ts +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -export enum ObservabilityAIAssistantConnectorType { - Bedrock = '.bedrock', - OpenAI = '.gen-ai', - Gemini = '.gemini', -} - -export function isSupportedConnectorType( - type: string -): type is ObservabilityAIAssistantConnectorType { - return ( - type === ObservabilityAIAssistantConnectorType.Bedrock || - type === ObservabilityAIAssistantConnectorType.OpenAI || - type === ObservabilityAIAssistantConnectorType.Gemini - ); -} diff --git a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/common/index.ts b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/common/index.ts index 52afdf95d4a4..0157a6a2b0aa 100644 --- a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/common/index.ts +++ b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/common/index.ts @@ -47,8 +47,6 @@ export { export { concatenateChatCompletionChunks } from './utils/concatenate_chat_completion_chunks'; -export { isSupportedConnectorType } from './connectors'; - export { ShortIdTable } from './utils/short_id_table'; export { KnowledgeBaseType } from './types'; diff --git a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/public/index.ts b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/public/index.ts index 76e643c6ae0d..f8ca9709a6e2 100644 --- a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/public/index.ts +++ b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/public/index.ts @@ -62,7 +62,6 @@ export { } from '../common/functions/visualize_esql'; export { - isSupportedConnectorType, FunctionVisibility, MessageRole, KnowledgeBaseEntryRole, diff --git a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/routes/connectors/route.ts b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/routes/connectors/route.ts index 80bc877e6f5f..78e713b42e9f 100644 --- a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/routes/connectors/route.ts +++ b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/routes/connectors/route.ts @@ -5,7 +5,7 @@ * 2.0. */ import { FindActionResult } from '@kbn/actions-plugin/server'; -import { isSupportedConnectorType } from '../../../common/connectors'; +import { isSupportedConnector } from '@kbn/inference-common'; import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route'; const listConnectorsRoute = createObservabilityAIAssistantServerRoute({ @@ -37,8 +37,7 @@ const listConnectorsRoute = createObservabilityAIAssistantServerRoute({ return connectors.filter( (connector) => - availableTypes.includes(connector.actionTypeId) && - isSupportedConnectorType(connector.actionTypeId) + availableTypes.includes(connector.actionTypeId) && isSupportedConnector(connector) ); }, }); diff --git a/x-pack/plugins/stack_connectors/common/inference/schema.ts b/x-pack/plugins/stack_connectors/common/inference/schema.ts index c62e9782bb51..2213efef1d6e 100644 --- a/x-pack/plugins/stack_connectors/common/inference/schema.ts +++ b/x-pack/plugins/stack_connectors/common/inference/schema.ts @@ -26,7 +26,7 @@ export const ChatCompleteParamsSchema = schema.object({ // subset of OpenAI.ChatCompletionMessageParam https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts const AIMessage = schema.object({ role: schema.string(), - content: schema.maybe(schema.string()), + content: schema.maybe(schema.nullable(schema.string())), name: schema.maybe(schema.string()), tool_calls: schema.maybe( schema.arrayOf( diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts index 4aa28d2952db..febec4d27ff5 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.test.ts @@ -60,11 +60,13 @@ describe('InferenceConnector', () => { }); it('uses the completion task_type is supplied', async () => { - const stream = Readable.from([ - `data: {"id":"chatcmpl-AbLKRuRMZCAcMMQdl96KMTUgAfZNg","choices":[{"delta":{"content":" you"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}\n\n`, - `data: [DONE]\n\n`, - ]); - mockEsClient.transport.request.mockResolvedValue(stream); + mockEsClient.transport.request.mockResolvedValue({ + body: Readable.from([ + `data: {"id":"chatcmpl-AbLKRuRMZCAcMMQdl96KMTUgAfZNg","choices":[{"delta":{"content":" you"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}\n\n`, + `data: [DONE]\n\n`, + ]), + statusCode: 200, + }); const response = await connector.performApiUnifiedCompletion({ body: { messages: [{ content: 'What is Elastic?', role: 'user' }] }, @@ -84,7 +86,7 @@ describe('InferenceConnector', () => { method: 'POST', path: '_inference/completion/test/_unified', }, - { asStream: true } + { asStream: true, meta: true } ); expect(response.choices[0].message.content).toEqual(' you'); }); @@ -264,6 +266,11 @@ describe('InferenceConnector', () => { }); it('the API call is successful with correct request parameters', async () => { + mockEsClient.transport.request.mockResolvedValue({ + body: Readable.from([`data: [DONE]\n\n`]), + statusCode: 200, + }); + await connector.performApiUnifiedCompletionStream({ body: { messages: [{ content: 'Hello world', role: 'user' }] }, }); @@ -282,11 +289,16 @@ describe('InferenceConnector', () => { method: 'POST', path: '_inference/completion/test/_unified', }, - { asStream: true } + { asStream: true, meta: true } ); }); it('signal is properly passed to streamApi', async () => { + mockEsClient.transport.request.mockResolvedValue({ + body: Readable.from([`data: [DONE]\n\n`]), + statusCode: 200, + }); + const signal = jest.fn() as unknown as AbortSignal; await connector.performApiUnifiedCompletionStream({ body: { messages: [{ content: 'Hello world', role: 'user' }] }, @@ -299,7 +311,7 @@ describe('InferenceConnector', () => { method: 'POST', path: '_inference/completion/test/_unified', }, - { asStream: true } + { asStream: true, meta: true, signal } ); }); @@ -319,7 +331,10 @@ describe('InferenceConnector', () => { `data: {"id":"chatcmpl-AbLKRuRMZCAcMMQdl96KMTUgAfZNg","choices":[{"delta":{"content":" you"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}\n\n`, `data: [DONE]\n\n`, ]); - mockEsClient.transport.request.mockResolvedValue(stream); + mockEsClient.transport.request.mockResolvedValue({ + body: stream, + statusCode: 200, + }); const response = await connector.performApiUnifiedCompletionStream({ body: { messages: [{ content: 'What is Elastic?', role: 'user' }] }, }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts index d6c9af0e1365..63d8904a6af8 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/inference/inference.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { text as streamToString } from 'node:stream/consumers'; import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server'; import { Stream } from 'openai/streaming'; import { Readable } from 'stream'; @@ -181,7 +182,7 @@ export class InferenceConnector extends SubActionConnector { * @signal abort signal */ public async performApiUnifiedCompletionStream(params: UnifiedChatCompleteParams) { - return await this.esClient.transport.request( + const response = await this.esClient.transport.request( { method: 'POST', path: `_inference/completion/${this.inferenceId}/_unified`, @@ -189,8 +190,18 @@ export class InferenceConnector extends SubActionConnector { }, { asStream: true, + meta: true, + signal: params.signal, } ); + + // errors should be thrown as it will not be a stream response + if (response.statusCode >= 400) { + const error = await streamToString(response.body as unknown as Readable); + throw new Error(error); + } + + return response.body; } /** diff --git a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/scripts/evaluation/kibana_client.ts b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/scripts/evaluation/kibana_client.ts index f3b5ca357231..69f6715da2db 100644 --- a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/scripts/evaluation/kibana_client.ts +++ b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/scripts/evaluation/kibana_client.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { isSupportedConnectorType } from '@kbn/inference-common'; import { BufferFlushEvent, ChatCompletionChunkEvent, @@ -21,11 +22,7 @@ import { import type { ObservabilityAIAssistantScreenContext } from '@kbn/observability-ai-assistant-plugin/common/types'; import type { AssistantScope } from '@kbn/ai-assistant-common'; import { throwSerializedChatCompletionErrors } from '@kbn/observability-ai-assistant-plugin/common/utils/throw_serialized_chat_completion_errors'; -import { - isSupportedConnectorType, - Message, - MessageRole, -} from '@kbn/observability-ai-assistant-plugin/common'; +import { Message, MessageRole } from '@kbn/observability-ai-assistant-plugin/common'; import { streamIntoObservable } from '@kbn/observability-ai-assistant-plugin/server'; import { ToolingLog } from '@kbn/tooling-log'; import axios, { AxiosInstance, AxiosResponse, isAxiosError } from 'axios';