Skip to content

Commit

Permalink
[8.x] [inference] add pre-bound versions of `chatComplete` …
Browse files Browse the repository at this point in the history
…and `output` APIs (#200568) (#201028)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[inference] add pre-bound versions of `chatComplete` and
`output` APIs
(#200568)](#200568)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Pierre
Gayvallet","email":"[email protected]"},"sourceCommit":{"committedDate":"2024-11-20T19:09:11Z","message":"[inference]
add pre-bound versions of `chatComplete` and `output` APIs
(#200568)\n\n## Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/199084\r\n\r\nIntroduce
pre-bound versions of the inference APIs.\r\n\r\nAccessing the bound
versions can be done using the same `getClient` API,\r\nvia an
additional `bindTo` parameter:\r\n\r\n**without
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({ request });\r\n\r\nconst chatResponse
= inferenceClient.chatComplete({\r\n connectorId: 'my-connector-id',\r\n
functionCalling: 'simulated',\r\n messages: [{ role: MessageRole.User,
content: 'Do something' }],\r\n});\r\n```\r\n\r\n**with
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({\r\n request,\r\n bindTo: {\r\n
connectorId: 'my-connector-id',\r\n functionCalling: 'simulated',\r\n
}\r\n});\r\n\r\nconst chatResponse = inferenceClient.chatComplete({\r\n
messages: [{ role: MessageRole.User, content: 'Do something'
}],\r\n});\r\n```\r\n\r\n*Note: this is only done for the server-side,
as there isn't much value\r\nin scoping APIs on the browser side in my
opinion*\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine
<[email protected]>","sha":"3c8f0777f4a4563824d0fb1f545524bf4346e3a2","branchLabelMapping":{"^v9.0.0$":"main","^v8.17.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","backport:prev-minor","Team:AI
Infra","v8.17.0"],"title":"[inference] add pre-bound versions of
`chatComplete` and `output`
APIs","number":200568,"url":"https://github.com/elastic/kibana/pull/200568","mergeCommit":{"message":"[inference]
add pre-bound versions of `chatComplete` and `output` APIs
(#200568)\n\n## Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/199084\r\n\r\nIntroduce
pre-bound versions of the inference APIs.\r\n\r\nAccessing the bound
versions can be done using the same `getClient` API,\r\nvia an
additional `bindTo` parameter:\r\n\r\n**without
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({ request });\r\n\r\nconst chatResponse
= inferenceClient.chatComplete({\r\n connectorId: 'my-connector-id',\r\n
functionCalling: 'simulated',\r\n messages: [{ role: MessageRole.User,
content: 'Do something' }],\r\n});\r\n```\r\n\r\n**with
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({\r\n request,\r\n bindTo: {\r\n
connectorId: 'my-connector-id',\r\n functionCalling: 'simulated',\r\n
}\r\n});\r\n\r\nconst chatResponse = inferenceClient.chatComplete({\r\n
messages: [{ role: MessageRole.User, content: 'Do something'
}],\r\n});\r\n```\r\n\r\n*Note: this is only done for the server-side,
as there isn't much value\r\nin scoping APIs on the browser side in my
opinion*\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine
<[email protected]>","sha":"3c8f0777f4a4563824d0fb1f545524bf4346e3a2"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/200568","number":200568,"mergeCommit":{"message":"[inference]
add pre-bound versions of `chatComplete` and `output` APIs
(#200568)\n\n## Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/199084\r\n\r\nIntroduce
pre-bound versions of the inference APIs.\r\n\r\nAccessing the bound
versions can be done using the same `getClient` API,\r\nvia an
additional `bindTo` parameter:\r\n\r\n**without
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({ request });\r\n\r\nconst chatResponse
= inferenceClient.chatComplete({\r\n connectorId: 'my-connector-id',\r\n
functionCalling: 'simulated',\r\n messages: [{ role: MessageRole.User,
content: 'Do something' }],\r\n});\r\n```\r\n\r\n**with
bindings**\r\n```ts\r\nconst inferenceClient =
myStartDeps.inference.getClient({\r\n request,\r\n bindTo: {\r\n
connectorId: 'my-connector-id',\r\n functionCalling: 'simulated',\r\n
}\r\n});\r\n\r\nconst chatResponse = inferenceClient.chatComplete({\r\n
messages: [{ role: MessageRole.User, content: 'Do something'
}],\r\n});\r\n```\r\n\r\n*Note: this is only done for the server-side,
as there isn't much value\r\nin scoping APIs on the browser side in my
opinion*\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine
<[email protected]>","sha":"3c8f0777f4a4563824d0fb1f545524bf4346e3a2"}},{"branch":"8.x","label":"v8.17.0","branchLabelMappingKey":"^v8.17.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Pierre Gayvallet <[email protected]>
  • Loading branch information
kibanamachine and pgayvallet authored Nov 20, 2024
1 parent 63934e8 commit 29209cb
Show file tree
Hide file tree
Showing 29 changed files with 811 additions and 61 deletions.
6 changes: 6 additions & 0 deletions x-pack/packages/ai-infra/inference-common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ export {
type ChatCompleteStreamResponse,
type ChatCompleteResponse,
type ChatCompletionTokenCount,
type BoundChatCompleteAPI,
type BoundChatCompleteOptions,
type UnboundChatCompleteOptions,
withoutTokenCountEvents,
withoutChunkEvents,
isChatCompletionMessageEvent,
Expand All @@ -59,6 +62,9 @@ export {
type OutputUpdateEvent,
type Output,
type OutputEvent,
type BoundOutputAPI,
type BoundOutputOptions,
type UnboundOutputOptions,
isOutputCompleteEvent,
isOutputUpdateEvent,
isOutputEvent,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { ChatCompleteOptions, ChatCompleteCompositeResponse } from './api';
import type { ToolOptions } from './tools';

/**
* Static options used to call the {@link BoundChatCompleteAPI}
*/
export type BoundChatCompleteOptions<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = Pick<ChatCompleteOptions<TToolOptions, TStream>, 'connectorId' | 'functionCalling'>;

/**
* Options used to call the {@link BoundChatCompleteAPI}
*/
export type UnboundChatCompleteOptions<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = Omit<ChatCompleteOptions<TToolOptions, TStream>, 'connectorId' | 'functionCalling'>;

/**
* Version of {@link ChatCompleteAPI} that got pre-bound to a set of static parameters
*/
export type BoundChatCompleteAPI = <
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
>(
options: UnboundChatCompleteOptions<TToolOptions, TStream>
) => ChatCompleteCompositeResponse<TToolOptions, TStream>;
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ export type {
ChatCompleteStreamResponse,
ChatCompleteResponse,
} from './api';
export type {
BoundChatCompleteAPI,
BoundChatCompleteOptions,
UnboundChatCompleteOptions,
} from './bound_api';
export {
ChatCompletionEventType,
type ChatCompletionMessageEvent,
Expand Down
38 changes: 38 additions & 0 deletions x-pack/packages/ai-infra/inference-common/src/output/bound_api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { OutputOptions, OutputCompositeResponse } from './api';
import type { ToolSchema } from '../chat_complete/tool_schema';

/**
* Static options used to call the {@link BoundOutputAPI}
*/
export type BoundOutputOptions<
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
> = Pick<OutputOptions<TId, TOutputSchema, TStream>, 'connectorId' | 'functionCalling'>;

/**
* Options used to call the {@link BoundOutputAPI}
*/
export type UnboundOutputOptions<
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
> = Omit<OutputOptions<TId, TOutputSchema, TStream>, 'connectorId' | 'functionCalling'>;

/**
* Version of {@link OutputAPI} that got pre-bound to a set of static parameters
*/
export type BoundOutputAPI = <
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
>(
options: UnboundOutputOptions<TId, TOutputSchema, TStream>
) => OutputCompositeResponse<TId, TOutputSchema, TStream>;
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export type {
OutputResponse,
OutputStreamResponse,
} from './api';
export type { BoundOutputAPI, BoundOutputOptions, UnboundOutputOptions } from './bound_api';
export {
OutputEventType,
type OutputCompleteEvent,
Expand Down
19 changes: 19 additions & 0 deletions x-pack/plugins/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,25 @@ class MyPlugin {
}
```

### Binding common parameters

It is also possible to bind a client to its configuration parameters, to avoid passing connectorId
to every call, for example, using the `bindTo` parameter when creating the client.

```ts
const inferenceClient = myStartDeps.inference.getClient({
request,
bindTo: {
connectorId: 'my-connector-id',
functionCalling: 'simulated',
}
});

const chatResponse = inferenceClient.chatComplete({
messages: [{ role: MessageRole.User, content: 'Do something' }],
});
```

## APIs

### `chatComplete` API:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import {
BoundChatCompleteOptions,
ChatCompleteAPI,
MessageRole,
UnboundChatCompleteOptions,
} from '@kbn/inference-common';
import { bindChatComplete } from './bind_chat_complete';

describe('bindChatComplete', () => {
let chatComplete: ChatCompleteAPI & jest.MockedFn<ChatCompleteAPI>;

beforeEach(() => {
chatComplete = jest.fn();
});

it('calls chatComplete with both bound and unbound params', async () => {
const bound: BoundChatCompleteOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};

const unbound: UnboundChatCompleteOptions = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
};

const boundApi = bindChatComplete(chatComplete, bound);

await boundApi({ ...unbound });

expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
...bound,
...unbound,
});
});

it('forwards the response from chatComplete', async () => {
const expectedReturnValue = Symbol('something');
chatComplete.mockResolvedValue(expectedReturnValue as any);

const boundApi = bindChatComplete(chatComplete, { connectorId: 'my-connector' });

const result = await boundApi({
messages: [{ role: MessageRole.User, content: 'hello there' }],
});

expect(result).toEqual(expectedReturnValue);
});

it('only passes the expected parameters from the bound param object', async () => {
const bound = {
connectorId: 'some-id',
functionCalling: 'native',
foo: 'bar',
} as BoundChatCompleteOptions;

const unbound: UnboundChatCompleteOptions = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
};

const boundApi = bindChatComplete(chatComplete, bound);

await boundApi({ ...unbound });

expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
messages: unbound.messages,
});
});

