diff --git a/x-pack/packages/ai-infra/inference-common/index.ts b/x-pack/packages/ai-infra/inference-common/index.ts index 2791896c801ef..4b5ef3a5cfda1 100644 --- a/x-pack/packages/ai-infra/inference-common/index.ts +++ b/x-pack/packages/ai-infra/inference-common/index.ts @@ -34,6 +34,9 @@ export { type ChatCompleteStreamResponse, type ChatCompleteResponse, type ChatCompletionTokenCount, + type BoundChatCompleteAPI, + type BoundChatCompleteOptions, + type UnboundChatCompleteOptions, withoutTokenCountEvents, withoutChunkEvents, isChatCompletionMessageEvent, @@ -59,6 +62,9 @@ export { type OutputUpdateEvent, type Output, type OutputEvent, + type BoundOutputAPI, + type BoundOutputOptions, + type UnboundOutputOptions, isOutputCompleteEvent, isOutputUpdateEvent, isOutputEvent, diff --git a/x-pack/packages/ai-infra/inference-common/src/chat_complete/bound_api.ts b/x-pack/packages/ai-infra/inference-common/src/chat_complete/bound_api.ts new file mode 100644 index 0000000000000..083620ed99a93 --- /dev/null +++ b/x-pack/packages/ai-infra/inference-common/src/chat_complete/bound_api.ts @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { ChatCompleteOptions, ChatCompleteCompositeResponse } from './api'; +import type { ToolOptions } from './tools'; + +/** + * Static options used to call the {@link BoundChatCompleteAPI} + */ +export type BoundChatCompleteOptions< + TToolOptions extends ToolOptions = ToolOptions, + TStream extends boolean = false +> = Pick, 'connectorId' | 'functionCalling'>; + +/** + * Options used to call the {@link BoundChatCompleteAPI} + */ +export type UnboundChatCompleteOptions< + TToolOptions extends ToolOptions = ToolOptions, + TStream extends boolean = false +> = Omit, 'connectorId' | 'functionCalling'>; + +/** + * Version of {@link ChatCompleteAPI} that got pre-bound to a set of static parameters + */ +export type BoundChatCompleteAPI = < + TToolOptions extends ToolOptions = ToolOptions, + TStream extends boolean = false +>( + options: UnboundChatCompleteOptions +) => ChatCompleteCompositeResponse; diff --git a/x-pack/packages/ai-infra/inference-common/src/chat_complete/index.ts b/x-pack/packages/ai-infra/inference-common/src/chat_complete/index.ts index ca69f39b273e5..3daa898ab2e1a 100644 --- a/x-pack/packages/ai-infra/inference-common/src/chat_complete/index.ts +++ b/x-pack/packages/ai-infra/inference-common/src/chat_complete/index.ts @@ -13,6 +13,11 @@ export type { ChatCompleteStreamResponse, ChatCompleteResponse, } from './api'; +export type { + BoundChatCompleteAPI, + BoundChatCompleteOptions, + UnboundChatCompleteOptions, +} from './bound_api'; export { ChatCompletionEventType, type ChatCompletionMessageEvent, diff --git a/x-pack/packages/ai-infra/inference-common/src/output/bound_api.ts b/x-pack/packages/ai-infra/inference-common/src/output/bound_api.ts new file mode 100644 index 0000000000000..967dac20c0568 --- /dev/null +++ b/x-pack/packages/ai-infra/inference-common/src/output/bound_api.ts @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { OutputOptions, OutputCompositeResponse } from './api'; +import type { ToolSchema } from '../chat_complete/tool_schema'; + +/** + * Static options used to call the {@link BoundOutputAPI} + */ +export type BoundOutputOptions< + TId extends string = string, + TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined, + TStream extends boolean = false +> = Pick, 'connectorId' | 'functionCalling'>; + +/** + * Options used to call the {@link BoundOutputAPI} + */ +export type UnboundOutputOptions< + TId extends string = string, + TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined, + TStream extends boolean = false +> = Omit, 'connectorId' | 'functionCalling'>; + +/** + * Version of {@link OutputAPI} that got pre-bound to a set of static parameters + */ +export type BoundOutputAPI = < + TId extends string = string, + TOutputSchema extends ToolSchema | undefined = ToolSchema | undefined, + TStream extends boolean = false +>( + options: UnboundOutputOptions +) => OutputCompositeResponse; diff --git a/x-pack/packages/ai-infra/inference-common/src/output/index.ts b/x-pack/packages/ai-infra/inference-common/src/output/index.ts index a3039005b2f7c..d4e17967b50f5 100644 --- a/x-pack/packages/ai-infra/inference-common/src/output/index.ts +++ b/x-pack/packages/ai-infra/inference-common/src/output/index.ts @@ -12,6 +12,7 @@ export type { OutputResponse, OutputStreamResponse, } from './api'; +export type { BoundOutputAPI, BoundOutputOptions, UnboundOutputOptions } from './bound_api'; export { OutputEventType, type OutputCompleteEvent, diff --git a/x-pack/plugins/inference/README.md b/x-pack/plugins/inference/README.md index 935ae31bd6bc6..bba5b4cdcfc27 100644 --- a/x-pack/plugins/inference/README.md +++ b/x-pack/plugins/inference/README.md @@ -77,6 +77,25 @@ class MyPlugin { } ``` +### Binding common parameters + +It is also possible to bind a client to its configuration parameters, to avoid passing connectorId +to every call, for example, using the `bindTo` parameter when creating the client. + +```ts +const inferenceClient = myStartDeps.inference.getClient({ + request, + bindTo: { + connectorId: 'my-connector-id', + functionCalling: 'simulated', + } +}); + +const chatResponse = inferenceClient.chatComplete({ + messages: [{ role: MessageRole.User, content: 'Do something' }], +}); +``` + ## APIs ### `chatComplete` API: diff --git a/x-pack/plugins/inference/common/chat_complete/bind_chat_complete.test.ts b/x-pack/plugins/inference/common/chat_complete/bind_chat_complete.test.ts new file mode 100644 index 0000000000000..039fd0410d254 --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/bind_chat_complete.test.ts @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { + BoundChatCompleteOptions, + ChatCompleteAPI, + MessageRole, + UnboundChatCompleteOptions, +} from '@kbn/inference-common'; +import { bindChatComplete } from './bind_chat_complete'; + +describe('bindChatComplete', () => { + let chatComplete: ChatCompleteAPI & jest.MockedFn; + + beforeEach(() => { + chatComplete = jest.fn(); + }); + + it('calls chatComplete with both bound and unbound params', async () => { + const bound: BoundChatCompleteOptions = { + connectorId: 'some-id', + functionCalling: 'native', + }; + + const unbound: UnboundChatCompleteOptions = { + messages: [{ role: MessageRole.User, content: 'hello there' }], + }; + + const boundApi = bindChatComplete(chatComplete, bound); + + await boundApi({ ...unbound }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith({ + ...bound, + ...unbound, + }); + }); + + it('forwards the response from chatComplete', async () => { + const expectedReturnValue = Symbol('something'); + chatComplete.mockResolvedValue(expectedReturnValue as any); + + const boundApi = bindChatComplete(chatComplete, { connectorId: 'my-connector' }); + + const result = await boundApi({ + messages: [{ role: MessageRole.User, content: 'hello there' }], + }); + + expect(result).toEqual(expectedReturnValue); + }); + + it('only passes the expected parameters from the bound param object', async () => { + const bound = { + connectorId: 'some-id', + functionCalling: 'native', + foo: 'bar', + } as BoundChatCompleteOptions; + + const unbound: UnboundChatCompleteOptions = { + messages: [{ role: MessageRole.User, content: 'hello there' }], + }; + + const boundApi = bindChatComplete(chatComplete, bound); + + await boundApi({ ...unbound }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith({ + connectorId: 'some-id', + functionCalling: 'native', + messages: unbound.messages, + }); + }); + + it('ignores mutations of the bound parameters after binding', async () => { + const bound: BoundChatCompleteOptions = { + connectorId: 'some-id', + functionCalling: 'native', + }; + + const unbound: UnboundChatCompleteOptions = { + messages: [{ role: MessageRole.User, content: 'hello there' }], + }; + + const boundApi = bindChatComplete(chatComplete, bound); + + bound.connectorId = 'some-other-id'; + + await boundApi({ ...unbound }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith({ + connectorId: 'some-id', + functionCalling: 'native', + messages: unbound.messages, + }); + }); + + it('does not allow overriding bound parameters with the unbound object', async () => { + const bound: BoundChatCompleteOptions = { + connectorId: 'some-id', + functionCalling: 'native', + }; + + const unbound = { + messages: [{ role: MessageRole.User, content: 'hello there' }], + connectorId: 'overridden', + } as UnboundChatCompleteOptions; + + const boundApi = bindChatComplete(chatComplete, bound); + + await boundApi({ ...unbound }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith({ + connectorId: 'some-id', + functionCalling: 'native', + messages: unbound.messages, + }); + }); +}); diff --git a/x-pack/plugins/inference/common/chat_complete/bind_chat_complete.ts b/x-pack/plugins/inference/common/chat_complete/bind_chat_complete.ts new file mode 100644 index 0000000000000..3030dee641223 --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/bind_chat_complete.ts @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { + ChatCompleteAPI, + ChatCompleteOptions, + BoundChatCompleteAPI, + BoundChatCompleteOptions, + UnboundChatCompleteOptions, + ToolOptions, +} from '@kbn/inference-common'; + +/** + * Bind chatComplete to the provided parameters, + * returning a bound version of the API. + */ +export function bindChatComplete( + chatComplete: ChatCompleteAPI, + boundParams: BoundChatCompleteOptions +): BoundChatCompleteAPI; +export function bindChatComplete( + chatComplete: ChatCompleteAPI, + boundParams: BoundChatCompleteOptions +) { + const { connectorId, functionCalling } = boundParams; + return (unboundParams: UnboundChatCompleteOptions) => { + const params: ChatCompleteOptions = { + ...unboundParams, + connectorId, + functionCalling, + }; + return chatComplete(params); + }; +} diff --git a/x-pack/plugins/inference/common/chat_complete/index.ts b/x-pack/plugins/inference/common/chat_complete/index.ts new file mode 100644 index 0000000000000..9eaa850fc8195 --- /dev/null +++ b/x-pack/plugins/inference/common/chat_complete/index.ts @@ -0,0 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { bindChatComplete } from './bind_chat_complete'; diff --git a/x-pack/plugins/inference/common/index.ts b/x-pack/plugins/inference/common/index.ts index 19b24d53a389a..79433cbc71a68 100644 --- a/x-pack/plugins/inference/common/index.ts +++ b/x-pack/plugins/inference/common/index.ts @@ -12,6 +12,6 @@ export { export { generateFakeToolCallId } from './utils/generate_fake_tool_call_id'; -export { createOutputApi } from './create_output_api'; +export { createOutputApi } from './output'; export type { ChatCompleteRequestBody, GetConnectorsResponseBody } from './http_apis'; diff --git a/x-pack/plugins/inference/common/output/bind_output.test.ts b/x-pack/plugins/inference/common/output/bind_output.test.ts new file mode 100644 index 0000000000000..65741acbd8a3e --- /dev/null +++ b/x-pack/plugins/inference/common/output/bind_output.test.ts @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { BoundOutputOptions, OutputAPI, UnboundOutputOptions } from '@kbn/inference-common'; +import { bindOutput } from './bind_output'; + +describe('createScopedOutputAPI', () => { + let chatComplete: OutputAPI & jest.MockedFn; + + beforeEach(() => { + chatComplete = jest.fn(); + }); + + it('calls chatComplete with both bound and unbound params', async () => { + const bound: BoundOutputOptions = { + connectorId: 'some-id', + functionCalling: 'native', + }; + + const unbound: UnboundOutputOptions = { + id: 'foo', + input: 'hello there', + }; + + const boundApi = bindOutput(chatComplete, bound); + + await boundApi({ ...unbound }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith({ + ...bound, + ...unbound, + }); + }); + + it('forwards the response from chatComplete', async () => { + const expectedReturnValue = Symbol('something'); + chatComplete.mockResolvedValue(expectedReturnValue as any); + + const boundApi = bindOutput(chatComplete, { connectorId: 'my-connector' }); + + const result = await boundApi({ + id: 'foo', + input: 'hello there', + }); + + expect(result).toEqual(expectedReturnValue); + }); + + it('only passes the expected parameters from the bound param object', async () => { + const bound = { + connectorId: 'some-id', + functionCalling: 'native', + foo: 'bar', + } as BoundOutputOptions; + + const unbound: UnboundOutputOptions = { + id: 'foo', + input: 'hello there', + }; + + const boundApi = bindOutput(chatComplete, bound); + + await boundApi({ ...unbound }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith({ + connectorId: 'some-id', + functionCalling: 'native', + id: 'foo', + input: 'hello there', + }); + }); + + it('ignores mutations of the bound parameters after binding', async () => { + const bound: BoundOutputOptions = { + connectorId: 'some-id', + functionCalling: 'native', + }; + + const unbound: UnboundOutputOptions = { + id: 'foo', + input: 'hello there', + }; + + const boundApi = bindOutput(chatComplete, bound); + + bound.connectorId = 'some-other-id'; + + await boundApi({ ...unbound }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith({ + connectorId: 'some-id', + functionCalling: 'native', + id: 'foo', + input: 'hello there', + }); + }); + + it('does not allow overriding bound parameters with the unbound object', async () => { + const bound: BoundOutputOptions = { + connectorId: 'some-id', + functionCalling: 'native', + }; + + const unbound = { + id: 'foo', + input: 'hello there', + connectorId: 'overridden', + } as UnboundOutputOptions; + + const boundApi = bindOutput(chatComplete, bound); + + await boundApi({ ...unbound }); + + expect(chatComplete).toHaveBeenCalledTimes(1); + expect(chatComplete).toHaveBeenCalledWith({ + connectorId: 'some-id', + functionCalling: 'native', + id: 'foo', + input: 'hello there', + }); + }); +}); diff --git a/x-pack/plugins/inference/common/output/bind_output.ts b/x-pack/plugins/inference/common/output/bind_output.ts new file mode 100644 index 0000000000000..45ac434d5ffd6 --- /dev/null +++ b/x-pack/plugins/inference/common/output/bind_output.ts @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { + OutputAPI, + OutputOptions, + BoundOutputAPI, + BoundOutputOptions, + UnboundOutputOptions, + ToolSchema, +} from '@kbn/inference-common'; + +/** + * Bind output to the provided parameters, + * returning a bound version of the API. + */ +export function bindOutput( + chatComplete: OutputAPI, + boundParams: BoundOutputOptions +): BoundOutputAPI; +export function bindOutput(chatComplete: OutputAPI, boundParams: BoundOutputOptions) { + const { connectorId, functionCalling } = boundParams; + return (unboundParams: UnboundOutputOptions) => { + const params: OutputOptions = { + ...unboundParams, + connectorId, + functionCalling, + }; + return chatComplete(params); + }; +} diff --git a/x-pack/plugins/inference/common/create_output_api.test.ts b/x-pack/plugins/inference/common/output/create_output_api.test.ts similarity index 100% rename from x-pack/plugins/inference/common/create_output_api.test.ts rename to x-pack/plugins/inference/common/output/create_output_api.test.ts diff --git a/x-pack/plugins/inference/common/create_output_api.ts b/x-pack/plugins/inference/common/output/create_output_api.ts similarity index 97% rename from x-pack/plugins/inference/common/create_output_api.ts rename to x-pack/plugins/inference/common/output/create_output_api.ts index e5dd2eeda2cbd..d263f733bf4ee 100644 --- a/x-pack/plugins/inference/common/create_output_api.ts +++ b/x-pack/plugins/inference/common/output/create_output_api.ts @@ -16,7 +16,7 @@ import { withoutTokenCountEvents, } from '@kbn/inference-common'; import { isObservable, map } from 'rxjs'; -import { ensureMultiTurn } from './utils/ensure_multi_turn'; +import { ensureMultiTurn } from '../utils/ensure_multi_turn'; export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI; export function createOutputApi(chatCompleteApi: ChatCompleteAPI) { diff --git a/x-pack/plugins/inference/common/output/index.ts b/x-pack/plugins/inference/common/output/index.ts new file mode 100644 index 0000000000000..4c6f053d6ed85 --- /dev/null +++ b/x-pack/plugins/inference/common/output/index.ts @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { createOutputApi } from './create_output_api'; +export { bindOutput } from './bind_output'; diff --git a/x-pack/plugins/inference/public/plugin.tsx b/x-pack/plugins/inference/public/plugin.tsx index f1023bc9c2546..614c2107c0a06 100644 --- a/x-pack/plugins/inference/public/plugin.tsx +++ b/x-pack/plugins/inference/public/plugin.tsx @@ -7,7 +7,7 @@ import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/public'; import type { Logger } from '@kbn/logging'; -import { createOutputApi } from '../common/create_output_api'; +import { createOutputApi } from '../common/output'; import type { GetConnectorsResponseBody } from '../common/http_apis'; import { createChatCompleteApi } from './chat_complete'; import type { diff --git a/x-pack/plugins/inference/scripts/util/kibana_client.ts b/x-pack/plugins/inference/scripts/util/kibana_client.ts index ad6c21cf4b248..ef6f1c4fdcdce 100644 --- a/x-pack/plugins/inference/scripts/util/kibana_client.ts +++ b/x-pack/plugins/inference/scripts/util/kibana_client.ts @@ -28,7 +28,7 @@ import { } from '@kbn/inference-common'; import type { ChatCompleteRequestBody } from '../../common/http_apis'; import type { InferenceConnector } from '../../common/connectors'; -import { createOutputApi } from '../../common/create_output_api'; +import { createOutputApi } from '../../common/output/create_output_api'; import { eventSourceStreamIntoObservable } from '../../server/util/event_source_stream_into_observable'; // eslint-disable-next-line spaced-comment diff --git a/x-pack/plugins/inference/server/chat_complete/api.ts b/x-pack/plugins/inference/server/chat_complete/api.ts index cf325e72ddf3a..13b1c8d87270c 100644 --- a/x-pack/plugins/inference/server/chat_complete/api.ts +++ b/x-pack/plugins/inference/server/chat_complete/api.ts @@ -16,14 +16,14 @@ import { type ToolOptions, ChatCompleteOptions, } from '@kbn/inference-common'; -import type { InferenceStartDependencies } from '../types'; +import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { getConnectorById } from '../util/get_connector_by_id'; import { getInferenceAdapter } from './adapters'; import { createInferenceExecutor, chunksIntoMessage, streamToResponse } from './utils'; interface CreateChatCompleteApiOptions { request: KibanaRequest; - actions: InferenceStartDependencies['actions']; + actions: ActionsPluginStart; logger: Logger; } diff --git a/x-pack/plugins/inference/server/index.ts b/x-pack/plugins/inference/server/index.ts index 60ce870020feb..128e90a58308d 100644 --- a/x-pack/plugins/inference/server/index.ts +++ b/x-pack/plugins/inference/server/index.ts @@ -15,7 +15,7 @@ import type { } from './types'; import { InferencePlugin } from './plugin'; -export type { InferenceClient } from './types'; +export type { InferenceClient, BoundInferenceClient } from './inference_client'; export type { InferenceServerSetup, InferenceServerStart }; export { naturalLanguageToEsql } from './tasks/nl_to_esql'; diff --git a/x-pack/plugins/inference/server/inference_client/bind_client.ts b/x-pack/plugins/inference/server/inference_client/bind_client.ts new file mode 100644 index 0000000000000..4600ed1364ed3 --- /dev/null +++ b/x-pack/plugins/inference/server/inference_client/bind_client.ts @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { BoundChatCompleteOptions } from '@kbn/inference-common'; +import { bindChatComplete } from '../../common/chat_complete'; +import { bindOutput } from '../../common/output'; +import type { InferenceClient, BoundInferenceClient } from './types'; + +export const bindClient = ( + unboundClient: InferenceClient, + boundParams: BoundChatCompleteOptions +): BoundInferenceClient => { + return { + ...unboundClient, + chatComplete: bindChatComplete(unboundClient.chatComplete, boundParams), + output: bindOutput(unboundClient.output, boundParams), + }; +}; diff --git a/x-pack/plugins/inference/server/inference_client/create_client.test.ts b/x-pack/plugins/inference/server/inference_client/create_client.test.ts new file mode 100644 index 0000000000000..98f5502cdfa55 --- /dev/null +++ b/x-pack/plugins/inference/server/inference_client/create_client.test.ts @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { createClient } from './create_client'; +import { loggerMock, type MockedLogger } from '@kbn/logging-mocks'; +import { httpServerMock } from '@kbn/core/server/mocks'; +import { actionsMock } from '@kbn/actions-plugin/server/mocks'; + +jest.mock('./inference_client'); +jest.mock('./bind_client'); +import { createInferenceClient } from './inference_client'; +import { bindClient } from './bind_client'; + +const bindClientMock = bindClient as jest.MockedFn; +const createInferenceClientMock = createInferenceClient as jest.MockedFn< + typeof createInferenceClient +>; + +describe('createClient', () => { + let logger: MockedLogger; + let actions: ReturnType; + let request: ReturnType; + + beforeEach(() => { + logger = loggerMock.create(); + actions = actionsMock.createStart(); + request = httpServerMock.createKibanaRequest(); + }); + + afterEach(() => { + bindClientMock.mockReset(); + createInferenceClientMock.mockReset(); + }); + + describe('when `bindTo` is not specified', () => { + it('calls createInferenceClient and return the client', () => { + const expectedResult = Symbol('expected') as any; + createInferenceClientMock.mockReturnValue(expectedResult); + + const result = createClient({ + request, + actions, + logger, + }); + + expect(createInferenceClientMock).toHaveBeenCalledTimes(1); + expect(createInferenceClientMock).toHaveBeenCalledWith({ request, actions, logger }); + + expect(bindClientMock).not.toHaveBeenCalled(); + + expect(result).toBe(expectedResult); + }); + + it('return a client with the expected type', async () => { + createInferenceClientMock.mockReturnValue({ + chatComplete: jest.fn(), + } as any); + + const client = createClient({ + request, + actions, + logger, + }); + + // type check on client.chatComplete + await client.chatComplete({ + messages: [], + connectorId: '.foo', + }); + }); + }); + + describe('when `bindTo` is specified', () => { + it('calls createInferenceClient and bindClient and forward the expected value', () => { + const createInferenceResult = Symbol('createInferenceResult') as any; + createInferenceClientMock.mockReturnValue(createInferenceResult); + + const bindClientResult = Symbol('bindClientResult') as any; + bindClientMock.mockReturnValue(bindClientResult); + + const result = createClient({ + request, + actions, + logger, + bindTo: { + connectorId: '.my-connector', + }, + }); + + expect(createInferenceClientMock).toHaveBeenCalledTimes(1); + expect(createInferenceClientMock).toHaveBeenCalledWith({ + request, + actions, + logger, + }); + + expect(bindClientMock).toHaveBeenCalledTimes(1); + expect(bindClientMock).toHaveBeenCalledWith(createInferenceResult, { + connectorId: '.my-connector', + }); + + expect(result).toBe(bindClientResult); + }); + + it('return a client with the expected type', async () => { + bindClientMock.mockReturnValue({ + chatComplete: jest.fn(), + } as any); + + const client = createClient({ + request, + actions, + logger, + bindTo: { + connectorId: '.foo', + }, + }); + + // type check on client.chatComplete + await client.chatComplete({ + messages: [], + }); + }); + }); +}); diff --git a/x-pack/plugins/inference/server/inference_client/create_client.ts b/x-pack/plugins/inference/server/inference_client/create_client.ts new file mode 100644 index 0000000000000..3507dd7fef8a8 --- /dev/null +++ b/x-pack/plugins/inference/server/inference_client/create_client.ts @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { Logger } from '@kbn/logging'; +import type { KibanaRequest } from '@kbn/core-http-server'; +import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import type { BoundChatCompleteOptions } from '@kbn/inference-common'; +import type { BoundInferenceClient, InferenceClient } from './types'; +import { createInferenceClient } from './inference_client'; +import { bindClient } from './bind_client'; + +interface UnboundOptions { + request: KibanaRequest; + actions: ActionsPluginStart; + logger: Logger; +} + +interface BoundOptions extends UnboundOptions { + bindTo: BoundChatCompleteOptions; +} + +export function createClient(options: UnboundOptions): InferenceClient; +export function createClient(options: BoundOptions): BoundInferenceClient; +export function createClient( + options: UnboundOptions | BoundOptions +): BoundInferenceClient | InferenceClient { + const { actions, request, logger } = options; + const client = createInferenceClient({ request, actions, logger }); + if ('bindTo' in options) { + return bindClient(client, options.bindTo); + } else { + return client; + } +} diff --git a/x-pack/plugins/inference/server/inference_client/index.ts b/x-pack/plugins/inference/server/inference_client/index.ts index 03da0e3da200f..9d56ebe7ff61a 100644 --- a/x-pack/plugins/inference/server/inference_client/index.ts +++ b/x-pack/plugins/inference/server/inference_client/index.ts @@ -5,28 +5,5 @@ * 2.0. */ -import type { Logger } from '@kbn/logging'; -import type { KibanaRequest } from '@kbn/core-http-server'; -import type { InferenceClient, InferenceStartDependencies } from '../types'; -import { createChatCompleteApi } from '../chat_complete'; -import { createOutputApi } from '../../common/create_output_api'; -import { getConnectorById } from '../util/get_connector_by_id'; - -export function createInferenceClient({ - request, - actions, - logger, -}: { request: KibanaRequest; logger: Logger } & Pick< - InferenceStartDependencies, - 'actions' ->): InferenceClient { - const chatComplete = createChatCompleteApi({ request, actions, logger }); - return { - chatComplete, - output: createOutputApi(chatComplete), - getConnectorById: async (connectorId: string) => { - const actionsClient = await actions.getActionsClientWithRequest(request); - return await getConnectorById({ connectorId, actionsClient }); - }, - }; -} +export { createClient } from './create_client'; +export type { InferenceClient, BoundInferenceClient } from './types'; diff --git a/x-pack/plugins/inference/server/inference_client/inference_client.ts b/x-pack/plugins/inference/server/inference_client/inference_client.ts new file mode 100644 index 0000000000000..f4c64ebdcce54 --- /dev/null +++ b/x-pack/plugins/inference/server/inference_client/inference_client.ts @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { Logger } from '@kbn/logging'; +import type { KibanaRequest } from '@kbn/core-http-server'; +import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import type { InferenceClient } from './types'; +import { createChatCompleteApi } from '../chat_complete'; +import { createOutputApi } from '../../common/output/create_output_api'; +import { getConnectorById } from '../util/get_connector_by_id'; + +export function createInferenceClient({ + request, + actions, + logger, +}: { + request: KibanaRequest; + logger: Logger; + actions: ActionsPluginStart; +}): InferenceClient { + const chatComplete = createChatCompleteApi({ request, actions, logger }); + return { + chatComplete, + output: createOutputApi(chatComplete), + getConnectorById: async (connectorId: string) => { + const actionsClient = await actions.getActionsClientWithRequest(request); + return await getConnectorById({ connectorId, actionsClient }); + }, + }; +} diff --git a/x-pack/plugins/inference/server/inference_client/types.ts b/x-pack/plugins/inference/server/inference_client/types.ts new file mode 100644 index 0000000000000..193ce83f6d7b6 --- /dev/null +++ b/x-pack/plugins/inference/server/inference_client/types.ts @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { + BoundChatCompleteAPI, + ChatCompleteAPI, + BoundOutputAPI, + OutputAPI, +} from '@kbn/inference-common'; +import type { InferenceConnector } from '../../common/connectors'; + +/** + * An inference client, scoped to a request, that can be used to interact with LLMs. + */ +export interface InferenceClient { + /** + * `chatComplete` requests the LLM to generate a response to + * a prompt or conversation, which might be plain text + * or a tool call, or a combination of both. + */ + chatComplete: ChatCompleteAPI; + /** + * `output` asks the LLM to generate a structured (JSON) + * response based on a schema and a prompt or conversation. + */ + output: OutputAPI; + /** + * `getConnectorById` returns an inference connector by id. + * Non-inference connectors will throw an error. + */ + getConnectorById: (id: string) => Promise; +} + +/** + * A version of the {@link InferenceClient} that is pre-bound to a set of parameters. + */ +export interface BoundInferenceClient { + /** + * `chatComplete` requests the LLM to generate a response to + * a prompt or conversation, which might be plain text + * or a tool call, or a combination of both. + */ + chatComplete: BoundChatCompleteAPI; + /** + * `output` asks the LLM to generate a structured (JSON) + * response based on a schema and a prompt or conversation. + */ + output: BoundOutputAPI; + /** + * `getConnectorById` returns an inference connector by id. + * Non-inference connectors will throw an error. + */ + getConnectorById: (id: string) => Promise; +} diff --git a/x-pack/plugins/inference/server/plugin.ts b/x-pack/plugins/inference/server/plugin.ts index 2b1a7be0a165c..0f7090f483339 100644 --- a/x-pack/plugins/inference/server/plugin.ts +++ b/x-pack/plugins/inference/server/plugin.ts @@ -7,10 +7,16 @@ import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/server'; import type { Logger } from '@kbn/logging'; -import { createInferenceClient } from './inference_client'; +import { + type BoundInferenceClient, + createClient as createInferenceClient, + type InferenceClient, +} from './inference_client'; import { registerRoutes } from './routes'; import type { InferenceConfig } from './config'; -import type { +import { + InferenceBoundClientCreateOptions, + InferenceClientCreateOptions, InferenceServerSetup, InferenceServerStart, InferenceSetupDependencies, @@ -48,12 +54,12 @@ export class InferencePlugin start(core: CoreStart, pluginsStart: InferenceStartDependencies): InferenceServerStart { return { - getClient: ({ request }) => { + getClient: (options: T) => { return createInferenceClient({ - request, + ...options, actions: pluginsStart.actions, logger: this.logger.get('client'), - }); + }) as T extends InferenceBoundClientCreateOptions ? BoundInferenceClient : InferenceClient; }, }; } diff --git a/x-pack/plugins/inference/server/routes/chat_complete.ts b/x-pack/plugins/inference/server/routes/chat_complete.ts index e4e078e58c15a..b363c88352994 100644 --- a/x-pack/plugins/inference/server/routes/chat_complete.ts +++ b/x-pack/plugins/inference/server/routes/chat_complete.ts @@ -15,7 +15,7 @@ import type { } from '@kbn/core/server'; import { MessageRole, ToolCall, ToolChoiceType } from '@kbn/inference-common'; import type { ChatCompleteRequestBody } from '../../common/http_apis'; -import { createInferenceClient } from '../inference_client'; +import { createClient as createInferenceClient } from '../inference_client'; import { InferenceServerStart, InferenceStartDependencies } from '../types'; import { observableIntoEventSourceStream } from '../util/observable_into_event_source_stream'; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts index ce45d9a15e4b3..db3ac3b493481 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts @@ -14,7 +14,7 @@ import type { ToolOptions, OutputCompleteEvent, } from '@kbn/inference-common'; -import type { InferenceClient } from '../../types'; +import type { InferenceClient } from '../../inference_client'; export type NlToEsqlTaskEvent = | OutputCompleteEvent< diff --git a/x-pack/plugins/inference/server/types.ts b/x-pack/plugins/inference/server/types.ts index f538448372e36..8d6d1413f306a 100644 --- a/x-pack/plugins/inference/server/types.ts +++ b/x-pack/plugins/inference/server/types.ts @@ -10,8 +10,8 @@ import type { PluginSetupContract as ActionsPluginSetup, } from '@kbn/actions-plugin/server'; import type { KibanaRequest } from '@kbn/core-http-server'; -import { ChatCompleteAPI, OutputAPI } from '@kbn/inference-common'; -import { InferenceConnector } from '../common/connectors'; +import type { BoundChatCompleteOptions } from '@kbn/inference-common'; +import type { InferenceClient, BoundInferenceClient } from './inference_client'; /* eslint-disable @typescript-eslint/no-empty-interface*/ @@ -23,37 +23,74 @@ export interface InferenceStartDependencies { actions: ActionsPluginStart; } +/** + * Setup contract of the inference plugin. + */ export interface InferenceServerSetup {} -export interface InferenceClient { - /** - * `chatComplete` requests the LLM to generate a response to - * a prompt or conversation, which might be plain text - * or a tool call, or a combination of both. - */ - chatComplete: ChatCompleteAPI; +/** + * Options to create an inference client using the {@link InferenceServerStart.getClient} API. + */ +export interface InferenceUnboundClientCreateOptions { /** - * `output` asks the LLM to generate a structured (JSON) - * response based on a schema and a prompt or conversation. + * The request to scope the client to. */ - output: OutputAPI; + request: KibanaRequest; +} + +/** + * Options to create a bound inference client using the {@link InferenceServerStart.getClient} API. + */ +export interface InferenceBoundClientCreateOptions extends InferenceUnboundClientCreateOptions { /** - * `getConnectorById` returns an inference connector by id. - * Non-inference connectors will throw an error. + * The parameters to bind the client to. */ - getConnectorById: (id: string) => Promise; + bindTo: BoundChatCompleteOptions; } -interface InferenceClientCreateOptions { - request: KibanaRequest; -} +/** + * Options to create an inference client using the {@link InferenceServerStart.getClient} API. + */ +export type InferenceClientCreateOptions = + | InferenceUnboundClientCreateOptions + | InferenceBoundClientCreateOptions; +/** + * Start contract of the inference plugin, exposing APIs to interact with LLMs. + */ export interface InferenceServerStart { /** - * Creates an inference client, scoped to a request. + * Creates an {@link InferenceClient}, scoped to a request. + * + * @example + * ```ts + * const inferenceClient = myStartDeps.inference.getClient({ request }); + * + * const chatResponse = inferenceClient.chatComplete({ + * connectorId: 'my-connector-id', + * messages: [{ role: MessageRole.User, content: 'Do something' }], + * }); + * ``` + * + * It is also possible to bind a client to its configuration parameters, to avoid passing connectorId + * to every call, for example. Defining the `bindTo` parameter will return a {@link BoundInferenceClient} + * + * @example + * ```ts + * const inferenceClient = myStartDeps.inference.getClient({ + * request, + * bindTo: { + * connectorId: 'my-connector-id', + * functionCalling: 'simulated', + * } + * }); * - * @param options {@link InferenceClientCreateOptions} - * @returns {@link InferenceClient} + * const chatResponse = inferenceClient.chatComplete({ + * messages: [{ role: MessageRole.User, content: 'Do something' }], + * }); + * ``` */ - getClient: (options: InferenceClientCreateOptions) => InferenceClient; + getClient: ( + options: T + ) => T extends InferenceBoundClientCreateOptions ? BoundInferenceClient : InferenceClient; }