From eb0feab55edc05423cbef6314dc40ca2319b2a9f Mon Sep 17 00:00:00 2001 From: Pierre Gayvallet Date: Tue, 5 Nov 2024 15:54:41 +0100 Subject: [PATCH] Add `stream` param for inference APIs (#198646) ## Summary Fix https://github.com/elastic/kibana/issues/198644 Add a `stream` parameter to the `chatComplete` and `output` APIs, defaulting to `false`, to switch between "full content response as promise" and "event observable" responses. Note: at the moment, in non-stream mode, the implementation is simply constructing the response from the observable. It should be possible later to improve this by having the LLM adapters handle the stream/no-stream logic, but this is out of scope of the current PR. ### Normal mode ```ts const response = await chatComplete({ connectorId: 'my-connector', system: "You are a helpful assistant", messages: [ { role: MessageRole.User, content: "Some question?"}, ] }); const { content, toolCalls } = response; // do something ``` ### Stream mode ```ts const events$ = chatComplete({ stream: true, connectorId: 'my-connector', system: "You are a helpful assistant", messages: [ { role: MessageRole.User, content: "Some question?"}, ] }); events$.subscribe((event) => { // do something }); ``` --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Co-authored-by: Elastic Machine --- ...use_observability_ai_assistant_context.tsx | 3 + .../ai-infra/inference-common/index.ts | 8 +- .../inference-common/src/chat_complete/api.ts | 93 +++++++- .../src/chat_complete/events.ts | 38 +-- .../src/chat_complete/index.ts | 5 +- .../src/chat_complete/tool_schema.ts | 7 +- .../inference-common/src/output/api.ts | 143 ++++++++++-- .../inference-common/src/output/index.ts | 8 +- x-pack/plugins/inference/README.md | 216 ++++++++++++++++-- .../common/create_output_api.test.ts | 122 ++++++++++ .../inference/common/create_output_api.ts | 69 ++++-- .../inference/public/chat_complete.test.ts | 63 +++++ .../plugins/inference/public/chat_complete.ts | 43 +++- .../scripts/evaluation/evaluation.ts | 4 +- .../scripts/evaluation/evaluation_client.ts | 73 +++--- .../evaluation/scenarios/esql/index.spec.ts | 54 ++--- .../load_esql_docs/utils/output_executor.ts | 15 +- .../inference/scripts/util/kibana_client.ts | 54 +++-- .../inference/server/chat_complete/api.ts | 35 +-- .../server/chat_complete/utils/index.ts | 1 + .../utils/stream_to_response.test.ts | 73 ++++++ .../chat_complete/utils/stream_to_response.ts | 42 ++++ .../inference/server/routes/chat_complete.ts | 65 ++++-- .../tasks/nl_to_esql/actions/generate_esql.ts | 1 + .../actions/request_documentation.ts | 4 +- .../server/test_utils/chat_complete_events.ts | 39 ++++ 26 files changed, 1050 insertions(+), 228 deletions(-) create mode 100644 x-pack/plugins/inference/common/create_output_api.test.ts create mode 100644 x-pack/plugins/inference/public/chat_complete.test.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/utils/stream_to_response.test.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/utils/stream_to_response.ts create mode 100644 x-pack/plugins/inference/server/test_utils/chat_complete_events.ts diff --git a/src/plugins/dashboard/public/dashboard_app/hooks/use_observability_ai_assistant_context.tsx b/src/plugins/dashboard/public/dashboard_app/hooks/use_observability_ai_assistant_context.tsx index c20e8fcd1dc76..39ae4594d5bc8 100644 --- a/src/plugins/dashboard/public/dashboard_app/hooks/use_observability_ai_assistant_context.tsx +++ b/src/plugins/dashboard/public/dashboard_app/hooks/use_observability_ai_assistant_context.tsx @@ -114,9 +114,11 @@ export function useObservabilityAIAssistantContext({ }, metric: { type: 'object', + properties: {}, }, gauge: { type: 'object', + properties: {}, }, pie: { type: 'object', @@ -158,6 +160,7 @@ export function useObservabilityAIAssistantContext({ }, table: { type: 'object', + properties: {}, }, tagcloud: { type: 'object', diff --git a/x-pack/packages/ai-infra/inference-common/index.ts b/x-pack/packages/ai-infra/inference-common/index.ts index 6de7ce3bb8008..502d8e86a0beb 100644 --- a/x-pack/packages/ai-infra/inference-common/index.ts +++ b/x-pack/packages/ai-infra/inference-common/index.ts @@ -25,12 +25,15 @@ export { type ToolChoice, type ChatCompleteAPI, type ChatCompleteOptions, - type ChatCompletionResponse, + type ChatCompleteCompositeResponse, type ChatCompletionTokenCountEvent, type ChatCompletionEvent, type ChatCompletionChunkEvent, type ChatCompletionChunkToolCall, type ChatCompletionMessageEvent, + type ChatCompleteStreamResponse, + type ChatCompleteResponse, + type ChatCompletionTokenCount, withoutTokenCountEvents, withoutChunkEvents, isChatCompletionMessageEvent, @@ -48,7 +51,10 @@ export { export { OutputEventType, type OutputAPI, + type OutputOptions, type OutputResponse, + type OutputCompositeResponse, + type OutputStreamResponse, type OutputCompleteEvent, type OutputUpdateEvent, type Output, diff --git a/x-pack/packages/ai-infra/inference-common/src/chat_complete/api.ts b/x-pack/packages/ai-infra/inference-common/src/chat_complete/api.ts index c6ffa9d4c8d5d..cb91f4e53e8ae 100644 --- a/x-pack/packages/ai-infra/inference-common/src/chat_complete/api.ts +++ b/x-pack/packages/ai-infra/inference-common/src/chat_complete/api.ts @@ -6,16 +6,35 @@ */ import type { Observable } from 'rxjs'; -import type { ToolOptions } from './tools'; +import type { ToolCallsOf, ToolOptions } from './tools'; import type { Message } from './messages'; -import type { ChatCompletionEvent } from './events'; +import type { ChatCompletionEvent, ChatCompletionTokenCount } from './events'; /** * Request a completion from the LLM based on a prompt or conversation. * - * @example using the API to get an event observable. + * By default, The complete LLM response will be returned as a promise. + * + * @example using the API in default mode to get promise of the LLM response. + * ```ts + * const response = await chatComplete({ + * connectorId: 'my-connector', + * system: "You are a helpful assistant", + * messages: [ + * { role: MessageRole.User, content: "Some question?"}, + * ] + * }); + * + * const { content, tokens, toolCalls } = response; + * ``` + * + * Use `stream: true` to return an observable returning the full set + * of events in real time. + * + * @example using the API in stream mode to get an event observable. * ```ts * const events$ = chatComplete({ + * stream: true, * connectorId: 'my-connector', * system: "You are a helpful assistant", * messages: [ @@ -24,20 +43,44 @@ import type { ChatCompletionEvent } from './events'; * { role: MessageRole.User, content: "Another question?"}, * ] * }); + * + * // using the observable + * events$.pipe(withoutTokenCountEvents()).subscribe((event) => { + * if (isChatCompletionChunkEvent(event)) { + * // do something with the chunk event + * } else { + * // do something with the message event + * } + * }); + * ``` */ -export type ChatCompleteAPI = ( - options: ChatCompleteOptions -) => ChatCompletionResponse; +export type ChatCompleteAPI = < + TToolOptions extends ToolOptions = ToolOptions, + TStream extends boolean = false +>( + options: ChatCompleteOptions +) => ChatCompleteCompositeResponse; /** * Options used to call the {@link ChatCompleteAPI} */ -export type ChatCompleteOptions = { +export type ChatCompleteOptions< + TToolOptions extends ToolOptions = ToolOptions, + TStream extends boolean = false +> = { /** * The ID of the connector to use. - * Must be a genAI compatible connector, or an error will be thrown. + * Must be an inference connector, or an error will be thrown. */ connectorId: string; + /** + * Set to true to enable streaming, which will change the API response type from + * a single {@link ChatCompleteResponse} promise + * to a {@link ChatCompleteStreamResponse} event observable. + * + * Defaults to false. + */ + stream?: TStream; /** * Optional system message for the LLM. */ @@ -53,14 +96,44 @@ export type ChatCompleteOptions } & TToolOptions; /** - * Response from the {@link ChatCompleteAPI}. + * Composite response type from the {@link ChatCompleteAPI}, + * which can be either an observable or a promise depending on + * whether API was called with stream mode enabled or not. + */ +export type ChatCompleteCompositeResponse< + TToolOptions extends ToolOptions = ToolOptions, + TStream extends boolean = false +> = TStream extends true + ? ChatCompleteStreamResponse + : Promise>; + +/** + * Response from the {@link ChatCompleteAPI} when streaming is enabled. * * Observable of {@link ChatCompletionEvent} */ -export type ChatCompletionResponse = Observable< +export type ChatCompleteStreamResponse = Observable< ChatCompletionEvent >; +/** + * Response from the {@link ChatCompleteAPI} when streaming is not enabled. + */ +export interface ChatCompleteResponse { + /** + * The text content of the LLM response. + */ + content: string; + /** + * The eventual tool calls performed by the LLM. + */ + toolCalls: ToolCallsOf['toolCalls']; + /** + * Token counts + */ + tokens?: ChatCompletionTokenCount; +} + /** * Define the function calling mode when using inference APIs. * - native will use the LLM's native function calling (requires the LLM to have native support) diff --git a/x-pack/packages/ai-infra/inference-common/src/chat_complete/events.ts b/x-pack/packages/ai-infra/inference-common/src/chat_complete/events.ts index 92c49e6ee7fc0..73396b3e2b905 100644 --- a/x-pack/packages/ai-infra/inference-common/src/chat_complete/events.ts +++ b/x-pack/packages/ai-infra/inference-common/src/chat_complete/events.ts @@ -76,30 +76,38 @@ export type ChatCompletionChunkEvent = tool_calls: ChatCompletionChunkToolCall[]; }; +/** + * Token count structure for the chatComplete API. + */ +export interface ChatCompletionTokenCount { + /** + * Input token count + */ + prompt: number; + /** + * Output token count + */ + completion: number; + /** + * Total token count + */ + total: number; +} + /** * Token count event, send only once, usually (but not necessarily) * before the message event */ export type ChatCompletionTokenCountEvent = InferenceTaskEventBase & { - tokens: { - /** - * Input token count - */ - prompt: number; - /** - * Output token count - */ - completion: number; - /** - * Total token count - */ - total: number; - }; + /** + * The token count structure + */ + tokens: ChatCompletionTokenCount; }; /** - * Events emitted from the {@link ChatCompletionResponse} observable + * Events emitted from the {@link ChatCompleteResponse} observable * returned from the {@link ChatCompleteAPI}. * * The chatComplete API returns 3 type of events: diff --git a/x-pack/packages/ai-infra/inference-common/src/chat_complete/index.ts b/x-pack/packages/ai-infra/inference-common/src/chat_complete/index.ts index 8199af4cf068b..ca69f39b273e5 100644 --- a/x-pack/packages/ai-infra/inference-common/src/chat_complete/index.ts +++ b/x-pack/packages/ai-infra/inference-common/src/chat_complete/index.ts @@ -6,10 +6,12 @@ */ export type { - ChatCompletionResponse, + ChatCompleteCompositeResponse, ChatCompleteAPI, ChatCompleteOptions, FunctionCallingMode, + ChatCompleteStreamResponse, + ChatCompleteResponse, } from './api'; export { ChatCompletionEventType, @@ -18,6 +20,7 @@ export { type ChatCompletionEvent, type ChatCompletionChunkToolCall, type ChatCompletionTokenCountEvent, + type ChatCompletionTokenCount, } from './events'; export { MessageRole, diff --git a/x-pack/packages/ai-infra/inference-common/src/chat_complete/tool_schema.ts b/x-pack/packages/ai-infra/inference-common/src/chat_complete/tool_schema.ts index bb4742c6b74d9..fd935785f74f5 100644 --- a/x-pack/packages/ai-infra/inference-common/src/chat_complete/tool_schema.ts +++ b/x-pack/packages/ai-infra/inference-common/src/chat_complete/tool_schema.ts @@ -11,9 +11,9 @@ interface ToolSchemaFragmentBase { description?: string; } -interface ToolSchemaTypeObject extends ToolSchemaFragmentBase { +export interface ToolSchemaTypeObject extends ToolSchemaFragmentBase { type: 'object'; - properties?: Record; + properties: Record; required?: string[] | readonly string[]; } @@ -40,6 +40,9 @@ interface ToolSchemaTypeArray extends ToolSchemaFragmentBase { items: Exclude; } +/** + * A tool schema property's possible types. + */ export type ToolSchemaType = | ToolSchemaTypeObject | ToolSchemaTypeString diff --git a/x-pack/packages/ai-infra/inference-common/src/output/api.ts b/x-pack/packages/ai-infra/inference-common/src/output/api.ts index 677d2f7015c2a..3355042910a61 100644 --- a/x-pack/packages/ai-infra/inference-common/src/output/api.ts +++ b/x-pack/packages/ai-infra/inference-common/src/output/api.ts @@ -6,39 +6,140 @@ */ import type { Observable } from 'rxjs'; -import type { Message, FunctionCallingMode, FromToolSchema, ToolSchema } from '../chat_complete'; -import type { OutputEvent } from './events'; +import { Message, FunctionCallingMode, FromToolSchema, ToolSchema } from '../chat_complete'; +import { Output, OutputEvent } from './events'; /** * Generate a response with the LLM for a prompt, optionally based on a schema. * - * @param {string} id The id of the operation - * @param {string} options.connectorId The ID of the connector that is to be used. - * @param {string} options.input The prompt for the LLM. - * @param {string} options.messages Previous messages in a conversation. - * @param {ToolSchema} [options.schema] The schema the response from the LLM should adhere to. + * @example + * ```ts + * // schema must be defined as full const or using the `satisfies ToolSchema` modifier for TS type inference to work + * const mySchema = { + * type: 'object', + * properties: { + * animals: { + * description: 'the list of animals that are mentioned in the provided article', + * type: 'array', + * items: { + * type: 'string', + * }, + * }, + * }, + * } as const; + * + * const response = outputApi({ + * id: 'extract_from_article', + * connectorId: 'my-connector connector', + * schema: mySchema, + * input: ` + * Please find all the animals that are mentioned in the following document: + * ## Document¬ + * ${theDoc} + * `, + * }); + * + * // output is properly typed from the provided schema + * const { animals } = response.output; + * ``` */ export type OutputAPI = < TId extends string = string, - TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined + TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined, + TStream extends boolean = false >( - id: TId, - options: { - connectorId: string; - system?: string; - input: string; - schema?: TOutputSchema; - previousMessages?: Message[]; - functionCalling?: FunctionCallingMode; - } -) => OutputResponse; + options: OutputOptions +) => OutputCompositeResponse; + +/** + * Options for the {@link OutputAPI} + */ +export interface OutputOptions< + TId extends string = string, + TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined, + TStream extends boolean = false +> { + /** + * The id of the operation. + */ + id: TId; + /** + * The ID of the connector to use. + * Must be an inference connector, or an error will be thrown. + */ + connectorId: string; + /** + * Optional system message for the LLM. + */ + system?: string; + /** + * The prompt for the LLM. + */ + input: string; + /** + * The schema the response from the LLM should adhere to. + */ + schema?: TOutputSchema; + /** + * Previous messages in the conversation. + * If provided, will be passed to the LLM in addition to `input`. + */ + previousMessages?: Message[]; + /** + * Function calling mode, defaults to "native". + */ + functionCalling?: FunctionCallingMode; + /** + * Set to true to enable streaming, which will change the API response type from + * a single promise to an event observable. + * + * Defaults to false. + */ + stream?: TStream; +} + +/** + * Composite response type from the {@link OutputAPI}, + * which can be either an observable or a promise depending on + * whether API was called with stream mode enabled or not. + */ +export type OutputCompositeResponse< + TId extends string = string, + TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined, + TStream extends boolean = false +> = TStream extends true + ? OutputStreamResponse + : Promise< + OutputResponse< + TId, + TOutputSchema extends ToolSchema ? FromToolSchema : undefined + > + >; + +/** + * Response from the {@link OutputAPI} when streaming is not enabled. + */ +export interface OutputResponse { + /** + * The id of the operation, as specified when calling the API. + */ + id: TId; + /** + * The task output, following the schema specified as input. + */ + output: TOutput; + /** + * Potential text content provided by the LLM, if it was provided in addition to the tool call. + */ + content: string; +} /** - * Response from the {@link OutputAPI}. + * Response from the {@link OutputAPI} in streaming mode. * - * Observable of {@link OutputEvent} + * @returns Observable of {@link OutputEvent} */ -export type OutputResponse< +export type OutputStreamResponse< TId extends string = string, TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined > = Observable< diff --git a/x-pack/packages/ai-infra/inference-common/src/output/index.ts b/x-pack/packages/ai-infra/inference-common/src/output/index.ts index ceac178f47faa..a3039005b2f7c 100644 --- a/x-pack/packages/ai-infra/inference-common/src/output/index.ts +++ b/x-pack/packages/ai-infra/inference-common/src/output/index.ts @@ -5,7 +5,13 @@ * 2.0. */ -export type { OutputAPI, OutputResponse } from './api'; +export type { + OutputAPI, + OutputOptions, + OutputCompositeResponse, + OutputResponse, + OutputStreamResponse, +} from './api'; export { OutputEventType, type OutputCompleteEvent, diff --git a/x-pack/plugins/inference/README.md b/x-pack/plugins/inference/README.md index 1807da7f29faa..935ae31bd6bc6 100644 --- a/x-pack/plugins/inference/README.md +++ b/x-pack/plugins/inference/README.md @@ -4,13 +4,12 @@ The inference plugin is a central place to handle all interactions with the Elas external LLM APIs. Its goals are: - Provide a single place for all interactions with large language models and other generative AI adjacent tasks. -- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini -- Host commonly used LLM-based tasks like generating ES|QL from natural language and knowledge base recall. +- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini. - Allow us to move gradually to the \_inference endpoint without disrupting engineers. ## Architecture and examples -![CleanShot 2024-07-14 at 14 45 27@2x](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f) +![architecture-schema](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f) ## Terminology @@ -21,8 +20,22 @@ The following concepts are commonly used throughout the plugin: - **tools**: a set of tools that the LLM can choose to use when generating the next message. In essence, it allows the consumer of the API to define a schema for structured output instead of plain text, and having the LLM select the most appropriate one. - **tool call**: when the LLM has chosen a tool (schema) to use for its output, and returns a document that matches the schema, this is referred to as a tool call. +## Inference connectors + +Performing inference, or more globally communicating with the LLM, is done using stack connectors. + +The subset of connectors that can be used for inference are called `genAI`, or `inference` connectors. +Calling any inference APIs with the ID of a connector that is not inference-compatible will result in the API throwing an error. + +The list of inference connector types: +- `.gen-ai`: OpenAI connector +- `.bedrock`: Bedrock Claude connector +- `.gemini`: Vertex Gemini connector + ## Usage examples +The inference APIs are available via the inference client, which can be created using the inference plugin's start contract: + ```ts class MyPlugin { setup(coreSetup, pluginsSetup) { @@ -40,9 +53,9 @@ class MyPlugin { async (context, request, response) => { const [coreStart, pluginsStart] = await coreSetup.getStartServices(); - const inferenceClient = pluginsSetup.inference.getClient({ request }); + const inferenceClient = pluginsStart.inference.getClient({ request }); - const chatComplete$ = inferenceClient.chatComplete({ + const chatResponse = inferenceClient.chatComplete({ connectorId: request.body.connectorId, system: `Here is my system message`, messages: [ @@ -53,13 +66,9 @@ class MyPlugin { ], }); - const message = await lastValueFrom( - chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents()) - ); - return response.ok({ body: { - message, + chatResponse, }, }); } @@ -68,33 +77,190 @@ class MyPlugin { } ``` -## Services +## APIs -### `chatComplete`: +### `chatComplete` API: `chatComplete` generates a response to a prompt or a conversation using the LLM. Here's what is supported: -- Normalizing request and response formats from different connector types (e.g. OpenAI, Bedrock, Claude, Elastic Inference Service) +- Normalizing request and response formats from all supported connector types - Tool calling and validation of tool calls -- Emits token count events -- Emits message events, which is the concatenated message based on the response chunks +- Token usage stats / events +- Streaming mode to work with chunks in real time instead of waiting for the full response + +#### Standard usage + +In standard mode, the API returns a promise resolving with the full LLM response once the generation is complete. +The response will also contain the token count info, if available. + +```ts +const chatResponse = inferenceClient.chatComplete({ + connectorId: 'some-gen-ai-connector', + system: `Here is my system message`, + messages: [ + { + role: MessageRole.User, + content: 'Do something', + }, + ], +}); + +const { content, tokens } = chatResponse; +// do something with the output +``` + +#### Streaming mode + +Passing `stream: true` when calling the API enables streaming mode. +In that mode, the API returns an observable instead of a promise, emitting chunks in real time. + +That observable emits three types of events: + +- `chunk` the completion chunks, emitted in real time +- `tokenCount` token count event, containing info about token usages, eventually emitted after the chunks +- `message` full message event, emitted once the source is done sending chunks + +The `@kbn/inference-common` package exposes various utilities to work with this multi-events observable: + +- `isChatCompletionChunkEvent`, `isChatCompletionMessageEvent` and `isChatCompletionTokenCountEvent` which are type guard for the corresponding event types +- `withoutChunkEvents` and `withoutTokenCountEvents` + +```ts +import { + isChatCompletionChunkEvent, + isChatCompletionMessageEvent, + withoutTokenCountEvents, + withoutChunkEvents, +} from '@kbn/inference-common'; -### `output` +const chatComplete$ = inferenceClient.chatComplete({ + connectorId: 'some-gen-ai-connector', + stream: true, + system: `Here is my system message`, + messages: [ + { + role: MessageRole.User, + content: 'Do something', + }, + ], +}); -`output` is a wrapper around `chatComplete` that is catered towards a single use case: having the LLM output a structured response, based on a schema. It also drops the token count events to simplify usage. +// using and filtering the events +chatComplete$.pipe(withoutTokenCountEvents()).subscribe((event) => { + if (isChatCompletionChunkEvent(event)) { + // do something with the chunk event + } else { + // do something with the message event + } +}); + +// or retrieving the final message +const message = await lastValueFrom( + chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents()) +); +``` + +#### Defining and using tools + +Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety. +This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`). + +The description and schema of a tool will be converted and sent to the LLM, so it's important +to be explicit about what each tool does. + +```ts +const chatResponse = inferenceClient.chatComplete({ + connectorId: 'some-gen-ai-connector', + system: `Here is my system message`, + messages: [ + { + role: MessageRole.User, + content: 'How much is 4 plus 9?', + }, + ], + toolChoice: ToolChoiceType.required, // MUST call a tool + tools: { + date: { + description: 'Call this tool if you need to know the current date' + }, + add: { + description: 'This tool can be used to add two numbers', + schema: { + type: 'object', + properties: { + a: { type: 'number', description: 'the first number' }, + b: { type: 'number', description: 'the second number'} + }, + required: ['a', 'b'] + } + } + } as const // as const is required to have type inference on the schema +}); -### Observable event streams +const { content, toolCalls } = chatResponse; +const toolCall = toolCalls[0]; +// process the tool call and eventually continue the conversation with the LLM +``` + +### `output` API + +`output` is a wrapper around the `chatComplete` API that is catered towards a specific use case: having the LLM output a structured response, based on a schema. +It's basically just making sure that the LLM will call the single tool that is exposed via the provided `schema`. +It also drops the token count info to simplify usage. + +Similar to `chatComplete`, `output` supports two modes: normal full response mode by default, and optional streaming mode by passing the `stream: true` parameter. + +```ts +import { ToolSchema } from '@kbn/inference-common'; -These APIs, both on the client and the server, return Observables that emit events. When converting the Observable into a stream, the following things happen: +// schema must be defined as full const or using the `satisfies ToolSchema` modifier for TS type inference to work +const mySchema = { + type: 'object', + properties: { + animals: { + description: 'the list of animals that are mentioned in the provided article', + type: 'array', + items: { + type: 'string', + }, + }, + vegetables: { + description: 'the list of vegetables that are mentioned in the provided article', + type: 'array', + items: { + type: 'string', + }, + }, + }, +} as const; -- Errors are caught and serialized as events sent over the stream (after an error, the stream ends). -- The response stream outputs data as [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) -- The client that reads the stream, parses the event source as an Observable, and if it encounters a serialized error, it deserializes it and throws an error in the Observable. +const response = inferenceClient.outputApi({ + id: 'extract_from_article', + connectorId: 'some-gen-ai-connector', + schema: mySchema, + system: + 'You are a helpful assistant and your current task is to extract informations from the provided document', + input: ` + Please find all the animals and vegetables that are mentioned in the following document: + + ## Document + + ${theDoc} + `, +}); + +// output is properly typed from the provided schema +const { animals, vegetables } = response.output; +``` ### Errors -All known errors are instances, and not extensions, from the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error. This allows us to serialize and deserialize errors over the wire without a complicated factory pattern. +All known errors are instances, and not extensions, of the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error. +This allows us to serialize and deserialize errors over the wire without a complicated factory pattern. -### Tools +Type guards for each type of error are exposed from the `@kbn/inference-common` package, such as: -Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety. This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`). +- `isInferenceError` +- `isInferenceInternalError` +- `isInferenceRequestError` +- ...`isXXXError` diff --git a/x-pack/plugins/inference/common/create_output_api.test.ts b/x-pack/plugins/inference/common/create_output_api.test.ts new file mode 100644 index 0000000000000..b5d380fa9aac6 --- /dev/null +++ b/x-pack/plugins/inference/common/create_output_api.test.ts @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { firstValueFrom, isObservable, of, toArray } from 'rxjs'; +import { + ChatCompleteResponse, + ChatCompletionEvent, + ChatCompletionEventType, +} from '@kbn/inference-common'; +import { createOutputApi } from './create_output_api'; + +describe('createOutputApi', () => { + let chatComplete: jest.Mock; + + beforeEach(() => { + chatComplete = jest.fn(); + }); + + it('calls `chatComplete` with the right parameters', async () => { + chatComplete.mockResolvedValue(Promise.resolve({ content: 'content', toolCalls: [] })); + + const output = createOutputApi(chatComplete); + + await output({ + id: 'id', + stream: false, + functionCalling: 'native', + connectorId: '.my-connector', + system: 'system', + input: 'input message', + }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith({ + connectorId: '.my-connector', + functionCalling: 'native', + stream: false, + system: 'system', + messages: [ + { + content: 'input message', + role: 'user', + }, + ], + }); + }); + + it('returns the expected value when stream=false', async () => { + const chatCompleteResponse: ChatCompleteResponse = { + content: 'content', + toolCalls: [{ toolCallId: 'a', function: { name: 'foo', arguments: { arg: 1 } } }], + }; + + chatComplete.mockResolvedValue(Promise.resolve(chatCompleteResponse)); + + const output = createOutputApi(chatComplete); + + const response = await output({ + id: 'my-id', + stream: false, + connectorId: '.my-connector', + input: 'input message', + }); + + expect(response).toEqual({ + id: 'my-id', + content: chatCompleteResponse.content, + output: chatCompleteResponse.toolCalls[0].function.arguments, + }); + }); + + it('returns the expected value when stream=true', async () => { + const sourceEvents: ChatCompletionEvent[] = [ + { type: ChatCompletionEventType.ChatCompletionChunk, content: 'chunk-1', tool_calls: [] }, + { type: ChatCompletionEventType.ChatCompletionChunk, content: 'chunk-2', tool_calls: [] }, + { + type: ChatCompletionEventType.ChatCompletionMessage, + content: 'message', + toolCalls: [{ toolCallId: 'a', function: { name: 'foo', arguments: { arg: 1 } } }], + }, + ]; + + chatComplete.mockReturnValue(of(...sourceEvents)); + + const output = createOutputApi(chatComplete); + + const response$ = await output({ + id: 'my-id', + stream: true, + connectorId: '.my-connector', + input: 'input message', + }); + + expect(isObservable(response$)).toEqual(true); + const events = await firstValueFrom(response$.pipe(toArray())); + + expect(events).toEqual([ + { + content: 'chunk-1', + id: 'my-id', + type: 'output', + }, + { + content: 'chunk-2', + id: 'my-id', + type: 'output', + }, + { + content: 'message', + id: 'my-id', + output: { + arg: 1, + }, + type: 'complete', + }, + ]); + }); +}); diff --git a/x-pack/plugins/inference/common/create_output_api.ts b/x-pack/plugins/inference/common/create_output_api.ts index 450114c892cba..e5dd2eeda2cbd 100644 --- a/x-pack/plugins/inference/common/create_output_api.ts +++ b/x-pack/plugins/inference/common/create_output_api.ts @@ -5,24 +5,36 @@ * 2.0. */ -import { map } from 'rxjs'; import { - OutputAPI, - OutputEvent, - OutputEventType, ChatCompleteAPI, ChatCompletionEventType, MessageRole, + OutputAPI, + OutputEventType, + OutputOptions, + ToolSchema, withoutTokenCountEvents, } from '@kbn/inference-common'; +import { isObservable, map } from 'rxjs'; import { ensureMultiTurn } from './utils/ensure_multi_turn'; -export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI { - return (id, { connectorId, input, schema, system, previousMessages, functionCalling }) => { - return chatCompleteApi({ +export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI; +export function createOutputApi(chatCompleteApi: ChatCompleteAPI) { + return ({ + id, + connectorId, + input, + schema, + system, + previousMessages, + functionCalling, + stream, + }: OutputOptions) => { + const response = chatCompleteApi({ connectorId, - system, + stream, functionCalling, + system, messages: ensureMultiTurn([ ...(previousMessages || []), { @@ -41,27 +53,42 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI { toolChoice: { function: 'structuredOutput' as const }, } : {}), - }).pipe( - withoutTokenCountEvents(), - map((event): OutputEvent => { - if (event.type === ChatCompletionEventType.ChatCompletionChunk) { + }); + + if (isObservable(response)) { + return response.pipe( + withoutTokenCountEvents(), + map((event) => { + if (event.type === ChatCompletionEventType.ChatCompletionChunk) { + return { + type: OutputEventType.OutputUpdate, + id, + content: event.content, + }; + } + return { - type: OutputEventType.OutputUpdate, id, + output: + event.toolCalls.length && 'arguments' in event.toolCalls[0].function + ? event.toolCalls[0].function.arguments + : undefined, content: event.content, + type: OutputEventType.OutputComplete, }; - } - + }) + ); + } else { + return response.then((chatResponse) => { return { id, + content: chatResponse.content, output: - event.toolCalls.length && 'arguments' in event.toolCalls[0].function - ? event.toolCalls[0].function.arguments + chatResponse.toolCalls.length && 'arguments' in chatResponse.toolCalls[0].function + ? chatResponse.toolCalls[0].function.arguments : undefined, - content: event.content, - type: OutputEventType.OutputComplete, }; - }) - ); + }); + } }; } diff --git a/x-pack/plugins/inference/public/chat_complete.test.ts b/x-pack/plugins/inference/public/chat_complete.test.ts new file mode 100644 index 0000000000000..b297db1f2fb2c --- /dev/null +++ b/x-pack/plugins/inference/public/chat_complete.test.ts @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { omit } from 'lodash'; +import { httpServiceMock } from '@kbn/core/public/mocks'; +import { ChatCompleteAPI, MessageRole, ChatCompleteOptions } from '@kbn/inference-common'; +import { createChatCompleteApi } from './chat_complete'; + +describe('createChatCompleteApi', () => { + let http: ReturnType; + let chatComplete: ChatCompleteAPI; + + beforeEach(() => { + http = httpServiceMock.createStartContract(); + chatComplete = createChatCompleteApi({ http }); + }); + + it('calls http.post with the right parameters when stream is not true', async () => { + const params = { + connectorId: 'my-connector', + functionCalling: 'native', + system: 'system', + messages: [{ role: MessageRole.User, content: 'question' }], + }; + await chatComplete(params as ChatCompleteOptions); + + expect(http.post).toHaveBeenCalledTimes(1); + expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete', { + body: expect.any(String), + }); + const callBody = http.post.mock.lastCall!; + + expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(params); + }); + + it('calls http.post with the right parameters when stream is true', async () => { + http.post.mockResolvedValue({}); + + const params = { + connectorId: 'my-connector', + functionCalling: 'native', + stream: true, + system: 'system', + messages: [{ role: MessageRole.User, content: 'question' }], + }; + + await chatComplete(params as ChatCompleteOptions); + + expect(http.post).toHaveBeenCalledTimes(1); + expect(http.post).toHaveBeenCalledWith('/internal/inference/chat_complete/stream', { + asResponse: true, + rawResponse: true, + body: expect.any(String), + }); + const callBody = http.post.mock.lastCall!; + + expect(JSON.parse((callBody as any[])[1].body as string)).toEqual(omit(params, 'stream')); + }); +}); diff --git a/x-pack/plugins/inference/public/chat_complete.ts b/x-pack/plugins/inference/public/chat_complete.ts index 5319f7c31c381..64cb5533f20be 100644 --- a/x-pack/plugins/inference/public/chat_complete.ts +++ b/x-pack/plugins/inference/public/chat_complete.ts @@ -5,14 +5,31 @@ * 2.0. */ -import { from } from 'rxjs'; import type { HttpStart } from '@kbn/core/public'; -import type { ChatCompleteAPI } from '@kbn/inference-common'; +import { + ChatCompleteAPI, + ChatCompleteCompositeResponse, + ChatCompleteOptions, + ToolOptions, +} from '@kbn/inference-common'; +import { from } from 'rxjs'; import type { ChatCompleteRequestBody } from '../common/http_apis'; import { httpResponseIntoObservable } from './util/http_response_into_observable'; -export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI { - return ({ connectorId, messages, system, toolChoice, tools, functionCalling }) => { +export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI; +export function createChatCompleteApi({ http }: { http: HttpStart }) { + return ({ + connectorId, + messages, + system, + toolChoice, + tools, + functionCalling, + stream, + }: ChatCompleteOptions): ChatCompleteCompositeResponse< + ToolOptions, + boolean + > => { const body: ChatCompleteRequestBody = { connectorId, system, @@ -22,12 +39,18 @@ export function createChatCompleteApi({ http }: { http: HttpStart }): ChatComple functionCalling, }; - return from( - http.post('/internal/inference/chat_complete', { - asResponse: true, - rawResponse: true, + if (stream) { + return from( + http.post('/internal/inference/chat_complete/stream', { + asResponse: true, + rawResponse: true, + body: JSON.stringify(body), + }) + ).pipe(httpResponseIntoObservable()); + } else { + return http.post('/internal/inference/chat_complete', { body: JSON.stringify(body), - }) - ).pipe(httpResponseIntoObservable()); + }); + } }; } diff --git a/x-pack/plugins/inference/scripts/evaluation/evaluation.ts b/x-pack/plugins/inference/scripts/evaluation/evaluation.ts index 0c70bf81eddad..425b7e334074a 100644 --- a/x-pack/plugins/inference/scripts/evaluation/evaluation.ts +++ b/x-pack/plugins/inference/scripts/evaluation/evaluation.ts @@ -89,8 +89,8 @@ function runEvaluations() { const evaluationClient = createInferenceEvaluationClient({ connectorId: connector.connectorId, evaluationConnectorId, - outputApi: (id, parameters) => - chatClient.output(id, { + outputApi: (parameters) => + chatClient.output({ ...parameters, connectorId: evaluationConnectorId, }) as any, diff --git a/x-pack/plugins/inference/scripts/evaluation/evaluation_client.ts b/x-pack/plugins/inference/scripts/evaluation/evaluation_client.ts index d35c214542255..99eaf02494af7 100644 --- a/x-pack/plugins/inference/scripts/evaluation/evaluation_client.ts +++ b/x-pack/plugins/inference/scripts/evaluation/evaluation_client.ts @@ -6,8 +6,7 @@ */ import { remove } from 'lodash'; -import { lastValueFrom } from 'rxjs'; -import { type OutputAPI, withoutOutputUpdateEvents } from '@kbn/inference-common'; +import type { OutputAPI } from '@kbn/inference-common'; import type { EvaluationResult } from './types'; export interface InferenceEvaluationClient { @@ -67,11 +66,12 @@ export function createInferenceEvaluationClient({ output: outputApi, getEvaluationConnectorId: () => evaluationConnectorId, evaluate: async ({ input, criteria = [], system }) => { - const evaluation = await lastValueFrom( - outputApi('evaluate', { - connectorId, - system: withAdditionalSystemContext( - `You are a helpful, respected assistant for evaluating task + const evaluation = await outputApi({ + id: 'evaluate', + stream: false, + connectorId, + system: withAdditionalSystemContext( + `You are a helpful, respected assistant for evaluating task inputs and outputs in the Elastic Platform. Your goal is to verify whether the output of a task @@ -84,10 +84,10 @@ export function createInferenceEvaluationClient({ quoting what the assistant did wrong, where it could improve, and what the root cause was in case of a failure. `, - system - ), + system + ), - input: ` + input: ` ## Criteria ${criteria @@ -99,37 +99,36 @@ export function createInferenceEvaluationClient({ ## Input ${input}`, - schema: { - type: 'object', - properties: { - criteria: { - type: 'array', - items: { - type: 'object', - properties: { - index: { - type: 'number', - description: 'The number of the criterion', - }, - score: { - type: 'number', - description: - 'The score you calculated for the criterion, between 0 (criterion fully failed) and 1 (criterion fully succeeded).', - }, - reasoning: { - type: 'string', - description: - 'Your reasoning for the score. Explain your score by mentioning what you expected to happen and what did happen.', - }, + schema: { + type: 'object', + properties: { + criteria: { + type: 'array', + items: { + type: 'object', + properties: { + index: { + type: 'number', + description: 'The number of the criterion', + }, + score: { + type: 'number', + description: + 'The score you calculated for the criterion, between 0 (criterion fully failed) and 1 (criterion fully succeeded).', + }, + reasoning: { + type: 'string', + description: + 'Your reasoning for the score. Explain your score by mentioning what you expected to happen and what did happen.', }, - required: ['index', 'score', 'reasoning'], }, + required: ['index', 'score', 'reasoning'], }, }, - required: ['criteria'], - } as const, - }).pipe(withoutOutputUpdateEvents()) - ); + }, + required: ['criteria'], + } as const, + }); const scoredCriteria = evaluation.output.criteria; diff --git a/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts b/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts index d9071b3f0ae3f..49a82db8124e9 100644 --- a/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts +++ b/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts @@ -9,8 +9,7 @@ import expect from '@kbn/expect'; import type { Logger } from '@kbn/logging'; -import { firstValueFrom, lastValueFrom, filter } from 'rxjs'; -import { isOutputCompleteEvent } from '@kbn/inference-common'; +import { lastValueFrom } from 'rxjs'; import { naturalLanguageToEsql } from '../../../../server/tasks/nl_to_esql'; import { chatClient, evaluationClient, logger } from '../../services'; import { EsqlDocumentBase } from '../../../../server/tasks/nl_to_esql/doc_base'; @@ -66,11 +65,10 @@ const retrieveUsedCommands = async ({ answer: string; esqlDescription: string; }) => { - const commandsListOutput = await firstValueFrom( - evaluationClient - .output('retrieve_commands', { - connectorId: evaluationClient.getEvaluationConnectorId(), - system: ` + const commandsListOutput = await evaluationClient.output({ + id: 'retrieve_commands', + connectorId: evaluationClient.getEvaluationConnectorId(), + system: ` You are a helpful, respected Elastic ES|QL assistant. Your role is to enumerate the list of ES|QL commands and functions that were used @@ -82,34 +80,32 @@ const retrieveUsedCommands = async ({ ${esqlDescription} `, - input: ` + input: ` # Question ${question} # Answer ${answer} `, - schema: { - type: 'object', - properties: { - commands: { - description: - 'The list of commands that were used in the provided ES|QL question and answer', - type: 'array', - items: { type: 'string' }, - }, - functions: { - description: - 'The list of functions that were used in the provided ES|QL question and answer', - type: 'array', - items: { type: 'string' }, - }, - }, - required: ['commands', 'functions'], - } as const, - }) - .pipe(filter(isOutputCompleteEvent)) - ); + schema: { + type: 'object', + properties: { + commands: { + description: + 'The list of commands that were used in the provided ES|QL question and answer', + type: 'array', + items: { type: 'string' }, + }, + functions: { + description: + 'The list of functions that were used in the provided ES|QL question and answer', + type: 'array', + items: { type: 'string' }, + }, + }, + required: ['commands', 'functions'], + } as const, + }); const output = commandsListOutput.output; diff --git a/x-pack/plugins/inference/scripts/load_esql_docs/utils/output_executor.ts b/x-pack/plugins/inference/scripts/load_esql_docs/utils/output_executor.ts index 62cfd8f877e3f..f4014db0e6e8d 100644 --- a/x-pack/plugins/inference/scripts/load_esql_docs/utils/output_executor.ts +++ b/x-pack/plugins/inference/scripts/load_esql_docs/utils/output_executor.ts @@ -5,7 +5,6 @@ * 2.0. */ -import { lastValueFrom } from 'rxjs'; import type { OutputAPI } from '@kbn/inference-common'; export interface Prompt { @@ -27,13 +26,13 @@ export type PromptCallerFactory = ({ export const bindOutput: PromptCallerFactory = ({ connectorId, output }) => { return async ({ input, system }) => { - const response = await lastValueFrom( - output('', { - connectorId, - input, - system, - }) - ); + const response = await output({ + id: 'output', + connectorId, + input, + system, + }); + return response.content ?? ''; }; }; diff --git a/x-pack/plugins/inference/scripts/util/kibana_client.ts b/x-pack/plugins/inference/scripts/util/kibana_client.ts index b599ab81a4af4..ad6c21cf4b248 100644 --- a/x-pack/plugins/inference/scripts/util/kibana_client.ts +++ b/x-pack/plugins/inference/scripts/util/kibana_client.ts @@ -15,6 +15,7 @@ import { inspect } from 'util'; import { isReadable } from 'stream'; import { ChatCompleteAPI, + ChatCompleteCompositeResponse, OutputAPI, ChatCompletionEvent, InferenceTaskError, @@ -22,6 +23,8 @@ import { InferenceTaskEventType, createInferenceInternalError, withoutOutputUpdateEvents, + type ToolOptions, + ChatCompleteOptions, } from '@kbn/inference-common'; import type { ChatCompleteRequestBody } from '../../common/http_apis'; import type { InferenceConnector } from '../../common/connectors'; @@ -154,7 +157,7 @@ export class KibanaClient { } createInferenceClient({ connectorId }: { connectorId: string }): ScriptInferenceClient { - function stream(responsePromise: Promise) { + function streamResponse(responsePromise: Promise) { return from(responsePromise).pipe( switchMap((response) => { if (isReadable(response.data)) { @@ -174,14 +177,18 @@ export class KibanaClient { ); } - const chatCompleteApi: ChatCompleteAPI = ({ + const chatCompleteApi: ChatCompleteAPI = < + TToolOptions extends ToolOptions = ToolOptions, + TStream extends boolean = false + >({ connectorId: chatCompleteConnectorId, messages, system, toolChoice, tools, functionCalling, - }) => { + stream, + }: ChatCompleteOptions) => { const body: ChatCompleteRequestBody = { connectorId: chatCompleteConnectorId, system, @@ -191,15 +198,29 @@ export class KibanaClient { functionCalling, }; - return stream( - this.axios.post( - this.getUrl({ - pathname: `/internal/inference/chat_complete`, - }), - body, - { responseType: 'stream', timeout: NaN } - ) - ); + if (stream) { + return streamResponse( + this.axios.post( + this.getUrl({ + pathname: `/internal/inference/chat_complete/stream`, + }), + body, + { responseType: 'stream', timeout: NaN } + ) + ) as ChatCompleteCompositeResponse; + } else { + return this.axios + .post( + this.getUrl({ + pathname: `/internal/inference/chat_complete/stream`, + }), + body, + { responseType: 'stream', timeout: NaN } + ) + .then((response) => { + return response.data; + }) as ChatCompleteCompositeResponse; + } }; const outputApi: OutputAPI = createOutputApi(chatCompleteApi); @@ -211,8 +232,13 @@ export class KibanaClient { ...options, }); }, - output: (id, options) => { - return outputApi(id, { ...options }).pipe(withoutOutputUpdateEvents()); + output: (options) => { + const response = outputApi({ ...options }); + if (options.stream) { + return (response as any).pipe(withoutOutputUpdateEvents()); + } else { + return response; + } }, }; } diff --git a/x-pack/plugins/inference/server/chat_complete/api.ts b/x-pack/plugins/inference/server/chat_complete/api.ts index 62a1ea8b26146..cf325e72ddf3a 100644 --- a/x-pack/plugins/inference/server/chat_complete/api.ts +++ b/x-pack/plugins/inference/server/chat_complete/api.ts @@ -11,32 +11,37 @@ import type { Logger } from '@kbn/logging'; import type { KibanaRequest } from '@kbn/core-http-server'; import { type ChatCompleteAPI, - type ChatCompletionResponse, + type ChatCompleteCompositeResponse, createInferenceRequestError, + type ToolOptions, + ChatCompleteOptions, } from '@kbn/inference-common'; import type { InferenceStartDependencies } from '../types'; import { getConnectorById } from '../util/get_connector_by_id'; import { getInferenceAdapter } from './adapters'; -import { createInferenceExecutor, chunksIntoMessage } from './utils'; +import { createInferenceExecutor, chunksIntoMessage, streamToResponse } from './utils'; -export function createChatCompleteApi({ - request, - actions, - logger, -}: { +interface CreateChatCompleteApiOptions { request: KibanaRequest; actions: InferenceStartDependencies['actions']; logger: Logger; -}) { - const chatCompleteAPI: ChatCompleteAPI = ({ +} + +export function createChatCompleteApi(options: CreateChatCompleteApiOptions): ChatCompleteAPI; +export function createChatCompleteApi({ request, actions, logger }: CreateChatCompleteApiOptions) { + return ({ connectorId, messages, toolChoice, tools, system, functionCalling, - }): ChatCompletionResponse => { - return defer(async () => { + stream, + }: ChatCompleteOptions): ChatCompleteCompositeResponse< + ToolOptions, + boolean + > => { + const obs$ = defer(async () => { const actionsClient = await actions.getActionsClientWithRequest(request); const connector = await getConnectorById({ connectorId, actionsClient }); const executor = createInferenceExecutor({ actionsClient, connector }); @@ -73,7 +78,11 @@ export function createChatCompleteApi({ logger, }) ); - }; - return chatCompleteAPI; + if (stream) { + return obs$; + } else { + return streamToResponse(obs$); + } + }; } diff --git a/x-pack/plugins/inference/server/chat_complete/utils/index.ts b/x-pack/plugins/inference/server/chat_complete/utils/index.ts index dea2ac65f4755..d3dc2010cba3a 100644 --- a/x-pack/plugins/inference/server/chat_complete/utils/index.ts +++ b/x-pack/plugins/inference/server/chat_complete/utils/index.ts @@ -12,3 +12,4 @@ export { type InferenceExecutor, } from './inference_executor'; export { chunksIntoMessage } from './chunks_into_message'; +export { streamToResponse } from './stream_to_response'; diff --git a/x-pack/plugins/inference/server/chat_complete/utils/stream_to_response.test.ts b/x-pack/plugins/inference/server/chat_complete/utils/stream_to_response.test.ts new file mode 100644 index 0000000000000..939997a5fef15 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/utils/stream_to_response.test.ts @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { of } from 'rxjs'; +import { ChatCompletionEvent } from '@kbn/inference-common'; +import { chunkEvent, tokensEvent, messageEvent } from '../../test_utils/chat_complete_events'; +import { streamToResponse } from './stream_to_response'; + +describe('streamToResponse', () => { + function fromEvents(...events: ChatCompletionEvent[]) { + return of(...events); + } + + it('returns a response with token count if both message and token events got emitted', async () => { + const response = await streamToResponse( + fromEvents( + chunkEvent('chunk_1'), + chunkEvent('chunk_2'), + tokensEvent({ prompt: 1, completion: 2, total: 3 }), + messageEvent('message') + ) + ); + + expect(response).toEqual({ + content: 'message', + tokens: { + completion: 2, + prompt: 1, + total: 3, + }, + toolCalls: [], + }); + }); + + it('returns a response with tool calls if present', async () => { + const someToolCall = { + toolCallId: '42', + function: { + name: 'my_tool', + arguments: {}, + }, + }; + const response = await streamToResponse( + fromEvents(chunkEvent('chunk_1'), messageEvent('message', [someToolCall])) + ); + + expect(response).toEqual({ + content: 'message', + toolCalls: [someToolCall], + }); + }); + + it('returns a response without token count if only message got emitted', async () => { + const response = await streamToResponse( + fromEvents(chunkEvent('chunk_1'), messageEvent('message')) + ); + + expect(response).toEqual({ + content: 'message', + toolCalls: [], + }); + }); + + it('rejects an error if message event is not emitted', async () => { + await expect( + streamToResponse(fromEvents(chunkEvent('chunk_1'), tokensEvent())) + ).rejects.toThrowErrorMatchingInlineSnapshot(`"No message event found"`); + }); +}); diff --git a/x-pack/plugins/inference/server/chat_complete/utils/stream_to_response.ts b/x-pack/plugins/inference/server/chat_complete/utils/stream_to_response.ts new file mode 100644 index 0000000000000..4bae4fda767cb --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/utils/stream_to_response.ts @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { toArray, map, firstValueFrom } from 'rxjs'; +import { + ChatCompleteResponse, + ChatCompleteStreamResponse, + createInferenceInternalError, + isChatCompletionMessageEvent, + isChatCompletionTokenCountEvent, + ToolOptions, + withoutChunkEvents, +} from '@kbn/inference-common'; + +export const streamToResponse = ( + streamResponse$: ChatCompleteStreamResponse +): Promise> => { + return firstValueFrom( + streamResponse$.pipe( + withoutChunkEvents(), + toArray(), + map((events) => { + const messageEvent = events.find(isChatCompletionMessageEvent); + const tokenEvent = events.find(isChatCompletionTokenCountEvent); + + if (!messageEvent) { + throw createInferenceInternalError('No message event found'); + } + + return { + content: messageEvent.content, + toolCalls: messageEvent.toolCalls, + tokens: tokenEvent?.tokens, + }; + }) + ) + ); +}; diff --git a/x-pack/plugins/inference/server/routes/chat_complete.ts b/x-pack/plugins/inference/server/routes/chat_complete.ts index d4d0d012a78cd..582d4ceb97d45 100644 --- a/x-pack/plugins/inference/server/routes/chat_complete.ts +++ b/x-pack/plugins/inference/server/routes/chat_complete.ts @@ -5,8 +5,14 @@ * 2.0. */ -import { schema, Type } from '@kbn/config-schema'; -import type { CoreSetup, IRouter, Logger, RequestHandlerContext } from '@kbn/core/server'; +import { schema, Type, TypeOf } from '@kbn/config-schema'; +import type { + CoreSetup, + IRouter, + Logger, + RequestHandlerContext, + KibanaRequest, +} from '@kbn/core/server'; import { MessageRole, ToolCall, ToolChoiceType } from '@kbn/inference-common'; import type { ChatCompleteRequestBody } from '../../common/http_apis'; import { createInferenceClient } from '../inference_client'; @@ -84,6 +90,32 @@ export function registerChatCompleteRoute({ router: IRouter; logger: Logger; }) { + async function callChatComplete({ + request, + stream, + }: { + request: KibanaRequest>; + stream: T; + }) { + const actions = await coreSetup + .getStartServices() + .then(([coreStart, pluginsStart]) => pluginsStart.actions); + + const client = createInferenceClient({ request, actions, logger }); + + const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body; + + return client.chatComplete({ + connectorId, + messages, + system, + toolChoice, + tools, + functionCalling, + stream, + }); + } + router.post( { path: '/internal/inference/chat_complete', @@ -92,23 +124,22 @@ export function registerChatCompleteRoute({ }, }, async (context, request, response) => { - const actions = await coreSetup - .getStartServices() - .then(([coreStart, pluginsStart]) => pluginsStart.actions); - - const client = createInferenceClient({ request, actions, logger }); - - const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body; - - const chatCompleteResponse = client.chatComplete({ - connectorId, - messages, - system, - toolChoice, - tools, - functionCalling, + const chatCompleteResponse = await callChatComplete({ request, stream: false }); + return response.ok({ + body: chatCompleteResponse, }); + } + ); + router.post( + { + path: '/internal/inference/chat_complete/stream', + validate: { + body: chatCompleteBodySchema, + }, + }, + async (context, request, response) => { + const chatCompleteResponse = await callChatComplete({ request, stream: true }); return response.ok({ body: observableIntoEventSourceStream(chatCompleteResponse, logger), }); diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts index 26a8fb63ce013..3d8701eba72db 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts @@ -71,6 +71,7 @@ export const generateEsqlTask = ({ chatCompleteApi({ connectorId, functionCalling, + stream: true, system: `${systemMessage} # Current task diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts index aea428208be1d..06e75db09bdc9 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts @@ -33,8 +33,10 @@ export const requestDocumentation = ({ }) => { const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; - return outputApi('request_documentation', { + return outputApi({ + id: 'request_documentation', connectorId, + stream: true, functionCalling, system, previousMessages: messages, diff --git a/x-pack/plugins/inference/server/test_utils/chat_complete_events.ts b/x-pack/plugins/inference/server/test_utils/chat_complete_events.ts new file mode 100644 index 0000000000000..4b09ca9c4dc5a --- /dev/null +++ b/x-pack/plugins/inference/server/test_utils/chat_complete_events.ts @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { + ChatCompletionChunkEvent, + ChatCompletionEventType, + ChatCompletionTokenCountEvent, + ChatCompletionMessageEvent, + ChatCompletionTokenCount, + ToolCall, +} from '@kbn/inference-common'; + +export const chunkEvent = (content: string = 'chunk'): ChatCompletionChunkEvent => ({ + type: ChatCompletionEventType.ChatCompletionChunk, + content, + tool_calls: [], +}); + +export const messageEvent = ( + content: string = 'message', + toolCalls: Array> = [] +): ChatCompletionMessageEvent => ({ + type: ChatCompletionEventType.ChatCompletionMessage, + content, + toolCalls, +}); + +export const tokensEvent = (tokens?: ChatCompletionTokenCount): ChatCompletionTokenCountEvent => ({ + type: ChatCompletionEventType.ChatCompletionTokenCount, + tokens: { + prompt: tokens?.prompt ?? 10, + completion: tokens?.completion ?? 20, + total: tokens?.total ?? 30, + }, +});