it('ignores mutations of the bound parameters after binding', async () => {
const bound: BoundChatCompleteOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};

const unbound: UnboundChatCompleteOptions = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
};

const boundApi = bindChatComplete(chatComplete, bound);

bound.connectorId = 'some-other-id';

await boundApi({ ...unbound });

expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
messages: unbound.messages,
});
});

it('does not allow overriding bound parameters with the unbound object', async () => {
const bound: BoundChatCompleteOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};

const unbound = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
connectorId: 'overridden',
} as UnboundChatCompleteOptions;

const boundApi = bindChatComplete(chatComplete, bound);

await boundApi({ ...unbound });

expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
messages: unbound.messages,
});
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type {
ChatCompleteAPI,
ChatCompleteOptions,
BoundChatCompleteAPI,
BoundChatCompleteOptions,
UnboundChatCompleteOptions,
ToolOptions,
} from '@kbn/inference-common';

/**
* Bind chatComplete to the provided parameters,
* returning a bound version of the API.
*/
export function bindChatComplete(
chatComplete: ChatCompleteAPI,
boundParams: BoundChatCompleteOptions
): BoundChatCompleteAPI;
export function bindChatComplete(
chatComplete: ChatCompleteAPI,
boundParams: BoundChatCompleteOptions
) {
const { connectorId, functionCalling } = boundParams;
return (unboundParams: UnboundChatCompleteOptions<ToolOptions, boolean>) => {
const params: ChatCompleteOptions<ToolOptions, boolean> = {
...unboundParams,
connectorId,
functionCalling,
};
return chatComplete(params);
};
}
8 changes: 8 additions & 0 deletions x-pack/plugins/inference/common/chat_complete/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export { bindChatComplete } from './bind_chat_complete';
2 changes: 1 addition & 1 deletion x-pack/plugins/inference/common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ export {

export { generateFakeToolCallId } from './utils/generate_fake_tool_call_id';

export { createOutputApi } from './create_output_api';
export { createOutputApi } from './output';

export type { ChatCompleteRequestBody, GetConnectorsResponseBody } from './http_apis';
Loading

0 comments on commit 29209cb

Please sign in to comment.