Skip to content

Commit

Permalink
Add stream param for inference APIs (#198646)
Browse files Browse the repository at this point in the history
## Summary

Fix #198644

Add a `stream` parameter to the `chatComplete` and `output` APIs,
defaulting to `false`, to switch between "full content response as
promise" and "event observable" responses.

Note: at the moment, in non-stream mode, the implementation is simply
constructing the response from the observable. It should be possible
later to improve this by having the LLM adapters handle the
stream/no-stream logic, but this is out of scope of the current PR.

### Normal mode
```ts
const response = await chatComplete({
  connectorId: 'my-connector',
  system: "You are a helpful assistant",
  messages: [
     { role: MessageRole.User, content: "Some question?"},
  ]
});

const { content, toolCalls } = response;
// do something
```

### Stream mode
```ts
const events$ = chatComplete({
  stream: true,
  connectorId: 'my-connector',
  system: "You are a helpful assistant",
  messages: [
     { role: MessageRole.User, content: "Some question?"},
  ]
});

events$.subscribe((event) => {
   // do something
});

```

---------

Co-authored-by: kibanamachine <[email protected]>
Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
3 people authored Nov 5, 2024
1 parent 6b77e05 commit fe16822
Show file tree
Hide file tree
Showing 26 changed files with 1,050 additions and 228 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ export function useObservabilityAIAssistantContext({
},
metric: {
type: 'object',
properties: {},
},
gauge: {
type: 'object',
properties: {},
},
pie: {
type: 'object',
Expand Down Expand Up @@ -158,6 +160,7 @@ export function useObservabilityAIAssistantContext({
},
table: {
type: 'object',
properties: {},
},
tagcloud: {
type: 'object',
Expand Down
8 changes: 7 additions & 1 deletion x-pack/packages/ai-infra/inference-common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ export {
type ToolChoice,
type ChatCompleteAPI,
type ChatCompleteOptions,
type ChatCompletionResponse,
type ChatCompleteCompositeResponse,
type ChatCompletionTokenCountEvent,
type ChatCompletionEvent,
type ChatCompletionChunkEvent,
type ChatCompletionChunkToolCall,
type ChatCompletionMessageEvent,
type ChatCompleteStreamResponse,
type ChatCompleteResponse,
type ChatCompletionTokenCount,
withoutTokenCountEvents,
withoutChunkEvents,
isChatCompletionMessageEvent,
Expand All @@ -48,7 +51,10 @@ export {
export {
OutputEventType,
type OutputAPI,
type OutputOptions,
type OutputResponse,
type OutputCompositeResponse,
type OutputStreamResponse,
type OutputCompleteEvent,
type OutputUpdateEvent,
type Output,
Expand Down
93 changes: 83 additions & 10 deletions x-pack/packages/ai-infra/inference-common/src/chat_complete/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,35 @@
*/

import type { Observable } from 'rxjs';
import type { ToolOptions } from './tools';
import type { ToolCallsOf, ToolOptions } from './tools';
import type { Message } from './messages';
import type { ChatCompletionEvent } from './events';
import type { ChatCompletionEvent, ChatCompletionTokenCount } from './events';

/**
* Request a completion from the LLM based on a prompt or conversation.
*
* @example using the API to get an event observable.
* By default, The complete LLM response will be returned as a promise.
*
* @example using the API in default mode to get promise of the LLM response.
* ```ts
* const response = await chatComplete({
* connectorId: 'my-connector',
* system: "You are a helpful assistant",
* messages: [
* { role: MessageRole.User, content: "Some question?"},
* ]
* });
*
* const { content, tokens, toolCalls } = response;
* ```
*
* Use `stream: true` to return an observable returning the full set
* of events in real time.
*
* @example using the API in stream mode to get an event observable.
* ```ts
* const events$ = chatComplete({
* stream: true,
* connectorId: 'my-connector',
* system: "You are a helpful assistant",
* messages: [
Expand All @@ -24,20 +43,44 @@ import type { ChatCompletionEvent } from './events';
* { role: MessageRole.User, content: "Another question?"},
* ]
* });
*
* // using the observable
* events$.pipe(withoutTokenCountEvents()).subscribe((event) => {
* if (isChatCompletionChunkEvent(event)) {
* // do something with the chunk event
* } else {
* // do something with the message event
* }
* });
* ```
*/
export type ChatCompleteAPI = <TToolOptions extends ToolOptions = ToolOptions>(
options: ChatCompleteOptions<TToolOptions>
) => ChatCompletionResponse<TToolOptions>;
export type ChatCompleteAPI = <
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
>(
options: ChatCompleteOptions<TToolOptions, TStream>
) => ChatCompleteCompositeResponse<TToolOptions, TStream>;

/**
* Options used to call the {@link ChatCompleteAPI}
*/
export type ChatCompleteOptions<TToolOptions extends ToolOptions = ToolOptions> = {
export type ChatCompleteOptions<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = {
/**
* The ID of the connector to use.
* Must be a genAI compatible connector, or an error will be thrown.
* Must be an inference connector, or an error will be thrown.
*/
connectorId: string;
/**
* Set to true to enable streaming, which will change the API response type from
* a single {@link ChatCompleteResponse} promise
* to a {@link ChatCompleteStreamResponse} event observable.
*
* Defaults to false.
*/
stream?: TStream;
/**
* Optional system message for the LLM.
*/
Expand All @@ -53,14 +96,44 @@ export type ChatCompleteOptions<TToolOptions extends ToolOptions = ToolOptions>
} & TToolOptions;

/**
* Response from the {@link ChatCompleteAPI}.
* Composite response type from the {@link ChatCompleteAPI},
* which can be either an observable or a promise depending on
* whether API was called with stream mode enabled or not.
*/
export type ChatCompleteCompositeResponse<
TToolOptions extends ToolOptions = ToolOptions,
TStream extends boolean = false
> = TStream extends true
? ChatCompleteStreamResponse<TToolOptions>
: Promise<ChatCompleteResponse<TToolOptions>>;

/**
* Response from the {@link ChatCompleteAPI} when streaming is enabled.
*
* Observable of {@link ChatCompletionEvent}
*/
export type ChatCompletionResponse<TToolOptions extends ToolOptions = ToolOptions> = Observable<
export type ChatCompleteStreamResponse<TToolOptions extends ToolOptions = ToolOptions> = Observable<
ChatCompletionEvent<TToolOptions>
>;

/**
* Response from the {@link ChatCompleteAPI} when streaming is not enabled.
*/
export interface ChatCompleteResponse<TToolOptions extends ToolOptions = ToolOptions> {
/**
* The text content of the LLM response.
*/
content: string;
/**
* The eventual tool calls performed by the LLM.
*/
toolCalls: ToolCallsOf<TToolOptions>['toolCalls'];
/**
* Token counts
*/
tokens?: ChatCompletionTokenCount;
}

/**
* Define the function calling mode when using inference APIs.
* - native will use the LLM's native function calling (requires the LLM to have native support)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,38 @@ export type ChatCompletionChunkEvent =
tool_calls: ChatCompletionChunkToolCall[];
};

/**
* Token count structure for the chatComplete API.
*/
export interface ChatCompletionTokenCount {
/**
* Input token count
*/
prompt: number;
/**
* Output token count
*/
completion: number;
/**
* Total token count
*/
total: number;
}

/**
* Token count event, send only once, usually (but not necessarily)
* before the message event
*/
export type ChatCompletionTokenCountEvent =
InferenceTaskEventBase<ChatCompletionEventType.ChatCompletionTokenCount> & {
tokens: {
/**
* Input token count
*/
prompt: number;
/**
* Output token count
*/
completion: number;
/**
* Total token count
*/
total: number;
};
/**
* The token count structure
*/
tokens: ChatCompletionTokenCount;
};

/**
* Events emitted from the {@link ChatCompletionResponse} observable
* Events emitted from the {@link ChatCompleteResponse} observable
* returned from the {@link ChatCompleteAPI}.
*
* The chatComplete API returns 3 type of events:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
*/

export type {
ChatCompletionResponse,
ChatCompleteCompositeResponse,
ChatCompleteAPI,
ChatCompleteOptions,
FunctionCallingMode,
ChatCompleteStreamResponse,
ChatCompleteResponse,
} from './api';
export {
ChatCompletionEventType,
Expand All @@ -18,6 +20,7 @@ export {
type ChatCompletionEvent,
type ChatCompletionChunkToolCall,
type ChatCompletionTokenCountEvent,
type ChatCompletionTokenCount,
} from './events';
export {
MessageRole,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ interface ToolSchemaFragmentBase {
description?: string;
}

interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
export interface ToolSchemaTypeObject extends ToolSchemaFragmentBase {
type: 'object';
properties?: Record<string, ToolSchemaType>;
properties: Record<string, ToolSchemaType>;
required?: string[] | readonly string[];
}

Expand All @@ -40,6 +40,9 @@ interface ToolSchemaTypeArray extends ToolSchemaFragmentBase {
items: Exclude<ToolSchemaType, ToolSchemaTypeArray>;
}

/**
* A tool schema property's possible types.
*/
export type ToolSchemaType =
| ToolSchemaTypeObject
| ToolSchemaTypeString
Expand Down
Loading

0 comments on commit fe16822

Please sign in to comment.