diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f60424da1c1da..8c044f50bfc7d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -499,6 +499,7 @@ x-pack/packages/index-management @elastic/kibana-management x-pack/plugins/index_management @elastic/kibana-management test/plugin_functional/plugins/index_patterns @elastic/kibana-data-discovery x-pack/packages/ml/inference_integration_flyout @elastic/ml-ui +x-pack/plugins/inference @elastic/kibana-core x-pack/packages/kbn-infra-forge @elastic/obs-ux-management-team x-pack/plugins/observability_solution/infra @elastic/obs-ux-logs-team @elastic/obs-ux-infra_services-team x-pack/plugins/ingest_pipelines @elastic/kibana-management @@ -1287,6 +1288,7 @@ x-pack/test/observability_ai_assistant_functional @elastic/obs-ai-assistant /x-pack/test_serverless/**/test_suites/common/saved_objects_management/ @elastic/kibana-core /x-pack/test_serverless/api_integration/test_suites/common/core/ @elastic/kibana-core /x-pack/test_serverless/api_integration/test_suites/**/telemetry/ @elastic/kibana-core +/x-pack/plugins/inference @elastic/kibana-core @elastic/obs-ai-assistant @elastic/security-generative-ai #CC# /src/core/server/csp/ @elastic/kibana-core #CC# /src/plugins/saved_objects/ @elastic/kibana-core #CC# /x-pack/plugins/cloud/ @elastic/kibana-core diff --git a/docs/developer/plugin-list.asciidoc b/docs/developer/plugin-list.asciidoc index 00b2bd8b6a624..72eee773793c6 100644 --- a/docs/developer/plugin-list.asciidoc +++ b/docs/developer/plugin-list.asciidoc @@ -630,6 +630,11 @@ Index Management by running this series of requests in Console: |This service is exposed from the Index Management setup contract and can be used to add content to the indices list and the index details page. +|{kib-repo}blob/{branch}/x-pack/plugins/inference/README.md[inference] +|The inference plugin is a central place to handle all interactions with the Elasticsearch Inference API and +external LLM APIs. Its goals are: + + |{kib-repo}blob/{branch}/x-pack/plugins/observability_solution/infra/README.md[infra] |This is the home of the infra plugin, which aims to provide a solution for the infrastructure monitoring use-case within Kibana. diff --git a/package.json b/package.json index 4eb7868550605..655bca42fc19e 100644 --- a/package.json +++ b/package.json @@ -541,6 +541,7 @@ "@kbn/index-management": "link:x-pack/packages/index-management", "@kbn/index-management-plugin": "link:x-pack/plugins/index_management", "@kbn/index-patterns-test-plugin": "link:test/plugin_functional/plugins/index_patterns", + "@kbn/inference-plugin": "link:x-pack/plugins/inference", "@kbn/inference_integration_flyout": "link:x-pack/packages/ml/inference_integration_flyout", "@kbn/infra-forge": "link:x-pack/packages/kbn-infra-forge", "@kbn/infra-plugin": "link:x-pack/plugins/observability_solution/infra", diff --git a/packages/kbn-optimizer/limits.yml b/packages/kbn-optimizer/limits.yml index b044ba8e093f4..8407645aae133 100644 --- a/packages/kbn-optimizer/limits.yml +++ b/packages/kbn-optimizer/limits.yml @@ -79,6 +79,7 @@ pageLoadAssetSize: imageEmbeddable: 12500 indexLifecycleManagement: 107090 indexManagement: 140608 + inference: 20403 infra: 184320 ingestPipelines: 58003 inputControlVis: 172675 diff --git a/tsconfig.base.json b/tsconfig.base.json index 1d8a9b6c83833..477e52bc568eb 100644 --- a/tsconfig.base.json +++ b/tsconfig.base.json @@ -992,6 +992,8 @@ "@kbn/index-patterns-test-plugin/*": ["test/plugin_functional/plugins/index_patterns/*"], "@kbn/inference_integration_flyout": ["x-pack/packages/ml/inference_integration_flyout"], "@kbn/inference_integration_flyout/*": ["x-pack/packages/ml/inference_integration_flyout/*"], + "@kbn/inference-plugin": ["x-pack/plugins/inference"], + "@kbn/inference-plugin/*": ["x-pack/plugins/inference/*"], "@kbn/infra-forge": ["x-pack/packages/kbn-infra-forge"], "@kbn/infra-forge/*": ["x-pack/packages/kbn-infra-forge/*"], "@kbn/infra-plugin": ["x-pack/plugins/observability_solution/infra"], diff --git a/x-pack/.i18nrc.json b/x-pack/.i18nrc.json index d0a336369ed29..7ff0f3e3ef766 100644 --- a/x-pack/.i18nrc.json +++ b/x-pack/.i18nrc.json @@ -54,6 +54,7 @@ "xpack.fleet": "plugins/fleet", "xpack.ingestPipelines": "plugins/ingest_pipelines", "xpack.integrationAssistant": "plugins/integration_assistant", + "xpack.inference": "plugins/inference", "xpack.investigate": "plugins/observability_solution/investigate", "xpack.investigateApp": "plugins/observability_solution/investigate_app", "xpack.kubernetesSecurity": "plugins/kubernetes_security", diff --git a/x-pack/plugins/inference/README.md b/x-pack/plugins/inference/README.md new file mode 100644 index 0000000000000..1807da7f29faa --- /dev/null +++ b/x-pack/plugins/inference/README.md @@ -0,0 +1,100 @@ +# Inference plugin + +The inference plugin is a central place to handle all interactions with the Elasticsearch Inference API and +external LLM APIs. Its goals are: + +- Provide a single place for all interactions with large language models and other generative AI adjacent tasks. +- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini +- Host commonly used LLM-based tasks like generating ES|QL from natural language and knowledge base recall. +- Allow us to move gradually to the \_inference endpoint without disrupting engineers. + +## Architecture and examples + +![CleanShot 2024-07-14 at 14 45 27@2x](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f) + +## Terminology + +The following concepts are commonly used throughout the plugin: + +- **chat completion**: the process in which the LLM generates the next message in the conversation. This is sometimes referred to as inference, text completion, text generation or content generation. +- **tasks**: higher level tasks that, based on its input, use the LLM in conjunction with other services like Elasticsearch to achieve a result. The example in this POC is natural language to ES|QL. +- **tools**: a set of tools that the LLM can choose to use when generating the next message. In essence, it allows the consumer of the API to define a schema for structured output instead of plain text, and having the LLM select the most appropriate one. +- **tool call**: when the LLM has chosen a tool (schema) to use for its output, and returns a document that matches the schema, this is referred to as a tool call. + +## Usage examples + +```ts +class MyPlugin { + setup(coreSetup, pluginsSetup) { + const router = coreSetup.http.createRouter(); + + router.post( + { + path: '/internal/my_plugin/do_something', + validate: { + body: schema.object({ + connectorId: schema.string(), + }), + }, + }, + async (context, request, response) => { + const [coreStart, pluginsStart] = await coreSetup.getStartServices(); + + const inferenceClient = pluginsSetup.inference.getClient({ request }); + + const chatComplete$ = inferenceClient.chatComplete({ + connectorId: request.body.connectorId, + system: `Here is my system message`, + messages: [ + { + role: MessageRole.User, + content: 'Do something', + }, + ], + }); + + const message = await lastValueFrom( + chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents()) + ); + + return response.ok({ + body: { + message, + }, + }); + } + ); + } +} +``` + +## Services + +### `chatComplete`: + +`chatComplete` generates a response to a prompt or a conversation using the LLM. Here's what is supported: + +- Normalizing request and response formats from different connector types (e.g. OpenAI, Bedrock, Claude, Elastic Inference Service) +- Tool calling and validation of tool calls +- Emits token count events +- Emits message events, which is the concatenated message based on the response chunks + +### `output` + +`output` is a wrapper around `chatComplete` that is catered towards a single use case: having the LLM output a structured response, based on a schema. It also drops the token count events to simplify usage. + +### Observable event streams + +These APIs, both on the client and the server, return Observables that emit events. When converting the Observable into a stream, the following things happen: + +- Errors are caught and serialized as events sent over the stream (after an error, the stream ends). +- The response stream outputs data as [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) +- The client that reads the stream, parses the event source as an Observable, and if it encounters a serialized error, it deserializes it and throws an error in the Observable. + +### Errors + +All known errors are instances, and not extensions, from the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error. This allows us to serialize and deserialize errors over the wire without a complicated factory pattern. + +### Tools + +Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety. This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`). diff --git a/x-pack/plugins/inference/common/chat_complete/errors.ts b/x-pack/plugins/inference/common/chat_complete/errors.ts new file mode 100644 index 0000000000000..8497350d7b49b --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/errors.ts @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { i18n } from '@kbn/i18n'; +import { InferenceTaskError } from '../errors'; +import type { UnvalidatedToolCall } from './tools'; + +export enum ChatCompletionErrorCode { + TokenLimitReachedError = 'tokenLimitReachedError', + ToolNotFoundError = 'toolNotFoundError', + ToolValidationError = 'toolValidationError', +} + +export type ChatCompletionTokenLimitReachedError = InferenceTaskError< + ChatCompletionErrorCode.TokenLimitReachedError, + { + tokenLimit?: number; + tokenCount?: number; + } +>; + +export type ChatCompletionToolNotFoundError = InferenceTaskError< + ChatCompletionErrorCode.ToolNotFoundError, + { + name: string; + } +>; + +export type ChatCompletionToolValidationError = InferenceTaskError< + ChatCompletionErrorCode.ToolValidationError, + { + name?: string; + arguments?: string; + errorsText?: string; + toolCalls?: UnvalidatedToolCall[]; + } +>; + +export function createTokenLimitReachedError( + tokenLimit?: number, + tokenCount?: number +): ChatCompletionTokenLimitReachedError { + return new InferenceTaskError( + ChatCompletionErrorCode.TokenLimitReachedError, + i18n.translate('xpack.inference.chatCompletionError.tokenLimitReachedError', { + defaultMessage: `Token limit reached. Token limit is {tokenLimit}, but the current conversation has {tokenCount} tokens.`, + values: { tokenLimit, tokenCount }, + }), + { tokenLimit, tokenCount } + ); +} + +export function createToolNotFoundError(name: string): ChatCompletionToolNotFoundError { + return new InferenceTaskError( + ChatCompletionErrorCode.ToolNotFoundError, + `Tool ${name} called but was not available`, + { + name, + } + ); +} + +export function createToolValidationError( + message: string, + meta: { + name?: string; + arguments?: string; + errorsText?: string; + toolCalls?: UnvalidatedToolCall[]; + } +): ChatCompletionToolValidationError { + return new InferenceTaskError(ChatCompletionErrorCode.ToolValidationError, message, meta); +} + +export function isToolValidationError(error?: Error): error is ChatCompletionToolValidationError { + return ( + error instanceof InferenceTaskError && + error.code === ChatCompletionErrorCode.ToolValidationError + ); +} + +export function isTokenLimitReachedError( + error: Error +): error is ChatCompletionTokenLimitReachedError { + return ( + error instanceof InferenceTaskError && + error.code === ChatCompletionErrorCode.TokenLimitReachedError + ); +} + +export function isToolNotFoundError(error: Error): error is ChatCompletionToolNotFoundError { + return ( + error instanceof InferenceTaskError && error.code === ChatCompletionErrorCode.ToolNotFoundError + ); +} diff --git a/x-pack/plugins/inference/common/chat_complete/index.ts b/x-pack/plugins/inference/common/chat_complete/index.ts new file mode 100644 index 0000000000000..175f86f74b5c4 --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/index.ts @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import type { Observable } from 'rxjs'; +import type { InferenceTaskEventBase } from '../tasks'; +import type { ToolCall, ToolCallsOf, ToolOptions } from './tools'; + +export enum MessageRole { + User = 'user', + Assistant = 'assistant', + Tool = 'tool', +} + +interface MessageBase { + role: TRole; +} + +export type UserMessage = MessageBase & { content: string }; + +export type AssistantMessage = MessageBase & { + content: string | null; + toolCalls?: Array | undefined>>; +}; + +export type ToolMessage | unknown> = + MessageBase & { + toolCallId: string; + response: TToolResponse; + }; + +export type Message = UserMessage | AssistantMessage | ToolMessage; + +export type ChatCompletionMessageEvent = + InferenceTaskEventBase & { + content: string; + } & { toolCalls: ToolCallsOf['toolCalls'] }; + +export type ChatCompletionResponse = Observable< + ChatCompletionEvent +>; + +export enum ChatCompletionEventType { + ChatCompletionChunk = 'chatCompletionChunk', + ChatCompletionTokenCount = 'chatCompletionTokenCount', + ChatCompletionMessage = 'chatCompletionMessage', +} + +export interface ChatCompletionChunkToolCall { + index: number; + toolCallId: string; + function: { + name: string; + arguments: string; + }; +} + +export type ChatCompletionChunkEvent = + InferenceTaskEventBase & { + content: string; + tool_calls: ChatCompletionChunkToolCall[]; + }; + +export type ChatCompletionTokenCountEvent = + InferenceTaskEventBase & { + tokens: { + prompt: number; + completion: number; + total: number; + }; + }; + +export type ChatCompletionEvent = + | ChatCompletionChunkEvent + | ChatCompletionTokenCountEvent + | ChatCompletionMessageEvent; + +/** + * Request a completion from the LLM based on a prompt or conversation. + * + * @param {string} options.connectorId The ID of the connector to use + * @param {string} [options.system] A system message that defines the behavior of the LLM. + * @param {Message[]} options.message A list of messages that make up the conversation to be completed. + * @param {ToolChoice} [options.toolChoice] Force the LLM to call a (specific) tool, or no tool + * @param {Record} [options.tools] A map of tools that can be called by the LLM + */ +export type ChatCompleteAPI = ( + options: { + connectorId: string; + system?: string; + messages: Message[]; + } & TToolOptions +) => ChatCompletionResponse; diff --git a/x-pack/plugins/inference/common/chat_complete/request.ts b/x-pack/plugins/inference/common/chat_complete/request.ts new file mode 100644 index 0000000000000..104d1856c9c80 --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/request.ts @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { Message } from '.'; +import { ToolOptions } from './tools'; + +export type ChatCompleteRequestBody = { + connectorId: string; + stream?: boolean; + system?: string; + messages: Message[]; +} & ToolOptions; diff --git a/x-pack/plugins/inference/common/chat_complete/tool_schema.ts b/x-pack/plugins/inference/common/chat_complete/tool_schema.ts new file mode 100644 index 0000000000000..5ca3e0ab57a49 --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/tool_schema.ts @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Required, ValuesType, UnionToIntersection } from 'utility-types'; + +interface ToolSchemaFragmentBase { + description?: string; +} + +interface ToolSchemaTypeObject extends ToolSchemaFragmentBase { + type: 'object'; + properties: Record; + required?: string[] | readonly string[]; +} + +interface ToolSchemaTypeString extends ToolSchemaFragmentBase { + type: 'string'; + const?: string; + enum?: string[] | readonly string[]; +} + +interface ToolSchemaTypeBoolean extends ToolSchemaFragmentBase { + type: 'boolean'; + const?: string; + enum?: string[] | readonly string[]; +} + +interface ToolSchemaTypeNumber extends ToolSchemaFragmentBase { + type: 'number'; + const?: string; + enum?: string[] | readonly string[]; +} + +interface ToolSchemaAnyOf extends ToolSchemaFragmentBase { + anyOf: ToolSchemaType[]; +} + +interface ToolSchemaAllOf extends ToolSchemaFragmentBase { + allOf: ToolSchemaType[]; +} + +interface ToolSchemaTypeArray extends ToolSchemaFragmentBase { + type: 'array'; + items: Exclude; +} + +type ToolSchemaType = + | ToolSchemaTypeObject + | ToolSchemaTypeString + | ToolSchemaTypeBoolean + | ToolSchemaTypeNumber + | ToolSchemaTypeArray; + +type ToolSchemaFragment = ToolSchemaType | ToolSchemaAnyOf | ToolSchemaAllOf; + +type FromToolSchemaObject = Required< + { + [key in keyof TToolSchemaObject['properties']]?: FromToolSchema< + TToolSchemaObject['properties'][key] + >; + }, + TToolSchemaObject['required'] extends string[] | readonly string[] + ? ValuesType + : never +>; + +type FromToolSchemaArray = Array< + FromToolSchema +>; + +type FromToolSchemaString = + TToolSchemaString extends { const: string } + ? TToolSchemaString['const'] + : TToolSchemaString extends { enum: string[] } | { enum: readonly string[] } + ? ValuesType + : string; + +type FromToolSchemaAnyOf = FromToolSchema< + ValuesType +>; + +type FromToolSchemaAllOf = UnionToIntersection< + FromToolSchema> +>; + +export type ToolSchema = ToolSchemaTypeObject; + +export type FromToolSchema = + TToolSchema extends ToolSchemaTypeObject + ? FromToolSchemaObject + : TToolSchema extends ToolSchemaTypeArray + ? FromToolSchemaArray + : TToolSchema extends ToolSchemaTypeBoolean + ? boolean + : TToolSchema extends ToolSchemaTypeNumber + ? number + : TToolSchema extends ToolSchemaTypeString + ? FromToolSchemaString + : TToolSchema extends ToolSchemaAnyOf + ? FromToolSchemaAnyOf + : TToolSchema extends ToolSchemaAllOf + ? FromToolSchemaAllOf + : never; diff --git a/x-pack/plugins/inference/common/chat_complete/tools.ts b/x-pack/plugins/inference/common/chat_complete/tools.ts new file mode 100644 index 0000000000000..85fb4cd9d7020 --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/tools.ts @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import type { ValuesType } from 'utility-types'; +import { FromToolSchema, ToolSchema } from './tool_schema'; + +type Assert = TValue extends TType ? TValue & TType : never; + +interface CustomToolChoice { + function: TName; +} + +type ToolsOfChoice = TToolOptions['toolChoice'] extends { + function: infer TToolName; +} + ? TToolName extends keyof TToolOptions['tools'] + ? Pick + : TToolOptions['tools'] + : TToolOptions['tools']; + +type ToolResponsesOf | undefined> = + TTools extends Record + ? Array< + ValuesType<{ + [TName in keyof TTools]: ToolResponseOf, TTools[TName]>; + }> + > + : never[]; + +type ToolResponseOf = ToolCall< + TName, + TToolDefinition extends { schema: ToolSchema } ? FromToolSchema : {} +>; + +export type ToolChoice = ToolChoiceType | CustomToolChoice; + +export interface ToolDefinition { + description: string; + schema?: ToolSchema; +} + +export type ToolCallsOf = TToolOptions extends { + tools?: Record; +} + ? TToolOptions extends { toolChoice: ToolChoiceType.none } + ? { toolCalls: [] } + : { + toolCalls: ToolResponsesOf< + Assert, Record | undefined> + >; + } + : { toolCalls: never[] }; + +export enum ToolChoiceType { + none = 'none', + auto = 'auto', + required = 'required', +} + +export interface UnvalidatedToolCall { + toolCallId: string; + function: { + name: string; + arguments: string; + }; +} + +export interface ToolCall< + TName extends string = string, + TArguments extends Record | undefined = undefined +> { + toolCallId: string; + function: { + name: TName; + } & (TArguments extends Record ? { arguments: TArguments } : {}); +} + +export interface ToolOptions { + toolChoice?: ToolChoice; + tools?: Record; +} diff --git a/x-pack/plugins/inference/common/chat_complete/without_chunk_events.ts b/x-pack/plugins/inference/common/chat_complete/without_chunk_events.ts new file mode 100644 index 0000000000000..58e72e2c90903 --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/without_chunk_events.ts @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { filter, OperatorFunction } from 'rxjs'; +import { ChatCompletionChunkEvent, ChatCompletionEvent, ChatCompletionEventType } from '.'; + +export function withoutChunkEvents(): OperatorFunction< + T, + Exclude +> { + return filter( + (event): event is Exclude => + event.type !== ChatCompletionEventType.ChatCompletionChunk + ); +} diff --git a/x-pack/plugins/inference/common/chat_complete/without_token_count_events.ts b/x-pack/plugins/inference/common/chat_complete/without_token_count_events.ts new file mode 100644 index 0000000000000..1b7dbdb9c1372 --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/without_token_count_events.ts @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { filter, OperatorFunction } from 'rxjs'; +import { ChatCompletionEvent, ChatCompletionEventType, ChatCompletionTokenCountEvent } from '.'; + +export function withoutTokenCountEvents(): OperatorFunction< + T, + Exclude +> { + return filter( + (event): event is Exclude => + event.type !== ChatCompletionEventType.ChatCompletionTokenCount + ); +} diff --git a/x-pack/plugins/inference/common/connectors.ts b/x-pack/plugins/inference/common/connectors.ts new file mode 100644 index 0000000000000..82baea2f83c39 --- /dev/null +++ b/x-pack/plugins/inference/common/connectors.ts @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export enum InferenceConnectorType { + OpenAI = '.gen-ai', + Bedrock = '.bedrock', + Gemini = '.gemini', +} + +export interface InferenceConnector { + type: InferenceConnectorType; + name: string; + connectorId: string; +} + +export function isSupportedConnectorType(id: string): id is InferenceConnectorType { + return ( + id === InferenceConnectorType.OpenAI || + id === InferenceConnectorType.Bedrock || + id === InferenceConnectorType.Gemini + ); +} diff --git a/x-pack/plugins/inference/common/errors.ts b/x-pack/plugins/inference/common/errors.ts new file mode 100644 index 0000000000000..fa063e1669936 --- /dev/null +++ b/x-pack/plugins/inference/common/errors.ts @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { i18n } from '@kbn/i18n'; +import { InferenceTaskEventBase, InferenceTaskEventType } from './tasks'; + +export enum InferenceTaskErrorCode { + internalError = 'internalError', + requestError = 'requestError', +} + +export class InferenceTaskError< + TCode extends string, + TMeta extends Record | undefined +> extends Error { + constructor(public code: TCode, message: string, public meta: TMeta) { + super(message); + } + + toJSON(): InferenceTaskErrorEvent { + return { + type: InferenceTaskEventType.error, + error: { + code: this.code, + message: this.message, + meta: this.meta, + }, + }; + } +} + +export type InferenceTaskErrorEvent = InferenceTaskEventBase & { + error: { + code: string; + message: string; + meta?: Record; + }; +}; + +export type InferenceTaskInternalError = InferenceTaskError< + InferenceTaskErrorCode.internalError, + {} +>; + +export type InferenceTaskRequestError = InferenceTaskError< + InferenceTaskErrorCode.requestError, + { status: number } +>; + +export function createInferenceInternalError( + message: string = i18n.translate('xpack.inference.internalError', { + defaultMessage: 'An internal error occurred', + }) +): InferenceTaskInternalError { + return new InferenceTaskError(InferenceTaskErrorCode.internalError, message, {}); +} + +export function createInferenceRequestError( + message: string, + status: number +): InferenceTaskRequestError { + return new InferenceTaskError(InferenceTaskErrorCode.requestError, message, { + status, + }); +} + +export function isInferenceError( + error: unknown +): error is InferenceTaskError | undefined> { + return error instanceof InferenceTaskError; +} + +export function isInferenceInternalError(error: unknown): error is InferenceTaskInternalError { + return isInferenceError(error) && error.code === InferenceTaskErrorCode.internalError; +} + +export function isInferenceRequestError(error: unknown): error is InferenceTaskRequestError { + return isInferenceError(error) && error.code === InferenceTaskErrorCode.requestError; +} diff --git a/x-pack/plugins/inference/common/output/create_output_api.ts b/x-pack/plugins/inference/common/output/create_output_api.ts new file mode 100644 index 0000000000000..9842f9635dea8 --- /dev/null +++ b/x-pack/plugins/inference/common/output/create_output_api.ts @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { map } from 'rxjs'; +import { ChatCompleteAPI, ChatCompletionEventType, MessageRole } from '../chat_complete'; +import { withoutTokenCountEvents } from '../chat_complete/without_token_count_events'; +import { OutputAPI, OutputEvent, OutputEventType } from '.'; + +export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI { + return (id, { connectorId, input, schema, system }) => { + return chatCompleteApi({ + connectorId, + system, + messages: [ + { + role: MessageRole.User, + content: input, + }, + ], + ...(schema + ? { + tools: { output: { description: `Output your response in the this format`, schema } }, + toolChoice: { function: 'output' }, + } + : {}), + }).pipe( + withoutTokenCountEvents(), + map((event): OutputEvent => { + if (event.type === ChatCompletionEventType.ChatCompletionChunk) { + return { + type: OutputEventType.OutputUpdate, + id, + content: event.content, + }; + } + return { + id, + type: OutputEventType.OutputComplete, + output: event.toolCalls[0].function.arguments, + }; + }) + ); + }; +} diff --git a/x-pack/plugins/inference/common/output/index.ts b/x-pack/plugins/inference/common/output/index.ts new file mode 100644 index 0000000000000..69df7fb3ecd9d --- /dev/null +++ b/x-pack/plugins/inference/common/output/index.ts @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Observable } from 'rxjs'; +import { FromToolSchema, ToolSchema } from '../chat_complete/tool_schema'; +import { InferenceTaskEventBase } from '../tasks'; + +export enum OutputEventType { + OutputUpdate = 'output', + OutputComplete = 'complete', +} + +type Output = Record | undefined; + +export type OutputUpdateEvent = + InferenceTaskEventBase & { + id: TId; + content: string; + }; + +export type OutputCompleteEvent< + TId extends string = string, + TOutput extends Output = Output +> = InferenceTaskEventBase & { + id: TId; + output: TOutput; +}; + +export type OutputEvent = + | OutputUpdateEvent + | OutputCompleteEvent; + +/** + * Generate a response with the LLM for a prompt, optionally based on a schema. + * + * @param {string} id The id of the operation + * @param {string} options.connectorId The ID of the connector that is to be used. + * @param {string} options.input The prompt for the LLM + * @param {ToolSchema} [options.schema] The schema the response from the LLM should adhere to. + */ +export type OutputAPI = < + TId extends string = string, + TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined +>( + id: TId, + options: { + connectorId: string; + system?: string; + input: string; + schema?: TOutputSchema; + } +) => Observable< + OutputEvent : undefined> +>; + +export function createOutputCompleteEvent( + id: TId, + output: TOutput +): OutputCompleteEvent { + return { + id, + type: OutputEventType.OutputComplete, + output, + }; +} diff --git a/x-pack/plugins/inference/common/output/without_output_update_events.ts b/x-pack/plugins/inference/common/output/without_output_update_events.ts new file mode 100644 index 0000000000000..38f26c8c8ece1 --- /dev/null +++ b/x-pack/plugins/inference/common/output/without_output_update_events.ts @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { filter, OperatorFunction } from 'rxjs'; +import { OutputEvent, OutputEventType, OutputUpdateEvent } from '.'; + +export function withoutOutputUpdateEvents(): OperatorFunction< + T, + Exclude +> { + return filter( + (event): event is Exclude => event.type !== OutputEventType.OutputUpdate + ); +} diff --git a/x-pack/plugins/inference/common/tasks.ts b/x-pack/plugins/inference/common/tasks.ts new file mode 100644 index 0000000000000..7b8f65b7af2c9 --- /dev/null +++ b/x-pack/plugins/inference/common/tasks.ts @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export interface InferenceTaskEventBase { + type: TEventType; +} + +export enum InferenceTaskEventType { + error = 'error', +} + +export type InferenceTaskEvent = InferenceTaskEventBase; diff --git a/x-pack/plugins/inference/jest.config.js b/x-pack/plugins/inference/jest.config.js new file mode 100644 index 0000000000000..3bc2142bcdfc3 --- /dev/null +++ b/x-pack/plugins/inference/jest.config.js @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +module.exports = { + preset: '@kbn/test', + rootDir: '../../..', + roots: ['/x-pack/plugins/inference/public', '/x-pack/plugins/inference/server'], + setupFiles: [], + collectCoverage: true, + collectCoverageFrom: [ + '/x-pack/plugins/inference/{public,server,common}/**/*.{js,ts,tsx}', + ], + + coverageReporters: ['html'], +}; diff --git a/x-pack/plugins/inference/kibana.jsonc b/x-pack/plugins/inference/kibana.jsonc new file mode 100644 index 0000000000000..c52b194be7dc7 --- /dev/null +++ b/x-pack/plugins/inference/kibana.jsonc @@ -0,0 +1,18 @@ +{ + "type": "plugin", + "id": "@kbn/inference-plugin", + "owner": "@elastic/kibana-core", + "plugin": { + "id": "inference", + "server": true, + "browser": true, + "configPath": ["xpack", "inference"], + "requiredPlugins": [ + "actions" + ], + "requiredBundles": [ + ], + "optionalPlugins": [], + "extraPublicDirs": [] + } +} diff --git a/x-pack/plugins/inference/public/chat_complete/index.ts b/x-pack/plugins/inference/public/chat_complete/index.ts new file mode 100644 index 0000000000000..2509ea2dc1222 --- /dev/null +++ b/x-pack/plugins/inference/public/chat_complete/index.ts @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { HttpStart } from '@kbn/core/public'; +import { from } from 'rxjs'; +import { ChatCompleteAPI } from '../../common/chat_complete'; +import type { ChatCompleteRequestBody } from '../../common/chat_complete/request'; +import { httpResponseIntoObservable } from '../util/http_response_into_observable'; + +export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI { + return ({ connectorId, messages, system, toolChoice, tools }) => { + const body: ChatCompleteRequestBody = { + connectorId, + system, + messages, + toolChoice, + tools, + }; + + return from( + http.post('/internal/inference/chat_complete', { + asResponse: true, + rawResponse: true, + body: JSON.stringify(body), + }) + ).pipe(httpResponseIntoObservable()); + }; +} diff --git a/x-pack/plugins/inference/public/index.ts b/x-pack/plugins/inference/public/index.ts new file mode 100644 index 0000000000000..82d36a7abe82d --- /dev/null +++ b/x-pack/plugins/inference/public/index.ts @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import type { PluginInitializer, PluginInitializerContext } from '@kbn/core/public'; + +import { InferencePlugin } from './plugin'; +import type { + InferencePublicSetup, + InferencePublicStart, + InferenceSetupDependencies, + InferenceStartDependencies, + ConfigSchema, +} from './types'; + +export { httpResponseIntoObservable } from './util/http_response_into_observable'; + +export type { InferencePublicSetup, InferencePublicStart }; + +export const plugin: PluginInitializer< + InferencePublicSetup, + InferencePublicStart, + InferenceSetupDependencies, + InferenceStartDependencies +> = (pluginInitializerContext: PluginInitializerContext) => + new InferencePlugin(pluginInitializerContext); diff --git a/x-pack/plugins/inference/public/plugin.tsx b/x-pack/plugins/inference/public/plugin.tsx new file mode 100644 index 0000000000000..9785efb7a8874 --- /dev/null +++ b/x-pack/plugins/inference/public/plugin.tsx @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/public'; +import type { Logger } from '@kbn/logging'; +import { createOutputApi } from '../common/output/create_output_api'; +import { createChatCompleteApi } from './chat_complete'; +import type { + ConfigSchema, + InferencePublicSetup, + InferencePublicStart, + InferenceSetupDependencies, + InferenceStartDependencies, +} from './types'; + +export class InferencePlugin + implements + Plugin< + InferencePublicSetup, + InferencePublicStart, + InferenceSetupDependencies, + InferenceStartDependencies + > +{ + logger: Logger; + + constructor(context: PluginInitializerContext) { + this.logger = context.logger.get(); + } + setup( + coreSetup: CoreSetup, + pluginsSetup: InferenceSetupDependencies + ): InferencePublicSetup { + return {}; + } + + start(coreStart: CoreStart, pluginsStart: InferenceStartDependencies): InferencePublicStart { + const chatComplete = createChatCompleteApi({ http: coreStart.http }); + return { + chatComplete, + output: createOutputApi(chatComplete), + getConnectors: () => { + return coreStart.http.get('/internal/inference/connectors'); + }, + }; + } +} diff --git a/x-pack/plugins/inference/public/types.ts b/x-pack/plugins/inference/public/types.ts new file mode 100644 index 0000000000000..df80256679ab4 --- /dev/null +++ b/x-pack/plugins/inference/public/types.ts @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import type { ChatCompleteAPI } from '../common/chat_complete'; +import type { InferenceConnector } from '../common/connectors'; +import type { OutputAPI } from '../common/output'; + +/* eslint-disable @typescript-eslint/no-empty-interface*/ + +export interface ConfigSchema {} + +export interface InferenceSetupDependencies {} + +export interface InferenceStartDependencies {} + +export interface InferencePublicSetup {} + +export interface InferencePublicStart { + chatComplete: ChatCompleteAPI; + output: OutputAPI; + getConnectors: () => Promise; +} diff --git a/x-pack/plugins/inference/public/util/create_observable_from_http_response.ts b/x-pack/plugins/inference/public/util/create_observable_from_http_response.ts new file mode 100644 index 0000000000000..09e9b9b2d5f5e --- /dev/null +++ b/x-pack/plugins/inference/public/util/create_observable_from_http_response.ts @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { createParser } from 'eventsource-parser'; +import { Observable, throwError } from 'rxjs'; +import { createInferenceInternalError } from '../../common/errors'; + +export interface StreamedHttpResponse { + response?: { body: ReadableStream | null | undefined }; +} + +export function createObservableFromHttpResponse( + response: StreamedHttpResponse +): Observable { + const rawResponse = response.response; + + const body = rawResponse?.body; + if (!body) { + return throwError(() => { + throw createInferenceInternalError(`No readable stream found in response`); + }); + } + + return new Observable((subscriber) => { + const parser = createParser((event) => { + if (event.type === 'event') { + subscriber.next(event.data); + } + }); + + const readStream = async () => { + const reader = body.getReader(); + const decoder = new TextDecoder(); + + // Function to process each chunk + const processChunk = ({ + done, + value, + }: ReadableStreamReadResult): Promise => { + if (done) { + return Promise.resolve(); + } + + parser.feed(decoder.decode(value, { stream: true })); + + return reader.read().then(processChunk); + }; + + // Start reading the stream + return reader.read().then(processChunk); + }; + + readStream() + .then(() => { + subscriber.complete(); + }) + .catch((error) => { + subscriber.error(error); + }); + }); +} diff --git a/x-pack/plugins/inference/public/util/http_response_into_observable.test.ts b/x-pack/plugins/inference/public/util/http_response_into_observable.test.ts new file mode 100644 index 0000000000000..f50ec402bdd77 --- /dev/null +++ b/x-pack/plugins/inference/public/util/http_response_into_observable.test.ts @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { lastValueFrom, of, toArray } from 'rxjs'; +import { httpResponseIntoObservable } from './http_response_into_observable'; +import type { StreamedHttpResponse } from './create_observable_from_http_response'; +import { ChatCompletionEventType } from '../../common/chat_complete'; +import { InferenceTaskEventType } from '../../common/tasks'; +import { InferenceTaskErrorCode } from '../../common/errors'; + +function toSse(...events: Array>) { + return events.map((event) => new TextEncoder().encode(`data: ${JSON.stringify(event)}\n\n`)); +} + +describe('httpResponseIntoObservable', () => { + it('parses SSE output', async () => { + const events = [ + { + type: ChatCompletionEventType.ChatCompletionChunk, + content: 'Hello', + }, + { + type: ChatCompletionEventType.ChatCompletionChunk, + content: 'Hello again', + }, + ]; + + const messages = await lastValueFrom( + of({ + response: { + // @ts-expect-error + body: ReadableStream.from(toSse(...events)), + }, + }).pipe(httpResponseIntoObservable(), toArray()) + ); + + expect(messages).toEqual(events); + }); + + it('throws serialized errors', async () => { + const events = [ + { + type: InferenceTaskEventType.error, + error: { + code: InferenceTaskErrorCode.internalError, + message: 'Internal error', + }, + }, + ]; + + await expect(async () => { + await lastValueFrom( + of({ + response: { + // @ts-expect-error + body: ReadableStream.from(toSse(...events)), + }, + }).pipe(httpResponseIntoObservable(), toArray()) + ); + }).rejects.toThrowError(`Internal error`); + }); +}); diff --git a/x-pack/plugins/inference/public/util/http_response_into_observable.ts b/x-pack/plugins/inference/public/util/http_response_into_observable.ts new file mode 100644 index 0000000000000..5b0929762e25d --- /dev/null +++ b/x-pack/plugins/inference/public/util/http_response_into_observable.ts @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { map, OperatorFunction, pipe, switchMap, tap } from 'rxjs'; +import { InferenceTaskEvent, InferenceTaskEventType } from '../../common/tasks'; +import { + createObservableFromHttpResponse, + StreamedHttpResponse, +} from './create_observable_from_http_response'; +import { + createInferenceInternalError, + InferenceTaskError, + InferenceTaskErrorEvent, +} from '../../common/errors'; + +export function httpResponseIntoObservable< + T extends InferenceTaskEvent = never +>(): OperatorFunction { + return pipe( + switchMap((response) => createObservableFromHttpResponse(response)), + map((line): T => { + try { + return JSON.parse(line); + } catch (error) { + throw createInferenceInternalError(`Failed to parse JSON`); + } + }), + tap((event) => { + if (event.type === InferenceTaskEventType.error) { + const errorEvent = event as unknown as InferenceTaskErrorEvent; + throw new InferenceTaskError( + errorEvent.error.code, + errorEvent.error.message, + errorEvent.error.meta + ); + } + }) + ); +} diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/chunks_into_message.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/chunks_into_message.test.ts new file mode 100644 index 0000000000000..a3fdcd33bee94 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/chunks_into_message.test.ts @@ -0,0 +1,267 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { lastValueFrom, of } from 'rxjs'; +import { + ChatCompletionChunkEvent, + ChatCompletionEventType, + ChatCompletionTokenCountEvent, +} from '../../../common/chat_complete'; +import { ToolChoiceType } from '../../../common/chat_complete/tools'; +import { chunksIntoMessage } from './chunks_into_message'; + +describe('chunksIntoMessage', () => { + function fromEvents(...events: Array) { + return of(...events); + } + + it('concatenates content chunks into a single message', async () => { + const message = await lastValueFrom( + chunksIntoMessage({})( + fromEvents( + { + content: 'Hey', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [], + }, + { + content: ' how is it', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [], + }, + { + content: ' going', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [], + } + ) + ) + ); + + expect(message).toEqual({ + content: 'Hey how is it going', + toolCalls: [], + type: ChatCompletionEventType.ChatCompletionMessage, + }); + }); + + it('parses tool calls', async () => { + const message = await lastValueFrom( + chunksIntoMessage({ + toolChoice: ToolChoiceType.auto, + tools: { + myFunction: { + description: 'myFunction', + schema: { + type: 'object', + properties: { + foo: { + type: 'string', + const: 'bar', + }, + }, + }, + }, + }, + })( + fromEvents( + { + content: '', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [ + { + function: { + name: 'myFunction', + arguments: '', + }, + index: 0, + toolCallId: '0', + }, + ], + }, + { + content: '', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [ + { + function: { + name: '', + arguments: '{', + }, + index: 0, + toolCallId: '0', + }, + ], + }, + { + content: '', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [ + { + function: { + name: '', + arguments: '"foo": "bar" }', + }, + index: 0, + toolCallId: '1', + }, + ], + } + ) + ) + ); + + expect(message).toEqual({ + content: '', + toolCalls: [ + { + function: { + name: 'myFunction', + arguments: { + foo: 'bar', + }, + }, + toolCallId: '001', + }, + ], + type: ChatCompletionEventType.ChatCompletionMessage, + }); + }); + + it('validates tool calls', async () => { + async function getMessage() { + return await lastValueFrom( + chunksIntoMessage({ + toolChoice: ToolChoiceType.auto, + tools: { + myFunction: { + description: 'myFunction', + schema: { + type: 'object', + properties: { + foo: { + type: 'string', + const: 'bar', + }, + }, + }, + }, + }, + })( + fromEvents({ + content: '', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [ + { + function: { + name: 'myFunction', + arguments: '{ "foo": "baz" }', + }, + index: 0, + toolCallId: '001', + }, + ], + }) + ) + ); + } + + await expect(async () => getMessage()).rejects.toThrowErrorMatchingInlineSnapshot( + `"Tool call arguments for myFunction were invalid"` + ); + }); + + it('concatenates multiple tool calls into a single message', async () => { + const message = await lastValueFrom( + chunksIntoMessage({ + toolChoice: ToolChoiceType.auto, + tools: { + myFunction: { + description: 'myFunction', + schema: { + type: 'object', + properties: { + foo: { + type: 'string', + }, + }, + }, + }, + }, + })( + fromEvents( + { + content: '', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [ + { + function: { + name: 'myFunction', + arguments: '', + }, + index: 0, + toolCallId: '001', + }, + ], + }, + { + content: '', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [ + { + function: { + name: '', + arguments: '{"foo": "bar"}', + }, + index: 0, + toolCallId: '', + }, + ], + }, + { + content: '', + type: ChatCompletionEventType.ChatCompletionChunk, + tool_calls: [ + { + function: { + name: 'myFunction', + arguments: '{ "foo": "baz" }', + }, + index: 1, + toolCallId: '002', + }, + ], + } + ) + ) + ); + + expect(message).toEqual({ + content: '', + toolCalls: [ + { + function: { + name: 'myFunction', + arguments: { + foo: 'bar', + }, + }, + toolCallId: '001', + }, + { + function: { + name: 'myFunction', + arguments: { + foo: 'baz', + }, + }, + toolCallId: '002', + }, + ], + type: ChatCompletionEventType.ChatCompletionMessage, + }); + }); +}); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/chunks_into_message.ts b/x-pack/plugins/inference/server/chat_complete/adapters/chunks_into_message.ts new file mode 100644 index 0000000000000..786a4c4ff7fb3 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/chunks_into_message.ts @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { last, map, merge, OperatorFunction, scan, share } from 'rxjs'; +import type { UnvalidatedToolCall, ToolOptions } from '../../../common/chat_complete/tools'; +import { + ChatCompletionChunkEvent, + ChatCompletionEventType, + ChatCompletionMessageEvent, + ChatCompletionTokenCountEvent, +} from '../../../common/chat_complete'; +import { withoutTokenCountEvents } from '../../../common/chat_complete/without_token_count_events'; +import { validateToolCalls } from '../../util/validate_tool_calls'; + +export function chunksIntoMessage( + toolOptions: TToolOptions +): OperatorFunction< + ChatCompletionChunkEvent | ChatCompletionTokenCountEvent, + | ChatCompletionChunkEvent + | ChatCompletionTokenCountEvent + | ChatCompletionMessageEvent +> { + return (chunks$) => { + const shared$ = chunks$.pipe(share()); + + return merge( + shared$, + shared$.pipe( + withoutTokenCountEvents(), + scan( + (prev, chunk) => { + prev.content += chunk.content ?? ''; + + chunk.tool_calls?.forEach((toolCall) => { + let prevToolCall = prev.tool_calls[toolCall.index]; + if (!prevToolCall) { + prev.tool_calls[toolCall.index] = { + function: { + name: '', + arguments: '', + }, + toolCallId: '', + }; + + prevToolCall = prev.tool_calls[toolCall.index]; + } + + prevToolCall.function.name += toolCall.function.name; + prevToolCall.function.arguments += toolCall.function.arguments; + prevToolCall.toolCallId += toolCall.toolCallId; + }); + + return prev; + }, + { + content: '', + tool_calls: [] as UnvalidatedToolCall[], + } + ), + last(), + map((concatenatedChunk): ChatCompletionMessageEvent => { + const validatedToolCalls = validateToolCalls({ + ...toolOptions, + toolCalls: concatenatedChunk.tool_calls, + }); + + return { + type: ChatCompletionEventType.ChatCompletionMessage, + content: concatenatedChunk.content, + toolCalls: validatedToolCalls, + }; + }) + ) + ); + }; +} diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/event_source_stream_into_observable.ts b/x-pack/plugins/inference/server/chat_complete/adapters/event_source_stream_into_observable.ts new file mode 100644 index 0000000000000..ece32d76222cc --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/event_source_stream_into_observable.ts @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { createParser } from 'eventsource-parser'; +import { Readable } from 'node:stream'; +import { Observable } from 'rxjs'; + +export function eventSourceStreamIntoObservable(readable: Readable) { + return new Observable((subscriber) => { + const parser = createParser((event) => { + if (event.type === 'event') { + subscriber.next(event.data); + } + }); + + async function processStream() { + for await (const chunk of readable) { + parser.feed(chunk.toString()); + } + } + + processStream().then( + () => { + subscriber.complete(); + }, + (error) => { + subscriber.error(error); + } + ); + }); +} diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.test.ts new file mode 100644 index 0000000000000..7f55f8a8faa48 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.test.ts @@ -0,0 +1,382 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import OpenAI from 'openai'; +import { openAIAdapter } from '.'; +import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client'; +import { ChatCompletionEventType, MessageRole } from '../../../../common/chat_complete'; +import { PassThrough } from 'stream'; +import { pick } from 'lodash'; +import { lastValueFrom, Subject, toArray } from 'rxjs'; +import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream'; +import { v4 } from 'uuid'; + +function createOpenAIChunk({ + delta, + usage, +}: { + delta: OpenAI.ChatCompletionChunk['choices'][number]['delta']; + usage?: OpenAI.ChatCompletionChunk['usage']; +}): OpenAI.ChatCompletionChunk { + return { + choices: [ + { + finish_reason: null, + index: 0, + delta, + }, + ], + created: new Date().getTime(), + id: v4(), + model: 'gpt-4o', + object: 'chat.completion.chunk', + usage, + }; +} + +describe('openAIAdapter', () => { + const actionsClientMock = { + execute: jest.fn(), + } as ActionsClient & { execute: jest.MockedFn }; + + beforeEach(() => { + actionsClientMock.execute.mockReset(); + }); + + const defaultArgs = { + connector: { + id: 'foo', + actionTypeId: '.gen-ai', + name: 'OpenAI', + isPreconfigured: false, + isDeprecated: false, + isSystemAction: false, + }, + actionsClient: actionsClientMock, + }; + + describe('when creating the request', () => { + function getRequest() { + const params = actionsClientMock.execute.mock.calls[0][0].params.subActionParams as Record< + string, + any + >; + + return { stream: params.stream, body: JSON.parse(params.body) }; + } + + beforeEach(() => { + actionsClientMock.execute.mockImplementation(async () => { + return { + actionId: '', + status: 'ok', + data: new PassThrough(), + }; + }); + }); + it('correctly formats messages ', () => { + openAIAdapter.chatComplete({ + ...defaultArgs, + system: 'system', + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + { + role: MessageRole.Assistant, + content: 'answer', + }, + { + role: MessageRole.User, + content: 'another question', + }, + ], + }); + + expect(getRequest().body.messages).toEqual([ + { + content: 'system', + role: 'system', + }, + { + content: 'question', + role: 'user', + }, + { + content: 'answer', + role: 'assistant', + }, + { + content: 'another question', + role: 'user', + }, + ]); + }); + + it('correctly formats tools and tool choice', () => { + openAIAdapter.chatComplete({ + ...defaultArgs, + system: 'system', + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + { + role: MessageRole.Assistant, + content: 'answer', + toolCalls: [ + { + function: { + name: 'my_function', + arguments: { + foo: 'bar', + }, + }, + toolCallId: '0', + }, + ], + }, + { + role: MessageRole.Tool, + toolCallId: '0', + response: { + bar: 'foo', + }, + }, + ], + toolChoice: { function: 'myFunction' }, + tools: { + myFunction: { + description: 'myFunction', + }, + myFunctionWithArgs: { + description: 'myFunctionWithArgs', + schema: { + type: 'object', + properties: { + foo: { + type: 'string', + description: 'foo', + }, + }, + required: ['foo'], + }, + }, + }, + }); + + expect(pick(getRequest().body, 'messages', 'tools', 'tool_choice')).toEqual({ + messages: [ + { + content: 'system', + role: 'system', + }, + { + content: 'question', + role: 'user', + }, + { + content: 'answer', + role: 'assistant', + tool_calls: [ + { + function: { + name: 'my_function', + arguments: JSON.stringify({ foo: 'bar' }), + }, + id: '0', + type: 'function', + }, + ], + }, + { + role: 'tool', + tool_call_id: '0', + content: JSON.stringify({ bar: 'foo' }), + }, + ], + tools: [ + { + function: { + name: 'myFunction', + description: 'myFunction', + parameters: { + type: 'object', + properties: {}, + }, + }, + type: 'function', + }, + { + function: { + name: 'myFunctionWithArgs', + description: 'myFunctionWithArgs', + parameters: { + type: 'object', + properties: { + foo: { + type: 'string', + description: 'foo', + }, + }, + required: ['foo'], + }, + }, + type: 'function', + }, + ], + tool_choice: { + function: { + name: 'myFunction', + }, + type: 'function', + }, + }); + }); + + it('always sets streaming to true', () => { + openAIAdapter.chatComplete({ + ...defaultArgs, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + }); + + expect(getRequest().stream).toBe(true); + expect(getRequest().body.stream).toBe(true); + }); + }); + + describe('when handling the response', () => { + let source$: Subject>; + + beforeEach(() => { + source$ = new Subject>(); + + actionsClientMock.execute.mockImplementation(async () => { + return { + actionId: '', + status: 'ok', + data: observableIntoEventSourceStream(source$), + }; + }); + }); + + it('emits chunk events', async () => { + const response$ = openAIAdapter.chatComplete({ + ...defaultArgs, + messages: [ + { + role: MessageRole.User, + content: 'Hello', + }, + ], + }); + + source$.next( + createOpenAIChunk({ + delta: { + content: 'First', + }, + }) + ); + + source$.next( + createOpenAIChunk({ + delta: { + content: ', second', + }, + }) + ); + + source$.complete(); + + const allChunks = await lastValueFrom(response$.pipe(toArray())); + + expect(allChunks).toEqual([ + { + content: 'First', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + { + content: ', second', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + ]); + }); + + it('emits token events', async () => { + const response$ = openAIAdapter.chatComplete({ + ...defaultArgs, + messages: [ + { + role: MessageRole.User, + content: 'Hello', + }, + ], + }); + + source$.next( + createOpenAIChunk({ + delta: { + content: 'First', + }, + }) + ); + + source$.next( + createOpenAIChunk({ + delta: { + tool_calls: [ + { + index: 0, + id: '0', + function: { + name: 'my_function', + arguments: '{}', + }, + }, + ], + }, + }) + ); + + source$.complete(); + + const allChunks = await lastValueFrom(response$.pipe(toArray())); + + expect(allChunks).toEqual([ + { + content: 'First', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + { + content: '', + tool_calls: [ + { + function: { + name: 'my_function', + arguments: '{}', + }, + index: 0, + toolCallId: '0', + }, + ], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + ]); + }); + }); +}); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.ts b/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.ts new file mode 100644 index 0000000000000..c811ed9f400ea --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.ts @@ -0,0 +1,182 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import OpenAI from 'openai'; +import type { + ChatCompletionAssistantMessageParam, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionUserMessageParam, +} from 'openai/resources'; +import { filter, from, map, switchMap, tap } from 'rxjs'; +import { Readable } from 'stream'; +import { + ChatCompletionChunkEvent, + ChatCompletionEventType, + Message, + MessageRole, +} from '../../../../common/chat_complete'; +import { createTokenLimitReachedError } from '../../../../common/chat_complete/errors'; +import { createInferenceInternalError } from '../../../../common/errors'; +import { InferenceConnectorAdapter } from '../../types'; +import { eventSourceStreamIntoObservable } from '../event_source_stream_into_observable'; + +export const openAIAdapter: InferenceConnectorAdapter = { + chatComplete: ({ connector, actionsClient, system, messages, toolChoice, tools }) => { + const openAIMessages = messagesToOpenAI({ system, messages }); + + const toolChoiceForOpenAI = + typeof toolChoice === 'string' + ? toolChoice + : toolChoice + ? { + function: { + name: toolChoice.function, + }, + type: 'function' as const, + } + : undefined; + + const stream = true; + + const request: Omit & { model?: string } = { + stream, + messages: openAIMessages, + temperature: 0, + tool_choice: toolChoiceForOpenAI, + tools: tools + ? Object.entries(tools).map(([toolName, { description, schema }]) => { + return { + type: 'function', + function: { + name: toolName, + description, + parameters: (schema ?? { + type: 'object' as const, + properties: {}, + }) as unknown as Record, + }, + }; + }) + : undefined, + }; + + return from( + actionsClient.execute({ + actionId: connector.id, + params: { + subAction: 'stream', + subActionParams: { + body: JSON.stringify(request), + stream, + }, + }, + }) + ).pipe( + switchMap((response) => { + const readable = response.data as Readable; + return eventSourceStreamIntoObservable(readable); + }), + filter((line) => !!line && line !== '[DONE]'), + map( + (line) => JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } } + ), + tap((line) => { + if ('error' in line) { + throw createInferenceInternalError(line.error.message); + } + if ( + 'choices' in line && + line.choices.length && + line.choices[0].finish_reason === 'length' + ) { + throw createTokenLimitReachedError(); + } + }), + filter( + (line): line is OpenAI.ChatCompletionChunk => + 'object' in line && line.object === 'chat.completion.chunk' + ), + map((chunk): ChatCompletionChunkEvent => { + const delta = chunk.choices[0].delta; + + return { + content: delta.content ?? '', + tool_calls: + delta.tool_calls?.map((toolCall) => { + return { + function: { + name: toolCall.function?.name ?? '', + arguments: toolCall.function?.arguments ?? '', + }, + toolCallId: toolCall.id ?? '', + index: toolCall.index, + }; + }) ?? [], + type: ChatCompletionEventType.ChatCompletionChunk, + }; + }) + ); + }, +}; + +function messagesToOpenAI({ + system, + messages, +}: { + system?: string; + messages: Message[]; +}): OpenAI.ChatCompletionMessageParam[] { + const systemMessage: ChatCompletionSystemMessageParam | undefined = system + ? { role: 'system', content: system } + : undefined; + + return [ + ...(systemMessage ? [systemMessage] : []), + ...messages.map((message): ChatCompletionMessageParam => { + const role = message.role; + + switch (role) { + case MessageRole.Assistant: + const assistantMessage: ChatCompletionAssistantMessageParam = { + role: 'assistant', + content: message.content, + tool_calls: message.toolCalls?.map((toolCall) => { + return { + function: { + name: toolCall.function.name, + arguments: + 'arguments' in toolCall.function + ? JSON.stringify(toolCall.function.arguments) + : '{}', + }, + id: toolCall.toolCallId, + type: 'function', + }; + }), + }; + return assistantMessage; + + case MessageRole.User: + const userMessage: ChatCompletionUserMessageParam = { + role: 'user', + content: message.content, + }; + return userMessage; + + case MessageRole.Tool: + const toolMessage: ChatCompletionToolMessageParam = { + role: 'tool', + content: JSON.stringify(message.response), + tool_call_id: message.toolCallId, + }; + return toolMessage; + } + }), + ]; +} diff --git a/x-pack/plugins/inference/server/chat_complete/index.ts b/x-pack/plugins/inference/server/chat_complete/index.ts new file mode 100644 index 0000000000000..e30afb58ca25a --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/index.ts @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { KibanaRequest } from '@kbn/core-http-server'; +import { defer, switchMap, throwError } from 'rxjs'; +import type { ChatCompleteAPI, ChatCompletionResponse } from '../../common/chat_complete'; +import type { ToolOptions } from '../../common/chat_complete/tools'; +import { InferenceConnectorType } from '../../common/connectors'; +import { createInferenceRequestError } from '../../common/errors'; +import type { InferenceStartDependencies } from '../types'; +import { chunksIntoMessage } from './adapters/chunks_into_message'; +import { openAIAdapter } from './adapters/openai'; + +export function createChatCompleteApi({ + request, + actions, +}: { + request: KibanaRequest; + actions: InferenceStartDependencies['actions']; +}) { + const chatCompleteAPI: ChatCompleteAPI = ({ + connectorId, + messages, + toolChoice, + tools, + system, + }): ChatCompletionResponse => { + return defer(async () => { + const actionsClient = await actions.getActionsClientWithRequest(request); + + const connector = await actionsClient.get({ id: connectorId, throwIfSystemAction: true }); + + return { actionsClient, connector }; + }).pipe( + switchMap(({ actionsClient, connector }) => { + switch (connector.actionTypeId) { + case InferenceConnectorType.OpenAI: + return openAIAdapter.chatComplete({ + system, + connector, + actionsClient, + messages, + toolChoice, + tools, + }); + + case InferenceConnectorType.Bedrock: + break; + + case InferenceConnectorType.Gemini: + break; + } + + return throwError(() => + createInferenceRequestError( + `Adapter for type ${connector.actionTypeId} not implemented`, + 400 + ) + ); + }), + chunksIntoMessage({ + toolChoice, + tools, + }) + ); + }; + + return chatCompleteAPI; +} diff --git a/x-pack/plugins/inference/server/chat_complete/types.ts b/x-pack/plugins/inference/server/chat_complete/types.ts new file mode 100644 index 0000000000000..6c89df1498646 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/types.ts @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { ActionsClient } from '@kbn/actions-plugin/server'; +import type { Observable } from 'rxjs'; +import type { + ChatCompleteAPI, + ChatCompletionChunkEvent, + ChatCompletionTokenCountEvent, +} from '../../common/chat_complete'; + +type Connector = Awaited>; + +export interface InferenceConnectorAdapter { + chatComplete: ( + options: Omit[0], 'connectorId'> & { + actionsClient: ActionsClient; + connector: Connector; + } + ) => Observable; +} diff --git a/x-pack/plugins/inference/server/config.ts b/x-pack/plugins/inference/server/config.ts new file mode 100644 index 0000000000000..f4cd1f886581b --- /dev/null +++ b/x-pack/plugins/inference/server/config.ts @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { schema, type TypeOf } from '@kbn/config-schema'; + +export const config = schema.object({ + enabled: schema.boolean({ defaultValue: true }), +}); + +export type InferenceConfig = TypeOf; diff --git a/x-pack/plugins/inference/server/index.ts b/x-pack/plugins/inference/server/index.ts new file mode 100644 index 0000000000000..721aa05d06023 --- /dev/null +++ b/x-pack/plugins/inference/server/index.ts @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import type { PluginInitializer, PluginInitializerContext } from '@kbn/core/server'; +import type { InferenceConfig } from './config'; +import { InferencePlugin } from './plugin'; +import type { + InferenceServerSetup, + InferenceServerStart, + InferenceSetupDependencies, + InferenceStartDependencies, +} from './types'; + +export { withoutTokenCountEvents } from '../common/chat_complete/without_token_count_events'; +export { withoutChunkEvents } from '../common/chat_complete/without_chunk_events'; +export { withoutOutputUpdateEvents } from '../common/output/without_output_update_events'; + +export type { InferenceServerSetup, InferenceServerStart }; + +export const plugin: PluginInitializer< + InferenceServerSetup, + InferenceServerStart, + InferenceSetupDependencies, + InferenceStartDependencies +> = async (pluginInitializerContext: PluginInitializerContext) => + new InferencePlugin(pluginInitializerContext); diff --git a/x-pack/plugins/inference/server/inference_client/index.ts b/x-pack/plugins/inference/server/inference_client/index.ts new file mode 100644 index 0000000000000..3c25cf29f6280 --- /dev/null +++ b/x-pack/plugins/inference/server/inference_client/index.ts @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { KibanaRequest } from '@kbn/core-http-server'; +import { ActionsClient } from '@kbn/actions-plugin/server'; +import { isSupportedConnectorType } from '../../common/connectors'; +import { createInferenceRequestError } from '../../common/errors'; +import { createChatCompleteApi } from '../chat_complete'; +import type { InferenceClient, InferenceStartDependencies } from '../types'; +import { createOutputApi } from '../../common/output/create_output_api'; + +export function createInferenceClient({ + request, + actions, +}: { request: KibanaRequest } & Pick): InferenceClient { + const chatComplete = createChatCompleteApi({ request, actions }); + return { + chatComplete, + output: createOutputApi(chatComplete), + getConnectorById: async (id: string) => { + const actionsClient = await actions.getActionsClientWithRequest(request); + let connector: Awaited>; + + try { + connector = await actionsClient.get({ + id, + throwIfSystemAction: true, + }); + } catch (error) { + throw createInferenceRequestError(`No connector found for id ${id}`, 400); + } + + const actionTypeId = connector.id; + + if (!isSupportedConnectorType(actionTypeId)) { + throw createInferenceRequestError( + `Type ${actionTypeId} not recognized as a supported connector type`, + 400 + ); + } + + return { + connectorId: connector.id, + name: connector.name, + type: actionTypeId, + }; + }, + }; +} diff --git a/x-pack/plugins/inference/server/plugin.ts b/x-pack/plugins/inference/server/plugin.ts new file mode 100644 index 0000000000000..26c56209df8ce --- /dev/null +++ b/x-pack/plugins/inference/server/plugin.ts @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/server'; +import type { Logger } from '@kbn/logging'; +import { createInferenceClient } from './inference_client'; +import { registerChatCompleteRoute } from './routes/chat_complete'; +import { registerConnectorsRoute } from './routes/connectors'; +import type { + ConfigSchema, + InferenceServerSetup, + InferenceServerStart, + InferenceSetupDependencies, + InferenceStartDependencies, +} from './types'; + +export class InferencePlugin + implements + Plugin< + InferenceServerSetup, + InferenceServerStart, + InferenceSetupDependencies, + InferenceStartDependencies + > +{ + logger: Logger; + + constructor(context: PluginInitializerContext) { + this.logger = context.logger.get(); + } + setup( + coreSetup: CoreSetup, + pluginsSetup: InferenceSetupDependencies + ): InferenceServerSetup { + const router = coreSetup.http.createRouter(); + + registerChatCompleteRoute({ + router, + coreSetup, + }); + + registerConnectorsRoute({ + router, + coreSetup, + }); + return {}; + } + + start(core: CoreStart, pluginsStart: InferenceStartDependencies): InferenceServerStart { + return { + getClient: ({ request }) => { + return createInferenceClient({ request, actions: pluginsStart.actions }); + }, + }; + } +} diff --git a/x-pack/plugins/inference/server/routes/chat_complete.ts b/x-pack/plugins/inference/server/routes/chat_complete.ts new file mode 100644 index 0000000000000..6c840f80466c2 --- /dev/null +++ b/x-pack/plugins/inference/server/routes/chat_complete.ts @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { schema, Type } from '@kbn/config-schema'; +import type { CoreSetup, IRouter, RequestHandlerContext } from '@kbn/core/server'; +import { isObservable } from 'rxjs'; +import { MessageRole } from '../../common/chat_complete'; +import type { ChatCompleteRequestBody } from '../../common/chat_complete/request'; +import { ToolCall, ToolChoiceType } from '../../common/chat_complete/tools'; +import { createInferenceClient } from '../inference_client'; +import { InferenceServerStart, InferenceStartDependencies } from '../types'; +import { observableIntoEventSourceStream } from '../util/observable_into_event_source_stream'; + +const toolCallSchema: Type = schema.arrayOf( + schema.object({ + toolCallId: schema.string(), + function: schema.object({ + name: schema.string(), + arguments: schema.maybe(schema.object({}, { unknowns: 'allow' })), + }), + }) +); + +const chatCompleteBodySchema: Type = schema.object({ + connectorId: schema.string(), + system: schema.maybe(schema.string()), + tools: schema.maybe( + schema.recordOf( + schema.string(), + schema.object({ + description: schema.string(), + schema: schema.maybe( + schema.object({ + type: schema.literal('object'), + properties: schema.recordOf(schema.string(), schema.any()), + required: schema.maybe(schema.arrayOf(schema.string())), + }) + ), + }) + ) + ), + toolChoice: schema.maybe( + schema.oneOf([ + schema.literal(ToolChoiceType.auto), + schema.literal(ToolChoiceType.none), + schema.literal(ToolChoiceType.required), + schema.object({ + function: schema.string(), + }), + ]) + ), + messages: schema.arrayOf( + schema.oneOf([ + schema.object({ + role: schema.literal(MessageRole.Assistant), + content: schema.string(), + toolCalls: toolCallSchema, + }), + schema.object({ + role: schema.literal(MessageRole.User), + content: schema.string(), + name: schema.maybe(schema.string()), + }), + schema.object({ + role: schema.literal(MessageRole.Tool), + toolCallId: schema.string(), + response: schema.object({}, { unknowns: 'allow' }), + }), + ]) + ), +}); + +export function registerChatCompleteRoute({ + coreSetup, + router, +}: { + coreSetup: CoreSetup; + router: IRouter; +}) { + router.post( + { + path: '/internal/inference/chat_complete', + validate: { + body: chatCompleteBodySchema, + }, + }, + async (context, request, response) => { + const actions = await coreSetup + .getStartServices() + .then(([coreStart, pluginsStart]) => pluginsStart.actions); + + const client = createInferenceClient({ request, actions }); + + const { connectorId, messages, system, toolChoice, tools } = request.body; + + const chatCompleteResponse = await client.chatComplete({ + connectorId, + messages, + system, + toolChoice, + tools, + }); + + if (isObservable(chatCompleteResponse)) { + return response.ok({ + body: observableIntoEventSourceStream(chatCompleteResponse), + }); + } + + return response.ok({ body: chatCompleteResponse }); + } + ); +} diff --git a/x-pack/plugins/inference/server/routes/connectors.ts b/x-pack/plugins/inference/server/routes/connectors.ts new file mode 100644 index 0000000000000..8c69b68d55f14 --- /dev/null +++ b/x-pack/plugins/inference/server/routes/connectors.ts @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { CoreSetup, IRouter, RequestHandlerContext } from '@kbn/core/server'; +import { InferenceConnector, InferenceConnectorType } from '../../common/connectors'; +import type { InferenceServerStart, InferenceStartDependencies } from '../types'; + +export function registerConnectorsRoute({ + coreSetup, + router, +}: { + coreSetup: CoreSetup; + router: IRouter; +}) { + router.get( + { + path: '/internal/inference/connectors', + validate: {}, + }, + async (_context, request, response) => { + const actions = await coreSetup + .getStartServices() + .then(([_coreStart, pluginsStart]) => pluginsStart.actions); + + const client = await actions.getActionsClientWithRequest(request); + + const allConnectors = await client.getAll({ + includeSystemActions: false, + }); + + const connectorTypes: string[] = [ + InferenceConnectorType.OpenAI, + InferenceConnectorType.Bedrock, + InferenceConnectorType.Gemini, + ]; + + const connectors: InferenceConnector[] = allConnectors + .filter((connector) => connectorTypes.includes(connector.actionTypeId)) + .map((connector) => { + return { + connectorId: connector.id, + name: connector.name, + type: connector.actionTypeId as InferenceConnectorType, + }; + }); + + return response.ok({ body: { connectors } }); + } + ); +} diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts new file mode 100644 index 0000000000000..00ce53fe5d288 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { switchMap, map } from 'rxjs'; +import { MessageRole } from '../../../common/chat_complete'; +import { ToolOptions } from '../../../common/chat_complete/tools'; +import { withoutChunkEvents } from '../../../common/chat_complete/without_chunk_events'; +import { withoutTokenCountEvents } from '../../../common/chat_complete/without_token_count_events'; +import { createOutputCompleteEvent } from '../../../common/output'; +import { withoutOutputUpdateEvents } from '../../../common/output/without_output_update_events'; +import { InferenceClient } from '../../types'; + +const ESQL_SYSTEM_MESSAGE = ''; + +async function getEsqlDocuments(documents: string[]) { + return [ + { + document: 'my-esql-function', + text: 'My ES|QL function', + }, + ]; +} + +export function naturalLanguageToEsql({ + client, + input, + connectorId, + tools, + toolChoice, +}: { + client: InferenceClient; + input: string; + connectorId: string; +} & TToolOptions) { + return client + .output('request_documentation', { + connectorId, + system: ESQL_SYSTEM_MESSAGE, + input: `Based on the following input, request documentation + from the ES|QL handbook to help you get the right information + needed to generate a query: + ${input} + `, + schema: { + type: 'object', + properties: { + documents: { + type: 'array', + items: { + type: 'string', + }, + }, + }, + required: ['documents'], + } as const, + }) + .pipe( + withoutOutputUpdateEvents(), + switchMap((event) => { + return getEsqlDocuments(event.output.documents ?? []); + }), + switchMap((documents) => { + return client + .chatComplete({ + connectorId, + system: `${ESQL_SYSTEM_MESSAGE} + + The following documentation is provided: + + ${documents}`, + messages: [ + { + role: MessageRole.User, + content: input, + }, + ], + tools, + toolChoice, + }) + .pipe( + withoutTokenCountEvents(), + withoutChunkEvents(), + map((message) => { + return createOutputCompleteEvent('generated_query', { + content: message.content, + toolCalls: message.toolCalls, + }); + }) + ); + }), + withoutOutputUpdateEvents() + ); +} diff --git a/x-pack/plugins/inference/server/types.ts b/x-pack/plugins/inference/server/types.ts new file mode 100644 index 0000000000000..609b719b15236 --- /dev/null +++ b/x-pack/plugins/inference/server/types.ts @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { + PluginStartContract as ActionsPluginStart, + PluginSetupContract as ActionsPluginSetup, +} from '@kbn/actions-plugin/server'; +import type { KibanaRequest } from '@kbn/core-http-server'; +import { ChatCompleteAPI } from '../common/chat_complete'; +import { InferenceConnector } from '../common/connectors'; +import { OutputAPI } from '../common/output'; + +/* eslint-disable @typescript-eslint/no-empty-interface*/ + +export interface ConfigSchema {} + +export interface InferenceSetupDependencies { + actions: ActionsPluginSetup; +} + +export interface InferenceStartDependencies { + actions: ActionsPluginStart; +} + +export interface InferenceServerSetup {} + +export interface InferenceClient { + /** + * `chatComplete` requests the LLM to generate a response to + * a prompt or conversation, which might be plain text + * or a tool call, or a combination of both. + */ + chatComplete: ChatCompleteAPI; + /** + * `output` asks the LLM to generate a structured (JSON) + * response based on a schema and a prompt or conversation. + */ + output: OutputAPI; + /** + * `getConnectorById` returns an inference connector by id. + * Non-inference connectors will throw an error. + */ + getConnectorById: (id: string) => Promise; +} + +interface InferenceClientCreateOptions { + request: KibanaRequest; +} + +export interface InferenceServerStart { + /** + * Creates an inference client, scoped to a request. + * + * @param options {@link InferenceClientCreateOptions} + * @returns {@link InferenceClient} + */ + getClient: (options: InferenceClientCreateOptions) => InferenceClient; +} diff --git a/x-pack/plugins/inference/server/util/observable_into_event_source_stream.test.ts b/x-pack/plugins/inference/server/util/observable_into_event_source_stream.test.ts new file mode 100644 index 0000000000000..d7972bb970317 --- /dev/null +++ b/x-pack/plugins/inference/server/util/observable_into_event_source_stream.test.ts @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { createParser } from 'eventsource-parser'; +import { partition } from 'lodash'; +import { merge, of, throwError } from 'rxjs'; +import { InferenceTaskEvent } from '../../common/tasks'; +import { observableIntoEventSourceStream } from './observable_into_event_source_stream'; + +describe('observableIntoEventSourceStream', () => { + function renderStream(events: Array) { + const [inferenceEvents, errors] = partition( + events, + (event): event is T => !(event instanceof Error) + ); + + const source$ = merge(of(...inferenceEvents), ...errors.map((error) => throwError(error))); + + const stream = observableIntoEventSourceStream(source$); + + return new Promise((resolve, reject) => { + const chunks: string[] = []; + stream.on('data', (chunk) => { + chunks.push(chunk.toString()); + }); + stream.on('error', (error) => { + reject(error); + }); + stream.on('end', () => { + resolve(chunks); + }); + }); + } + + it('serializes error events', async () => { + const chunks = await renderStream([ + { + type: 'chunk', + }, + new Error('foo'), + ]); + + expect(chunks.map((chunk) => chunk.trim())).toEqual([ + `data: ${JSON.stringify({ type: 'chunk' })}`, + `data: ${JSON.stringify({ + type: 'error', + error: { code: 'internalError', message: 'foo' }, + })}`, + ]); + }); + + it('outputs data in SSE-compatible format', async () => { + const chunks = await renderStream([ + { + type: 'chunk', + id: 0, + }, + { + type: 'chunk', + id: 1, + }, + ]); + + const events: Array> = []; + + const parser = createParser((event) => { + if (event.type === 'event') { + events.push(JSON.parse(event.data)); + } + }); + + chunks.forEach((chunk) => { + parser.feed(chunk); + }); + + expect(events).toEqual([ + { + type: 'chunk', + id: 0, + }, + { + type: 'chunk', + id: 1, + }, + ]); + }); +}); diff --git a/x-pack/plugins/inference/server/util/observable_into_event_source_stream.ts b/x-pack/plugins/inference/server/util/observable_into_event_source_stream.ts new file mode 100644 index 0000000000000..2007b9842db69 --- /dev/null +++ b/x-pack/plugins/inference/server/util/observable_into_event_source_stream.ts @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { catchError, map, Observable, of } from 'rxjs'; +import { PassThrough } from 'stream'; +import { + InferenceTaskErrorCode, + InferenceTaskErrorEvent, + isInferenceError, +} from '../../common/errors'; +import { InferenceTaskEventType } from '../../common/tasks'; + +export function observableIntoEventSourceStream(source$: Observable) { + const withSerializedErrors$ = source$.pipe( + catchError((error): Observable => { + if (isInferenceError(error)) { + return of({ + type: InferenceTaskEventType.error, + error: { + code: error.code, + message: error.message, + meta: error.meta, + }, + }); + } + + return of({ + type: InferenceTaskEventType.error, + error: { + code: InferenceTaskErrorCode.internalError, + message: error.message as string, + }, + }); + }), + map((event) => { + return `data: ${JSON.stringify(event)}\n\n`; + }) + ); + + const stream = new PassThrough(); + + withSerializedErrors$.subscribe({ + next: (line) => { + stream.write(line); + }, + complete: () => { + stream.end(); + }, + error: (error) => { + stream.write( + `data: ${JSON.stringify({ + type: InferenceTaskEventType.error, + error: { + code: InferenceTaskErrorCode.internalError, + message: error.message, + }, + })}\n\n` + ); + stream.end(); + }, + }); + + return stream; +} diff --git a/x-pack/plugins/inference/server/util/validate_tool_calls.test.ts b/x-pack/plugins/inference/server/util/validate_tool_calls.test.ts new file mode 100644 index 0000000000000..96bf202fa236b --- /dev/null +++ b/x-pack/plugins/inference/server/util/validate_tool_calls.test.ts @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { isToolValidationError } from '../../common/chat_complete/errors'; +import { ToolChoiceType } from '../../common/chat_complete/tools'; +import { validateToolCalls } from './validate_tool_calls'; + +describe('validateToolCalls', () => { + it('throws an error if tools were called but toolChoice == none', () => { + expect(() => { + validateToolCalls({ + toolCalls: [ + { + function: { + name: 'my_function', + arguments: '{}', + }, + toolCallId: '1', + }, + ], + + toolChoice: ToolChoiceType.none, + tools: { + my_function: { + description: 'description', + }, + }, + }); + }).toThrowErrorMatchingInlineSnapshot( + `"tool_choice was \\"none\\" but my_function was/were called"` + ); + }); + + it('throws an error if an unknown tool was called', () => { + expect(() => + validateToolCalls({ + toolCalls: [ + { + function: { + name: 'my_unknown_function', + arguments: '{}', + }, + toolCallId: '1', + }, + ], + + tools: { + my_function: { + description: 'description', + }, + }, + }) + ).toThrowErrorMatchingInlineSnapshot(`"Tool my_unknown_function called but was not available"`); + }); + + it('throws an error if invalid JSON was generated', () => { + expect(() => + validateToolCalls({ + toolCalls: [ + { + function: { + name: 'my_function', + arguments: '{[]}', + }, + toolCallId: '1', + }, + ], + + tools: { + my_function: { + description: 'description', + }, + }, + }) + ).toThrowErrorMatchingInlineSnapshot(`"Failed parsing arguments for my_function"`); + }); + + it('throws an error if the function call has invalid arguments', () => { + function validate() { + validateToolCalls({ + toolCalls: [ + { + function: { + name: 'my_function', + arguments: JSON.stringify({ foo: 'bar' }), + }, + toolCallId: '1', + }, + ], + + tools: { + my_function: { + description: 'description', + schema: { + type: 'object', + properties: { + bar: { + type: 'string', + }, + }, + required: ['bar'], + }, + }, + }, + }); + } + expect(() => validate()).toThrowErrorMatchingInlineSnapshot( + `"Tool call arguments for my_function were invalid"` + ); + + try { + validate(); + } catch (error) { + if (isToolValidationError(error)) { + expect(error.meta).toEqual({ + arguments: JSON.stringify({ foo: 'bar' }), + errorsText: `data must have required property 'bar'`, + name: 'my_function', + }); + } else { + fail('Expected toolValidationError'); + } + } + }); + + it('successfully validates and parses a valid tool call', () => { + function runValidation() { + return validateToolCalls({ + toolCalls: [ + { + function: { + name: 'my_function', + arguments: '{ "foo": "bar" }', + }, + toolCallId: '1', + }, + ], + + tools: { + my_function: { + description: 'description', + schema: { + type: 'object', + properties: { + foo: { + type: 'string', + }, + }, + required: ['foo'], + }, + }, + }, + }); + } + expect(() => runValidation()).not.toThrowError(); + + const validated = runValidation(); + + expect(validated).toEqual([ + { + function: { + name: 'my_function', + arguments: { + foo: 'bar', + }, + }, + toolCallId: '1', + }, + ]); + }); +}); diff --git a/x-pack/plugins/inference/server/util/validate_tool_calls.ts b/x-pack/plugins/inference/server/util/validate_tool_calls.ts new file mode 100644 index 0000000000000..5d1e659bc36f5 --- /dev/null +++ b/x-pack/plugins/inference/server/util/validate_tool_calls.ts @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import Ajv from 'ajv'; +import { + createToolNotFoundError, + createToolValidationError, +} from '../../common/chat_complete/errors'; +import { + ToolCallsOf, + ToolChoiceType, + ToolOptions, + UnvalidatedToolCall, +} from '../../common/chat_complete/tools'; + +export function validateToolCalls({ + toolCalls, + toolChoice, + tools, +}: TToolOptions & { toolCalls: UnvalidatedToolCall[] }): ToolCallsOf['toolCalls'] { + const validator = new Ajv(); + + if (toolCalls.length && toolChoice === ToolChoiceType.none) { + throw createToolValidationError( + `tool_choice was "none" but ${toolCalls + .map((toolCall) => toolCall.function.name) + .join(', ')} was/were called`, + { toolCalls } + ); + } + + return toolCalls.map((toolCall) => { + const tool = tools?.[toolCall.function.name]; + + if (!tool) { + throw createToolNotFoundError(toolCall.function.name); + } + + const toolSchema = tool.schema ?? { type: 'object', properties: {} }; + + let serializedArguments: ToolCallsOf['toolCalls'][0]['function']['arguments']; + + try { + serializedArguments = JSON.parse(toolCall.function.arguments); + } catch (error) { + throw createToolValidationError(`Failed parsing arguments for ${toolCall.function.name}`, { + name: toolCall.function.name, + arguments: toolCall.function.arguments, + toolCalls: [toolCall], + }); + } + + const valid = validator.validate(toolSchema, serializedArguments); + + if (!valid) { + throw createToolValidationError( + `Tool call arguments for ${toolCall.function.name} were invalid`, + { + name: toolCall.function.name, + errorsText: validator.errorsText(), + arguments: toolCall.function.arguments, + } + ); + } + + return { + toolCallId: toolCall.toolCallId, + function: { + name: toolCall.function.name, + arguments: serializedArguments, + }, + }; + }); +} diff --git a/x-pack/plugins/inference/tsconfig.json b/x-pack/plugins/inference/tsconfig.json new file mode 100644 index 0000000000000..16d7ca041582c --- /dev/null +++ b/x-pack/plugins/inference/tsconfig.json @@ -0,0 +1,27 @@ +{ + "extends": "../../../tsconfig.base.json", + "compilerOptions": { + "outDir": "target/types" + }, + "include": [ + "../../../typings/**/*", + "common/**/*", + "public/**/*", + "typings/**/*", + "public/**/*.json", + "server/**/*", + ".storybook/**/*" + ], + "exclude": [ + "target/**/*", + ".storybook/**/*.js" + ], + "kbn_references": [ + "@kbn/core", + "@kbn/i18n", + "@kbn/logging", + "@kbn/core-http-server", + "@kbn/actions-plugin", + "@kbn/config-schema" + ] +} diff --git a/yarn.lock b/yarn.lock index d3b64c7a7c286..6de55f6d100cd 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5260,6 +5260,10 @@ version "0.0.0" uid "" +"@kbn/inference-plugin@link:x-pack/plugins/inference": + version "0.0.0" + uid "" + "@kbn/inference_integration_flyout@link:x-pack/packages/ml/inference_integration_flyout": version "0.0.0" uid ""