Skip to content

Commit

Permalink
[Inference] Inference plugin + chatComplete API (elastic#188280)
Browse files Browse the repository at this point in the history
This PR introduces an Inference plugin.

## Goals

- Provide a single place for all interactions with large language models
and other generative AI adjacent tasks.
- Abstract away differences between different LLM providers like OpenAI,
Bedrock and Gemini
- Host commonly used LLM-based tasks like generating ES|QL from natural
language and knowledge base recall.
- Allow us to move gradually to the _inference endpoint without
disrupting engineers.

## Architecture and examples

![CleanShot 2024-07-14 at 14 45
27@2x](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f)

## Terminology

The following concepts are referenced throughout this POC:

- **chat completion**: the process in which the LLM generates the next
message in the conversation. This is sometimes referred to as inference,
text completion, text generation or content generation.
- **tasks**: higher level tasks that, based on its input, use the LLM in
conjunction with other services like Elasticsearch to achieve a result.
The example in this POC is natural language to ES|QL.
- **tools**: a set of tools that the LLM can choose to use when
generating the next message. In essence, it allows the consumer of the
API to define a schema for structured output instead of plain text, and
having the LLM select the most appropriate one.
- **tool call**: when the LLM has chosen a tool (schema) to use for its
output, and returns a document that matches the schema, this is referred
to as a tool call.

## Usage examples

```ts

class MyPlugin {
  setup(coreSetup, pluginsSetup) {
    const router = coreSetup.http.createRouter();

    router.post(
      {
        path: '/internal/my_plugin/do_something',
        validate: {
          body: schema.object({
            connectorId: schema.string(),
          }),
        },
      },
      async (context, request, response) => {
        const [coreStart, pluginsStart] = await coreSetup.getStartServices();

        const inferenceClient = pluginsSetup.inference.getClient({ request });

        const chatComplete$ = inferenceClient.chatComplete({
          connectorId: request.body.connectorId,
          system: `Here is my system message`,
          messages: [
            {
              role: MessageRole.User,
              content: 'Do something',
            },
          ],
        });

        const message = await lastValueFrom(
          chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents())
        );

        return response.ok({
          body: {
            message,
          },
        });
      }
    );
  }
}
```

## Implementation

The bulk of the work here is implementing a `chatComplete` API. Here's
what it does:

- Formats the request for the specific LLM that is being called (all
have different API specifications).
- Executes the specified connector with the formatted request.
- Creates and returns an Observable, and starts reading from the stream.
- Every event in the stream is normalized to a format that is close to
(but not exactly the same) as OpenAI's format, and emitted as a value
from the Observable.
- When the stream ends, the individual events (chunks) are concatenated
into a single message.
- If the LLM has called any tools, the tool call is validated according
to its schema.
- After emitting the message, the Observable completes

There's also a thin wrapper around this API, which is called the
`output` API. It simplifies a few things:

- It doesn't require a conversation (list of messages), a simple `input`
string suffices.
- You can define a schema for the output of the LLM. 
- It drops the token count events that are emitted
- It simplifies the event format (update & complete)

### Observable event streams

These APIs, both on the client and the server, return Observables that
emit events. When converting the Observable into a stream, the following
things happen:

- Errors are caught and serialized as events sent over the stream (after
an error, the stream ends).
- The response stream outputs data as [server-sent
events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)
- The client that reads the stream, parses the event source as an
Observable, and if it encounters a serialized error, it deserializes it
and throws an error in the Observable.

### Errors

All known errors are instances, and not extensions, from the
`InferenceTaskError` base class, which has a `code`, a `message`, and
`meta` information about the error. This allows us to serialize and
deserialize errors over the wire without a complicated factory pattern.

### Tools

Tools are defined as a record, with a `description` and optionally a
`schema`. The reason why it's a record is because of type-safety. This
allows us to have fully typed tool calls (e.g. when the name of the tool
being called is `x`, its arguments are typed as the schema of `x`).

## Notes for reviewers

- I've only added one reference implementation for a connector adapter,
which is OpenAI. Adding more would create noise in the PR, but I can add
them as well. Bedrock would need simulated function calling, which I
would also expect to be handled by this plugin.
- Similarly, the natural language to ES|QL task just creates dummy
steps, as moving the entire implementation would mean 1000s of
additional LOC due to it needing the documentation, for instance.
- Observables over promises/iterators: Observables are a well-defined
and widely-adopted solution for async programming. Promises are not
suitable for streamed/chunked responses because there are no
intermediate values. Async iterators are not widely adopted for Kibana
engineers.
- JSON Schema over Zod: I've tried using Zod, because I like its
ergonomics over plain JSON Schema, but we need to convert it to JSON
Schema at some point, which is a lossy conversion, creating a risk of
using features that we cannot convert to JSON Schema. Additionally,
tools for converting Zod to and [from JSON Schema are not always
suitable
](https://github.com/StefanTerdell/json-schema-to-zod#use-at-runtime).
I've implemented my own JSON Schema to type definition, as
[json-schema-to-ts](https://github.com/ThomasAribart/json-schema-to-ts)
is very slow.
- There's no option for raw input or output. There could be, but it
would defeat the purpose of the normalization that the `chatComplete`
API handles. At that point it might be better to use the connector
directly.
- That also means that for LangChain, something would be needed to
convert the Observable into an async iterator that returns
OpenAI-compatible output. This is doable, although it would be nice if
we could just use the output from the OpenAI API in that case.
- I have not made room for any vendor-specific parameters in the
`chatComplete` API. We might need it, but hopefully not.
- I think type safety is critical here, so there is some TypeScript
voodoo in some places to make that happen.
- `system` is not a message in the conversation, but a separate
property. Given the semantics of a system message (there can only be
one, and only at the beginning of the conversation), I think it's easier
to make it a top-level property than a message type.

---------

Co-authored-by: kibanamachine <[email protected]>
Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
3 people authored Aug 6, 2024
1 parent a048ad1 commit 769fb99
Show file tree
Hide file tree
Showing 50 changed files with 3,127 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ x-pack/packages/index-management @elastic/kibana-management
x-pack/plugins/index_management @elastic/kibana-management
test/plugin_functional/plugins/index_patterns @elastic/kibana-data-discovery
x-pack/packages/ml/inference_integration_flyout @elastic/ml-ui
x-pack/plugins/inference @elastic/kibana-core
x-pack/packages/kbn-infra-forge @elastic/obs-ux-management-team
x-pack/plugins/observability_solution/infra @elastic/obs-ux-logs-team @elastic/obs-ux-infra_services-team
x-pack/plugins/ingest_pipelines @elastic/kibana-management
Expand Down Expand Up @@ -1287,6 +1288,7 @@ x-pack/test/observability_ai_assistant_functional @elastic/obs-ai-assistant
/x-pack/test_serverless/**/test_suites/common/saved_objects_management/ @elastic/kibana-core
/x-pack/test_serverless/api_integration/test_suites/common/core/ @elastic/kibana-core
/x-pack/test_serverless/api_integration/test_suites/**/telemetry/ @elastic/kibana-core
/x-pack/plugins/inference @elastic/kibana-core @elastic/obs-ai-assistant @elastic/security-generative-ai
#CC# /src/core/server/csp/ @elastic/kibana-core
#CC# /src/plugins/saved_objects/ @elastic/kibana-core
#CC# /x-pack/plugins/cloud/ @elastic/kibana-core
Expand Down
5 changes: 5 additions & 0 deletions docs/developer/plugin-list.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ Index Management by running this series of requests in Console:
|This service is exposed from the Index Management setup contract and can be used to add content to the indices list and the index details page.
|{kib-repo}blob/{branch}/x-pack/plugins/inference/README.md[inference]
|The inference plugin is a central place to handle all interactions with the Elasticsearch Inference API and
external LLM APIs. Its goals are:
|{kib-repo}blob/{branch}/x-pack/plugins/observability_solution/infra/README.md[infra]
|This is the home of the infra plugin, which aims to provide a solution for
the infrastructure monitoring use-case within Kibana.
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@
"@kbn/index-management": "link:x-pack/packages/index-management",
"@kbn/index-management-plugin": "link:x-pack/plugins/index_management",
"@kbn/index-patterns-test-plugin": "link:test/plugin_functional/plugins/index_patterns",
"@kbn/inference-plugin": "link:x-pack/plugins/inference",
"@kbn/inference_integration_flyout": "link:x-pack/packages/ml/inference_integration_flyout",
"@kbn/infra-forge": "link:x-pack/packages/kbn-infra-forge",
"@kbn/infra-plugin": "link:x-pack/plugins/observability_solution/infra",
Expand Down
1 change: 1 addition & 0 deletions packages/kbn-optimizer/limits.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pageLoadAssetSize:
imageEmbeddable: 12500
indexLifecycleManagement: 107090
indexManagement: 140608
inference: 20403
infra: 184320
ingestPipelines: 58003
inputControlVis: 172675
Expand Down
2 changes: 2 additions & 0 deletions tsconfig.base.json
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,8 @@
"@kbn/index-patterns-test-plugin/*": ["test/plugin_functional/plugins/index_patterns/*"],
"@kbn/inference_integration_flyout": ["x-pack/packages/ml/inference_integration_flyout"],
"@kbn/inference_integration_flyout/*": ["x-pack/packages/ml/inference_integration_flyout/*"],
"@kbn/inference-plugin": ["x-pack/plugins/inference"],
"@kbn/inference-plugin/*": ["x-pack/plugins/inference/*"],
"@kbn/infra-forge": ["x-pack/packages/kbn-infra-forge"],
"@kbn/infra-forge/*": ["x-pack/packages/kbn-infra-forge/*"],
"@kbn/infra-plugin": ["x-pack/plugins/observability_solution/infra"],
Expand Down
1 change: 1 addition & 0 deletions x-pack/.i18nrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"xpack.fleet": "plugins/fleet",
"xpack.ingestPipelines": "plugins/ingest_pipelines",
"xpack.integrationAssistant": "plugins/integration_assistant",
"xpack.inference": "plugins/inference",
"xpack.investigate": "plugins/observability_solution/investigate",
"xpack.investigateApp": "plugins/observability_solution/investigate_app",
"xpack.kubernetesSecurity": "plugins/kubernetes_security",
Expand Down
100 changes: 100 additions & 0 deletions x-pack/plugins/inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Inference plugin

The inference plugin is a central place to handle all interactions with the Elasticsearch Inference API and
external LLM APIs. Its goals are:

- Provide a single place for all interactions with large language models and other generative AI adjacent tasks.
- Abstract away differences between different LLM providers like OpenAI, Bedrock and Gemini
- Host commonly used LLM-based tasks like generating ES|QL from natural language and knowledge base recall.
- Allow us to move gradually to the \_inference endpoint without disrupting engineers.

## Architecture and examples

![CleanShot 2024-07-14 at 14 45 27@2x](https://github.com/user-attachments/assets/e65a3e47-bce1-4dcf-bbed-4f8ac12a104f)

## Terminology

The following concepts are commonly used throughout the plugin:

- **chat completion**: the process in which the LLM generates the next message in the conversation. This is sometimes referred to as inference, text completion, text generation or content generation.
- **tasks**: higher level tasks that, based on its input, use the LLM in conjunction with other services like Elasticsearch to achieve a result. The example in this POC is natural language to ES|QL.
- **tools**: a set of tools that the LLM can choose to use when generating the next message. In essence, it allows the consumer of the API to define a schema for structured output instead of plain text, and having the LLM select the most appropriate one.
- **tool call**: when the LLM has chosen a tool (schema) to use for its output, and returns a document that matches the schema, this is referred to as a tool call.

## Usage examples

```ts
class MyPlugin {
setup(coreSetup, pluginsSetup) {
const router = coreSetup.http.createRouter();

router.post(
{
path: '/internal/my_plugin/do_something',
validate: {
body: schema.object({
connectorId: schema.string(),
}),
},
},
async (context, request, response) => {
const [coreStart, pluginsStart] = await coreSetup.getStartServices();

const inferenceClient = pluginsSetup.inference.getClient({ request });

const chatComplete$ = inferenceClient.chatComplete({
connectorId: request.body.connectorId,
system: `Here is my system message`,
messages: [
{
role: MessageRole.User,
content: 'Do something',
},
],
});

const message = await lastValueFrom(
chatComplete$.pipe(withoutTokenCountEvents(), withoutChunkEvents())
);

return response.ok({
body: {
message,
},
});
}
);
}
}
```

## Services

### `chatComplete`:

`chatComplete` generates a response to a prompt or a conversation using the LLM. Here's what is supported:

- Normalizing request and response formats from different connector types (e.g. OpenAI, Bedrock, Claude, Elastic Inference Service)
- Tool calling and validation of tool calls
- Emits token count events
- Emits message events, which is the concatenated message based on the response chunks

### `output`

`output` is a wrapper around `chatComplete` that is catered towards a single use case: having the LLM output a structured response, based on a schema. It also drops the token count events to simplify usage.

### Observable event streams

These APIs, both on the client and the server, return Observables that emit events. When converting the Observable into a stream, the following things happen:

- Errors are caught and serialized as events sent over the stream (after an error, the stream ends).
- The response stream outputs data as [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)
- The client that reads the stream, parses the event source as an Observable, and if it encounters a serialized error, it deserializes it and throws an error in the Observable.

### Errors

All known errors are instances, and not extensions, from the `InferenceTaskError` base class, which has a `code`, a `message`, and `meta` information about the error. This allows us to serialize and deserialize errors over the wire without a complicated factory pattern.

### Tools

Tools are defined as a record, with a `description` and optionally a `schema`. The reason why it's a record is because of type-safety. This allows us to have fully typed tool calls (e.g. when the name of the tool being called is `x`, its arguments are typed as the schema of `x`).
99 changes: 99 additions & 0 deletions x-pack/plugins/inference/common/chat_complete/errors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { i18n } from '@kbn/i18n';
import { InferenceTaskError } from '../errors';
import type { UnvalidatedToolCall } from './tools';

export enum ChatCompletionErrorCode {
TokenLimitReachedError = 'tokenLimitReachedError',
ToolNotFoundError = 'toolNotFoundError',
ToolValidationError = 'toolValidationError',
}

export type ChatCompletionTokenLimitReachedError = InferenceTaskError<
ChatCompletionErrorCode.TokenLimitReachedError,
{
tokenLimit?: number;
tokenCount?: number;
}
>;

export type ChatCompletionToolNotFoundError = InferenceTaskError<
ChatCompletionErrorCode.ToolNotFoundError,
{
name: string;
}
>;

export type ChatCompletionToolValidationError = InferenceTaskError<
ChatCompletionErrorCode.ToolValidationError,
{
name?: string;
arguments?: string;
errorsText?: string;
toolCalls?: UnvalidatedToolCall[];
}
>;

export function createTokenLimitReachedError(
tokenLimit?: number,
tokenCount?: number
): ChatCompletionTokenLimitReachedError {
return new InferenceTaskError(
ChatCompletionErrorCode.TokenLimitReachedError,
i18n.translate('xpack.inference.chatCompletionError.tokenLimitReachedError', {
defaultMessage: `Token limit reached. Token limit is {tokenLimit}, but the current conversation has {tokenCount} tokens.`,
values: { tokenLimit, tokenCount },
}),
{ tokenLimit, tokenCount }
);
}

export function createToolNotFoundError(name: string): ChatCompletionToolNotFoundError {
return new InferenceTaskError(
ChatCompletionErrorCode.ToolNotFoundError,
`Tool ${name} called but was not available`,
{
name,
}
);
}

export function createToolValidationError(
message: string,
meta: {
name?: string;
arguments?: string;
errorsText?: string;
toolCalls?: UnvalidatedToolCall[];
}
): ChatCompletionToolValidationError {
return new InferenceTaskError(ChatCompletionErrorCode.ToolValidationError, message, meta);
}

export function isToolValidationError(error?: Error): error is ChatCompletionToolValidationError {
return (
error instanceof InferenceTaskError &&
error.code === ChatCompletionErrorCode.ToolValidationError
);
}

export function isTokenLimitReachedError(
error: Error
): error is ChatCompletionTokenLimitReachedError {
return (
error instanceof InferenceTaskError &&
error.code === ChatCompletionErrorCode.TokenLimitReachedError
);
}

export function isToolNotFoundError(error: Error): error is ChatCompletionToolNotFoundError {
return (
error instanceof InferenceTaskError && error.code === ChatCompletionErrorCode.ToolNotFoundError
);
}
95 changes: 95 additions & 0 deletions x-pack/plugins/inference/common/chat_complete/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { Observable } from 'rxjs';
import type { InferenceTaskEventBase } from '../tasks';
import type { ToolCall, ToolCallsOf, ToolOptions } from './tools';

export enum MessageRole {
User = 'user',
Assistant = 'assistant',
Tool = 'tool',
}

interface MessageBase<TRole extends MessageRole> {
role: TRole;
}

export type UserMessage = MessageBase<MessageRole.User> & { content: string };

export type AssistantMessage = MessageBase<MessageRole.Assistant> & {
content: string | null;
toolCalls?: Array<ToolCall<string, Record<string, any> | undefined>>;
};

export type ToolMessage<TToolResponse extends Record<string, any> | unknown> =
MessageBase<MessageRole.Tool> & {
toolCallId: string;
response: TToolResponse;
};

export type Message = UserMessage | AssistantMessage | ToolMessage<unknown>;

export type ChatCompletionMessageEvent<TToolOptions extends ToolOptions> =
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionMessage> & {
content: string;
} & { toolCalls: ToolCallsOf<TToolOptions>['toolCalls'] };

export type ChatCompletionResponse<TToolOptions extends ToolOptions = ToolOptions> = Observable<
ChatCompletionEvent<TToolOptions>
>;

export enum ChatCompletionEventType {
ChatCompletionChunk = 'chatCompletionChunk',
ChatCompletionTokenCount = 'chatCompletionTokenCount',
ChatCompletionMessage = 'chatCompletionMessage',
}

export interface ChatCompletionChunkToolCall {
index: number;
toolCallId: string;
function: {
name: string;
arguments: string;
};
}

export type ChatCompletionChunkEvent =
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionChunk> & {
content: string;
tool_calls: ChatCompletionChunkToolCall[];
};

export type ChatCompletionTokenCountEvent =
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionTokenCount> & {
tokens: {
prompt: number;
completion: number;
total: number;
};
};

export type ChatCompletionEvent<TToolOptions extends ToolOptions = ToolOptions> =
| ChatCompletionChunkEvent
| ChatCompletionTokenCountEvent
| ChatCompletionMessageEvent<TToolOptions>;

/**
* Request a completion from the LLM based on a prompt or conversation.
*
* @param {string} options.connectorId The ID of the connector to use
* @param {string} [options.system] A system message that defines the behavior of the LLM.
* @param {Message[]} options.message A list of messages that make up the conversation to be completed.
* @param {ToolChoice} [options.toolChoice] Force the LLM to call a (specific) tool, or no tool
* @param {Record<string, ToolDefinition>} [options.tools] A map of tools that can be called by the LLM
*/
export type ChatCompleteAPI<TToolOptions extends ToolOptions = ToolOptions> = (
options: {
connectorId: string;
system?: string;
messages: Message[];
} & TToolOptions
) => ChatCompletionResponse<TToolOptions>;
16 changes: 16 additions & 0 deletions x-pack/plugins/inference/common/chat_complete/request.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { Message } from '.';
import { ToolOptions } from './tools';

export type ChatCompleteRequestBody = {
connectorId: string;
stream?: boolean;
system?: string;
messages: Message[];
} & ToolOptions;
Loading

0 comments on commit 769fb99

Please sign in to comment.