Skip to content

Commit

Permalink
[AI Connector] Change completion subAction schema to be OpenAI compat…
Browse files Browse the repository at this point in the history
…ible (#200249)

…

## Summary

Summarize your PR. If it involves visual changes include a screenshot or
gif.


### Checklist

Check the PR satisfies following conditions. 

Reviewers should verify this PR satisfies this list as well.

- [ ] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)
- [ ]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [ ] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [ ] If a plugin configuration key changed, check if it needs to be
allowlisted in the cloud and added to the [docker
list](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)
- [ ] This was checked for breaking HTTP API changes, and any breaking
changes have been approved by the breaking-change committee. The
`release_note:breaking` label should be applied in these situations.
- [ ] [Flaky Test
Runner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was
used on any tests changed
- [ ] The PR description includes the appropriate Release Notes section,
and the correct `release_node:*` label is applied per the
[guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)

### Identify risks

Does this PR introduce any risks? For example, consider risks like hard
to test bugs, performance regression, potential of data loss.

Describe the risk, its severity, and mitigation for each identified
risk. Invite stakeholders and evaluate how to proceed before merging.

- [ ] [See some risk
examples](https://github.com/elastic/kibana/blob/main/RISK_MATRIX.mdx)
- [ ] ...
  • Loading branch information
YulNaumenko authored Dec 18, 2024
1 parent 8ad2d50 commit 6eaa1d0
Show file tree
Hide file tree
Showing 14 changed files with 721 additions and 97 deletions.
5 changes: 4 additions & 1 deletion x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,7 @@ export const getGenAiTokenTracking = async ({
};

export const shouldTrackGenAiToken = (actionTypeId: string) =>
actionTypeId === '.gen-ai' || actionTypeId === '.bedrock' || actionTypeId === '.gemini';
actionTypeId === '.gen-ai' ||
actionTypeId === '.bedrock' ||
actionTypeId === '.gemini' ||
actionTypeId === '.inference';
3 changes: 3 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ export enum ServiceProviderKeys {

export const INFERENCE_CONNECTOR_ID = '.inference';
export enum SUB_ACTION {
UNIFIED_COMPLETION_ASYNC_ITERATOR = 'unified_completion_async_iterator',
UNIFIED_COMPLETION_STREAM = 'unified_completion_stream',
UNIFIED_COMPLETION = 'unified_completion',
COMPLETION = 'completion',
RERANK = 'rerank',
TEXT_EMBEDDING = 'text_embedding',
Expand Down
179 changes: 179 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,176 @@ export const ChatCompleteParamsSchema = schema.object({
input: schema.string(),
});

// subset of OpenAI.ChatCompletionMessageParam https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts
const AIMessage = schema.object({
role: schema.string(),
content: schema.maybe(schema.string()),
name: schema.maybe(schema.string()),
tool_calls: schema.maybe(
schema.arrayOf(
schema.object({
id: schema.string(),
function: schema.object({
arguments: schema.maybe(schema.string()),
name: schema.maybe(schema.string()),
}),
type: schema.string(),
})
)
),
tool_call_id: schema.maybe(schema.string()),
});

const AITool = schema.object({
type: schema.string(),
function: schema.object({
name: schema.string(),
description: schema.maybe(schema.string()),
parameters: schema.maybe(schema.recordOf(schema.string(), schema.any())),
}),
});

// subset of OpenAI.ChatCompletionCreateParamsBase https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts
export const UnifiedChatCompleteParamsSchema = schema.object({
body: schema.object({
messages: schema.arrayOf(AIMessage, { defaultValue: [] }),
model: schema.maybe(schema.string()),
/**
* The maximum number of [tokens](/tokenizer) that can be generated in the chat
* completion. This value can be used to control
* [costs](https://openai.com/api/pricing/) for text generated via API.
*
* This value is now deprecated in favor of `max_completion_tokens`, and is not
* compatible with
* [o1 series models](https://platform.openai.com/docs/guides/reasoning).
*/
max_tokens: schema.maybe(schema.number()),
/**
* Developer-defined tags and values used for filtering completions in the
* [dashboard](https://platform.openai.com/chat-completions).
*/
metadata: schema.maybe(schema.recordOf(schema.string(), schema.string())),
/**
* How many chat completion choices to generate for each input message. Note that
* you will be charged based on the number of generated tokens across all of the
* choices. Keep `n` as `1` to minimize costs.
*/
n: schema.maybe(schema.number()),
/**
* Up to 4 sequences where the API will stop generating further tokens.
*/
stop: schema.maybe(
schema.nullable(schema.oneOf([schema.string(), schema.arrayOf(schema.string())]))
),
/**
* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
* make the output more random, while lower values like 0.2 will make it more
* focused and deterministic.
*
* We generally recommend altering this or `top_p` but not both.
*/
temperature: schema.maybe(schema.number()),
/**
* Controls which (if any) tool is called by the model. `none` means the model will
* not call any tool and instead generates a message. `auto` means the model can
* pick between generating a message or calling one or more tools. `required` means
* the model must call one or more tools. Specifying a particular tool via
* `{"type": "function", "function": {"name": "my_function"}}` forces the model to
* call that tool.
*
* `none` is the default when no tools are present. `auto` is the default if tools
* are present.
*/
tool_choice: schema.maybe(
schema.oneOf([
schema.string(),
schema.object({
type: schema.string(),
function: schema.object({
name: schema.string(),
}),
}),
])
),
/**
* A list of tools the model may call. Currently, only functions are supported as a
* tool. Use this to provide a list of functions the model may generate JSON inputs
* for. A max of 128 functions are supported.
*/
tools: schema.maybe(schema.arrayOf(AITool)),
/**
* An alternative to sampling with temperature, called nucleus sampling, where the
* model considers the results of the tokens with top_p probability mass. So 0.1
* means only the tokens comprising the top 10% probability mass are considered.
*
* We generally recommend altering this or `temperature` but not both.
*/
top_p: schema.maybe(schema.number()),
/**
* A unique identifier representing your end-user, which can help OpenAI to monitor
* and detect abuse.
* [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
*/
user: schema.maybe(schema.string()),
}),
// abort signal from client
signal: schema.maybe(schema.any()),
});

export const UnifiedChatCompleteResponseSchema = schema.object({
id: schema.string(),
choices: schema.arrayOf(
schema.object({
finish_reason: schema.maybe(
schema.nullable(
schema.oneOf([
schema.literal('stop'),
schema.literal('length'),
schema.literal('tool_calls'),
schema.literal('content_filter'),
schema.literal('function_call'),
])
)
),
index: schema.maybe(schema.number()),
message: schema.object({
content: schema.maybe(schema.nullable(schema.string())),
refusal: schema.maybe(schema.nullable(schema.string())),
role: schema.maybe(schema.string()),
tool_calls: schema.maybe(
schema.arrayOf(
schema.object({
id: schema.maybe(schema.string()),
index: schema.maybe(schema.number()),
function: schema.maybe(
schema.object({
arguments: schema.maybe(schema.string()),
name: schema.maybe(schema.string()),
})
),
type: schema.maybe(schema.string()),
}),
{ defaultValue: [] }
)
),
}),
}),
{ defaultValue: [] }
),
created: schema.maybe(schema.number()),
model: schema.maybe(schema.string()),
object: schema.maybe(schema.string()),
usage: schema.maybe(
schema.nullable(
schema.object({
completion_tokens: schema.maybe(schema.number()),
prompt_tokens: schema.maybe(schema.number()),
total_tokens: schema.maybe(schema.number()),
})
)
),
});

export const ChatCompleteResponseSchema = schema.arrayOf(
schema.object({
result: schema.string(),
Expand Down Expand Up @@ -66,3 +236,12 @@ export const TextEmbeddingResponseSchema = schema.arrayOf(
);

export const StreamingResponseSchema = schema.stream();

// Run action schema
export const DashboardActionParamsSchema = schema.object({
dashboardId: schema.string(),
});

export const DashboardActionResponseSchema = schema.object({
available: schema.boolean(),
});
10 changes: 10 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ import {
SparseEmbeddingResponseSchema,
TextEmbeddingParamsSchema,
TextEmbeddingResponseSchema,
UnifiedChatCompleteParamsSchema,
UnifiedChatCompleteResponseSchema,
DashboardActionParamsSchema,
DashboardActionResponseSchema,
} from './schema';
import { ConfigProperties } from '../dynamic_config/types';

export type Config = TypeOf<typeof ConfigSchema>;
export type Secrets = TypeOf<typeof SecretsSchema>;

export type UnifiedChatCompleteParams = TypeOf<typeof UnifiedChatCompleteParamsSchema>;
export type UnifiedChatCompleteResponse = TypeOf<typeof UnifiedChatCompleteResponseSchema>;

export type ChatCompleteParams = TypeOf<typeof ChatCompleteParamsSchema>;
export type ChatCompleteResponse = TypeOf<typeof ChatCompleteResponseSchema>;

Expand All @@ -38,6 +45,9 @@ export type TextEmbeddingResponse = TypeOf<typeof TextEmbeddingResponseSchema>;

export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;

export type DashboardActionParams = TypeOf<typeof DashboardActionParamsSchema>;
export type DashboardActionResponse = TypeOf<typeof DashboardActionResponseSchema>;

export type FieldsConfiguration = Record<string, ConfigProperties>;

export interface InferenceProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,27 @@ export const DEFAULT_TEXT_EMBEDDING_BODY = {
inputType: 'ingest',
};

export const DEFAULT_UNIFIED_CHAT_COMPLETE_BODY = {
body: {
messages: [
{
role: 'user',
content: 'Hello world',
},
],
},
};

export const DEFAULTS_BY_TASK_TYPE: Record<string, unknown> = {
[SUB_ACTION.COMPLETION]: DEFAULT_CHAT_COMPLETE_BODY,
[SUB_ACTION.UNIFIED_COMPLETION]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY,
[SUB_ACTION.UNIFIED_COMPLETION_STREAM]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY,
[SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY,
[SUB_ACTION.RERANK]: DEFAULT_RERANK_BODY,
[SUB_ACTION.SPARSE_EMBEDDING]: DEFAULT_SPARSE_EMBEDDING_BODY,
[SUB_ACTION.TEXT_EMBEDDING]: DEFAULT_TEXT_EMBEDDING_BODY,
};

export const DEFAULT_TASK_TYPE = 'completion';
export const DEFAULT_TASK_TYPE = 'unified_completion';

export const DEFAULT_PROVIDER = 'elasticsearch';
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@ describe('OpenAI action params validation', () => {
subActionParams: { input: ['message test'], query: 'foobar' },
},
{
subAction: SUB_ACTION.COMPLETION,
subActionParams: { input: 'message test' },
subAction: SUB_ACTION.UNIFIED_COMPLETION,
subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } },
},
{
subAction: SUB_ACTION.UNIFIED_COMPLETION_STREAM,
subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } },
},
{
subAction: SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR,
subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } },
},
{
subAction: SUB_ACTION.TEXT_EMBEDDING,
Expand All @@ -55,6 +63,10 @@ describe('OpenAI action params validation', () => {
subAction: SUB_ACTION.SPARSE_EMBEDDING,
subActionParams: { input: 'message test' },
},
{
subAction: SUB_ACTION.COMPLETION,
subActionParams: { input: 'message test' },
},
])(
'validation succeeds when params are valid for subAction $subAction',
async ({ subAction, subActionParams }) => {
Expand All @@ -63,19 +75,25 @@ describe('OpenAI action params validation', () => {
subActionParams,
};
expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: { input: [], subAction: [], inputType: [], query: [] },
errors: { body: [], input: [], subAction: [], inputType: [], query: [] },
});
}
);

test('params validation fails when params is a wrong object', async () => {
const actionParams = {
subAction: SUB_ACTION.COMPLETION,
subAction: SUB_ACTION.UNIFIED_COMPLETION,
subActionParams: { body: 'message {test}' },
};

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: { input: ['Input is required.'], inputType: [], query: [], subAction: [] },
errors: {
body: ['Messages is required.'],
inputType: [],
query: [],
subAction: [],
input: [],
},
});
});

Expand All @@ -86,6 +104,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: [],
inputType: [],
query: [],
Expand All @@ -102,6 +121,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: [],
inputType: [],
query: [],
Expand All @@ -118,6 +138,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: ['Input is required.', 'Input does not have a valid Array format.'],
inputType: [],
query: ['Query is required.'],
Expand All @@ -134,6 +155,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: [],
inputType: ['Input type is required.'],
query: [],
Expand Down
Loading

0 comments on commit 6eaa1d0

Please sign in to comment.