Skip to content

Commit

Permalink
[inference] add pre-bound versions of chatComplete and output APIs (
Browse files Browse the repository at this point in the history
#200568)

## Summary

Fix #199084

Introduce pre-bound versions of the inference APIs.

Accessing the bound versions can be done using the same `getClient` API,
via an additional `bindTo` parameter:

**without bindings**
```ts
const inferenceClient = myStartDeps.inference.getClient({ request });

const chatResponse = inferenceClient.chatComplete({
  connectorId: 'my-connector-id',
  functionCalling: 'simulated',
  messages: [{ role: MessageRole.User, content: 'Do something' }],
});
```

**with bindings**
```ts
const inferenceClient = myStartDeps.inference.getClient({
  request,
  bindTo: {
   connectorId: 'my-connector-id',
   functionCalling: 'simulated',
  }
});

const chatResponse = inferenceClient.chatComplete({
  messages: [{ role: MessageRole.User, content: 'Do something' }],
});
```

*Note: this is only done for the server-side, as there isn't much value
in scoping APIs on the browser side in my opinion*

---------

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
pgayvallet and elasticmachine authored Nov 20, 2024
1 parent c04d80b commit 3c8f077
Show file tree
Hide file tree
Showing 29 changed files with 811 additions and 61 deletions.
6 changes: 6 additions & 0 deletions x-pack/packages/ai-infra/inference-common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ export {
type ChatCompleteStreamResponse,
type ChatCompleteResponse,
type ChatCompletionTokenCount,
type BoundChatCompleteAPI,
type BoundChatCompleteOptions,
type UnboundChatCompleteOptions,
withoutTokenCountEvents,
withoutChunkEvents,
isChatCompletionMessageEvent,
Expand All @@ -59,6 +62,9 @@ export {
type OutputUpdateEvent,
type Output,
type OutputEvent,
type BoundOutputAPI,
type BoundOutputOptions,
type UnboundOutputOptions,
isOutputCompleteEvent,
isOutputUpdateEvent,
isOutputEvent,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { ChatCompleteOptions, ChatCompleteCompositeResponse } from './api';
import type { ToolOptions } from './tools';

/**
* Static options used to call the {@link BoundChatCompleteAPI}
*/
export type BoundChatCompleteOptions<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = Pick<ChatCompleteOptions<TToolOptions, TStream>, 'connectorId' | 'functionCalling'>;

/**
* Options used to call the {@link BoundChatCompleteAPI}
*/
export type UnboundChatCompleteOptions<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = Omit<ChatCompleteOptions<TToolOptions, TStream>, 'connectorId' | 'functionCalling'>;

/**
* Version of {@link ChatCompleteAPI} that got pre-bound to a set of static parameters
*/
export type BoundChatCompleteAPI = <
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
>(
options: UnboundChatCompleteOptions<TToolOptions, TStream>
) => ChatCompleteCompositeResponse<TToolOptions, TStream>;
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ export type {
ChatCompleteStreamResponse,
ChatCompleteResponse,
} from './api';
export type {
BoundChatCompleteAPI,
BoundChatCompleteOptions,
UnboundChatCompleteOptions,
} from './bound_api';
export {
ChatCompletionEventType,
type ChatCompletionMessageEvent,
Expand Down
38 changes: 38 additions & 0 deletions x-pack/packages/ai-infra/inference-common/src/output/bound_api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { OutputOptions, OutputCompositeResponse } from './api';
import type { ToolSchema } from '../chat_complete/tool_schema';

/**
* Static options used to call the {@link BoundOutputAPI}
*/
export type BoundOutputOptions<
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
> = Pick<OutputOptions<TId, TOutputSchema, TStream>, 'connectorId' | 'functionCalling'>;

/**
* Options used to call the {@link BoundOutputAPI}
*/
export type UnboundOutputOptions<
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
> = Omit<OutputOptions<TId, TOutputSchema, TStream>, 'connectorId' | 'functionCalling'>;

/**
* Version of {@link OutputAPI} that got pre-bound to a set of static parameters
*/
export type BoundOutputAPI = <
TId extends string = string,
TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined,
TStream extends boolean = false
>(
options: UnboundOutputOptions<TId, TOutputSchema, TStream>
) => OutputCompositeResponse<TId, TOutputSchema, TStream>;
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export type {
OutputResponse,
OutputStreamResponse,
} from './api';
export type { BoundOutputAPI, BoundOutputOptions, UnboundOutputOptions } from './bound_api';
export {
OutputEventType,
type OutputCompleteEvent,
Expand Down
19 changes: 19 additions & 0 deletions x-pack/plugins/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,25 @@ class MyPlugin {
}
```

### Binding common parameters

It is also possible to bind a client to its configuration parameters, to avoid passing connectorId
to every call, for example, using the `bindTo` parameter when creating the client.

```ts
const inferenceClient = myStartDeps.inference.getClient({
request,
bindTo: {
connectorId: 'my-connector-id',
functionCalling: 'simulated',
}
});

const chatResponse = inferenceClient.chatComplete({
messages: [{ role: MessageRole.User, content: 'Do something' }],
});
```

## APIs

### `chatComplete` API:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import {
BoundChatCompleteOptions,
ChatCompleteAPI,
MessageRole,
UnboundChatCompleteOptions,
} from '@kbn/inference-common';
import { bindChatComplete } from './bind_chat_complete';

describe('bindChatComplete', () => {
let chatComplete: ChatCompleteAPI & jest.MockedFn<ChatCompleteAPI>;

beforeEach(() => {
chatComplete = jest.fn();
});

it('calls chatComplete with both bound and unbound params', async () => {
const bound: BoundChatCompleteOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};

const unbound: UnboundChatCompleteOptions = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
};

const boundApi = bindChatComplete(chatComplete, bound);

await boundApi({ ...unbound });

expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
...bound,
...unbound,
});
});

it('forwards the response from chatComplete', async () => {
const expectedReturnValue = Symbol('something');
chatComplete.mockResolvedValue(expectedReturnValue as any);

const boundApi = bindChatComplete(chatComplete, { connectorId: 'my-connector' });

const result = await boundApi({
messages: [{ role: MessageRole.User, content: 'hello there' }],
});

expect(result).toEqual(expectedReturnValue);
});

it('only passes the expected parameters from the bound param object', async () => {
const bound = {
connectorId: 'some-id',
functionCalling: 'native',
foo: 'bar',
} as BoundChatCompleteOptions;

const unbound: UnboundChatCompleteOptions = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
};

const boundApi = bindChatComplete(chatComplete, bound);

await boundApi({ ...unbound });

expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
messages: unbound.messages,
});
});

it('ignores mutations of the bound parameters after binding', async () => {
const bound: BoundChatCompleteOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};

const unbound: UnboundChatCompleteOptions = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
};

const boundApi = bindChatComplete(chatComplete, bound);

bound.connectorId = 'some-other-id';

await boundApi({ ...unbound });

expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
messages: unbound.messages,
});
});

it('does not allow overriding bound parameters with the unbound object', async () => {
const bound: BoundChatCompleteOptions = {
connectorId: 'some-id',
functionCalling: 'native',
};

const unbound = {
messages: [{ role: MessageRole.User, content: 'hello there' }],
connectorId: 'overridden',
} as UnboundChatCompleteOptions;

const boundApi = bindChatComplete(chatComplete, bound);

await boundApi({ ...unbound });

expect(chatComplete).toHaveBeenCalledTimes(1);
expect(chatComplete).toHaveBeenCalledWith({
connectorId: 'some-id',
functionCalling: 'native',
messages: unbound.messages,
});
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type {
ChatCompleteAPI,
ChatCompleteOptions,
BoundChatCompleteAPI,
BoundChatCompleteOptions,
UnboundChatCompleteOptions,
ToolOptions,
} from '@kbn/inference-common';

/**
* Bind chatComplete to the provided parameters,
* returning a bound version of the API.
*/
export function bindChatComplete(
chatComplete: ChatCompleteAPI,
boundParams: BoundChatCompleteOptions
): BoundChatCompleteAPI;
export function bindChatComplete(
chatComplete: ChatCompleteAPI,
boundParams: BoundChatCompleteOptions
) {
const { connectorId, functionCalling } = boundParams;
return (unboundParams: UnboundChatCompleteOptions<ToolOptions, boolean>) => {
const params: ChatCompleteOptions<ToolOptions, boolean> = {
...unboundParams,
connectorId,
functionCalling,
};
return chatComplete(params);
};
}
8 changes: 8 additions & 0 deletions x-pack/plugins/inference/common/chat_complete/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export { bindChatComplete } from './bind_chat_complete';
2 changes: 1 addition & 1 deletion x-pack/plugins/inference/common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ export {

export { generateFakeToolCallId } from './utils/generate_fake_tool_call_id';

export { createOutputApi } from './create_output_api';
export { createOutputApi } from './output';

export type { ChatCompleteRequestBody, GetConnectorsResponseBody } from './http_apis';
Loading

0 comments on commit 3c8f077

Please sign in to comment.