From 02a3992edb8021b850e9c66c72f4f5da0d988b82 Mon Sep 17 00:00:00 2001 From: Pierre Gayvallet Date: Wed, 28 Aug 2024 15:02:06 +0200 Subject: [PATCH] Inference plugin: Add Gemini model adapter (#191292) ## Summary Add the `gemini` model adapter for the `inference` plugin. Had to perform minor changes on the associated connector Also update the codeowner files to add the `@elastic/appex-ai-infra` team as (one of the) owner of the genAI connectors --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com> --- .github/CODEOWNERS | 18 +- .../connector_types.test.ts.snap | 241 +++++++++++ .../common/chat_complete/tool_schema.ts | 30 +- .../gemini/gemini_adapter.test.mocks.ts | 16 + .../adapters/gemini/gemini_adapter.test.ts | 396 ++++++++++++++++++ .../adapters/gemini/gemini_adapter.ts | 213 ++++++++++ .../chat_complete/adapters/gemini/index.ts | 8 + .../gemini/process_vertex_stream.test.ts | 155 +++++++ .../adapters/gemini/process_vertex_stream.ts | 70 ++++ .../chat_complete/adapters/gemini/types.ts | 38 ++ .../adapters/get_inference_adapter.test.ts | 9 +- .../adapters/get_inference_adapter.ts | 8 +- .../chat_complete/adapters/openai/index.ts | 4 +- .../utils/generate_fake_tool_call_id.ts | 12 + .../server/chat_complete/utils/index.ts | 1 + .../inference/server/routes/chat_complete.ts | 2 +- .../stack_connectors/common/gemini/schema.ts | 8 + .../connector_types/gemini/gemini.test.ts | 16 + .../server/connector_types/gemini/gemini.ts | 72 +++- 19 files changed, 1261 insertions(+), 56 deletions(-) create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.mocks.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/gemini/index.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/gemini/process_vertex_stream.test.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/gemini/process_vertex_stream.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/adapters/gemini/types.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/utils/generate_fake_tool_call_id.ts diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index faac357fc0a7a..c6beae627cb53 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1557,18 +1557,18 @@ x-pack/test/security_solution_cypress/cypress/tasks/expandable_flyout @elastic/ ## Generative AI owner connectors # OpenAI -/x-pack/plugins/stack_connectors/public/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant -/x-pack/plugins/stack_connectors/server/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant -/x-pack/plugins/stack_connectors/common/openai @elastic/security-generative-ai @elastic/obs-ai-assistant +/x-pack/plugins/stack_connectors/public/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra +/x-pack/plugins/stack_connectors/server/connector_types/openai @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra +/x-pack/plugins/stack_connectors/common/openai @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra # Bedrock -/x-pack/plugins/stack_connectors/public/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant -/x-pack/plugins/stack_connectors/server/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant -/x-pack/plugins/stack_connectors/common/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant +/x-pack/plugins/stack_connectors/public/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra +/x-pack/plugins/stack_connectors/server/connector_types/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra +/x-pack/plugins/stack_connectors/common/bedrock @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra # Gemini -/x-pack/plugins/stack_connectors/public/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant -/x-pack/plugins/stack_connectors/server/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant -/x-pack/plugins/stack_connectors/common/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant +/x-pack/plugins/stack_connectors/public/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra +/x-pack/plugins/stack_connectors/server/connector_types/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra +/x-pack/plugins/stack_connectors/common/gemini @elastic/security-generative-ai @elastic/obs-ai-assistant @elastic/appex-ai-infra ## Defend Workflows owner connectors /x-pack/plugins/stack_connectors/public/connector_types/sentinelone @elastic/security-defend-workflows diff --git a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap index d534170f74d0f..ae1289b2a9e2e 100644 --- a/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap +++ b/x-pack/plugins/actions/server/integration_tests/__snapshots__/connector_types.test.ts.snap @@ -4318,6 +4318,27 @@ Object { ], "type": "array", }, + "systemInstruction": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, "temperature": Object { "flags": Object { "default": [Function], @@ -4344,6 +4365,95 @@ Object { ], "type": "number", }, + "toolConfig": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "allowedFunctionNames": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, + "mode": Object { + "flags": Object { + "error": [Function], + }, + "matches": Array [ + Object { + "schema": Object { + "allow": Array [ + "AUTO", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "ANY", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "NONE", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + ], + "type": "alternatives", + }, + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "object", + }, "tools": Object { "flags": Object { "default": [Function], @@ -4464,6 +4574,27 @@ Object { ], "type": "array", }, + "systemInstruction": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, "temperature": Object { "flags": Object { "default": [Function], @@ -4610,6 +4741,27 @@ Object { ], "type": "array", }, + "systemInstruction": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, "temperature": Object { "flags": Object { "default": [Function], @@ -4636,6 +4788,95 @@ Object { ], "type": "number", }, + "toolConfig": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "keys": Object { + "allowedFunctionNames": Object { + "flags": Object { + "default": [Function], + "error": [Function], + "presence": "optional", + }, + "items": Array [ + Object { + "flags": Object { + "error": [Function], + "presence": "optional", + }, + "rules": Array [ + Object { + "args": Object { + "method": [Function], + }, + "name": "custom", + }, + ], + "type": "string", + }, + ], + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "array", + }, + "mode": Object { + "flags": Object { + "error": [Function], + }, + "matches": Array [ + Object { + "schema": Object { + "allow": Array [ + "AUTO", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "ANY", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + Object { + "schema": Object { + "allow": Array [ + "NONE", + ], + "flags": Object { + "error": [Function], + "only": true, + }, + "type": "any", + }, + }, + ], + "type": "alternatives", + }, + }, + "metas": Array [ + Object { + "x-oas-optional": true, + }, + ], + "type": "object", + }, "tools": Object { "flags": Object { "default": [Function], diff --git a/x-pack/plugins/inference/common/chat_complete/tool_schema.ts b/x-pack/plugins/inference/common/chat_complete/tool_schema.ts index 5ca3e0ab57a49..b23c03aaad775 100644 --- a/x-pack/plugins/inference/common/chat_complete/tool_schema.ts +++ b/x-pack/plugins/inference/common/chat_complete/tool_schema.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { Required, ValuesType, UnionToIntersection } from 'utility-types'; +import { Required, ValuesType } from 'utility-types'; interface ToolSchemaFragmentBase { description?: string; @@ -13,7 +13,7 @@ interface ToolSchemaFragmentBase { interface ToolSchemaTypeObject extends ToolSchemaFragmentBase { type: 'object'; - properties: Record; + properties: Record; required?: string[] | readonly string[]; } @@ -35,28 +35,18 @@ interface ToolSchemaTypeNumber extends ToolSchemaFragmentBase { enum?: string[] | readonly string[]; } -interface ToolSchemaAnyOf extends ToolSchemaFragmentBase { - anyOf: ToolSchemaType[]; -} - -interface ToolSchemaAllOf extends ToolSchemaFragmentBase { - allOf: ToolSchemaType[]; -} - interface ToolSchemaTypeArray extends ToolSchemaFragmentBase { type: 'array'; items: Exclude; } -type ToolSchemaType = +export type ToolSchemaType = | ToolSchemaTypeObject | ToolSchemaTypeString | ToolSchemaTypeBoolean | ToolSchemaTypeNumber | ToolSchemaTypeArray; -type ToolSchemaFragment = ToolSchemaType | ToolSchemaAnyOf | ToolSchemaAllOf; - type FromToolSchemaObject = Required< { [key in keyof TToolSchemaObject['properties']]?: FromToolSchema< @@ -79,17 +69,9 @@ type FromToolSchemaString = ? ValuesType : string; -type FromToolSchemaAnyOf = FromToolSchema< - ValuesType ->; - -type FromToolSchemaAllOf = UnionToIntersection< - FromToolSchema> ->; - export type ToolSchema = ToolSchemaTypeObject; -export type FromToolSchema = +export type FromToolSchema = TToolSchema extends ToolSchemaTypeObject ? FromToolSchemaObject : TToolSchema extends ToolSchemaTypeArray @@ -100,8 +82,4 @@ export type FromToolSchema = ? number : TToolSchema extends ToolSchemaTypeString ? FromToolSchemaString - : TToolSchema extends ToolSchemaAnyOf - ? FromToolSchemaAnyOf - : TToolSchema extends ToolSchemaAllOf - ? FromToolSchemaAllOf : never; diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.mocks.ts b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.mocks.ts new file mode 100644 index 0000000000000..4b9d0b4648985 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.mocks.ts @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export const processVertexStreamMock = jest.fn(); + +jest.doMock('./process_vertex_stream', () => { + const actual = jest.requireActual('./process_vertex_stream'); + return { + ...actual, + processVertexStream: processVertexStreamMock, + }; +}); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts new file mode 100644 index 0000000000000..3fe8c917c0015 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts @@ -0,0 +1,396 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { processVertexStreamMock } from './gemini_adapter.test.mocks'; +import { PassThrough } from 'stream'; +import { noop, tap, lastValueFrom, toArray, Subject } from 'rxjs'; +import type { InferenceExecutor } from '../../utils/inference_executor'; +import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream'; +import { MessageRole } from '../../../../common/chat_complete'; +import { ToolChoiceType } from '../../../../common/chat_complete/tools'; +import { geminiAdapter } from './gemini_adapter'; + +describe('geminiAdapter', () => { + const executorMock = { + invoke: jest.fn(), + } as InferenceExecutor & { invoke: jest.MockedFn }; + + beforeEach(() => { + executorMock.invoke.mockReset(); + processVertexStreamMock.mockReset().mockImplementation(() => tap(noop)); + }); + + function getCallParams() { + const params = executorMock.invoke.mock.calls[0][0].subActionParams as Record; + return { + messages: params.messages, + tools: params.tools, + toolConfig: params.toolConfig, + systemInstruction: params.systemInstruction, + }; + } + + describe('#chatComplete()', () => { + beforeEach(() => { + executorMock.invoke.mockImplementation(async () => { + return { + actionId: '', + status: 'ok', + data: new PassThrough(), + }; + }); + }); + + it('calls `executor.invoke` with the right fixed parameters', () => { + geminiAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + expect(executorMock.invoke).toHaveBeenCalledWith({ + subAction: 'invokeStream', + subActionParams: { + messages: [ + { + parts: [{ text: 'question' }], + role: 'user', + }, + ], + tools: [], + temperature: 0, + stopSequences: ['\n\nHuman:'], + }, + }); + }); + + it('correctly format tools', () => { + geminiAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + tools: { + myFunction: { + description: 'myFunction', + }, + myFunctionWithArgs: { + description: 'myFunctionWithArgs', + schema: { + type: 'object', + properties: { + foo: { + type: 'string', + description: 'foo', + }, + }, + required: ['foo'], + }, + }, + }, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { tools } = getCallParams(); + expect(tools).toEqual([ + { + functionDeclarations: [ + { + description: 'myFunction', + name: 'myFunction', + parameters: { + properties: {}, + type: 'OBJECT', + }, + }, + { + description: 'myFunctionWithArgs', + name: 'myFunctionWithArgs', + parameters: { + properties: { + foo: { + description: 'foo', + enum: undefined, + type: 'STRING', + }, + }, + required: ['foo'], + type: 'OBJECT', + }, + }, + ], + }, + ]); + }); + + it('correctly format messages', () => { + geminiAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + { + role: MessageRole.Assistant, + content: 'answer', + }, + { + role: MessageRole.User, + content: 'another question', + }, + { + role: MessageRole.Assistant, + content: null, + toolCalls: [ + { + function: { + name: 'my_function', + arguments: { + foo: 'bar', + }, + }, + toolCallId: '0', + }, + ], + }, + { + role: MessageRole.Tool, + toolCallId: '0', + response: { + bar: 'foo', + }, + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { messages } = getCallParams(); + expect(messages).toEqual([ + { + parts: [ + { + text: 'question', + }, + ], + role: 'user', + }, + { + parts: [ + { + text: 'answer', + }, + ], + role: 'assistant', + }, + { + parts: [ + { + text: 'another question', + }, + ], + role: 'user', + }, + { + parts: [ + { + functionCall: { + args: { + foo: 'bar', + }, + name: 'my_function', + }, + }, + ], + role: 'assistant', + }, + { + parts: [ + { + functionResponse: { + name: '0', + response: { + bar: 'foo', + }, + }, + }, + ], + role: 'user', + }, + ]); + }); + + it('groups messages from the same user', () => { + geminiAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + { + role: MessageRole.User, + content: 'another question', + }, + { + role: MessageRole.Assistant, + content: 'answer', + }, + { + role: MessageRole.Assistant, + content: null, + toolCalls: [ + { + function: { + name: 'my_function', + arguments: { + foo: 'bar', + }, + }, + toolCallId: '0', + }, + ], + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { messages } = getCallParams(); + expect(messages).toEqual([ + { + parts: [ + { + text: 'question', + }, + { + text: 'another question', + }, + ], + role: 'user', + }, + { + parts: [ + { + text: 'answer', + }, + { + functionCall: { + args: { + foo: 'bar', + }, + name: 'my_function', + }, + }, + ], + role: 'assistant', + }, + ]); + }); + + it('correctly format system message', () => { + geminiAdapter.chatComplete({ + executor: executorMock, + system: 'Some system message', + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { systemInstruction } = getCallParams(); + expect(systemInstruction).toEqual('Some system message'); + }); + + it('correctly format tool choice', () => { + geminiAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + toolChoice: ToolChoiceType.required, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { toolConfig } = getCallParams(); + expect(toolConfig).toEqual({ mode: 'ANY' }); + }); + + it('correctly format tool choice for named function', () => { + geminiAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + toolChoice: { function: 'foobar' }, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { toolConfig } = getCallParams(); + expect(toolConfig).toEqual({ mode: 'ANY', allowedFunctionNames: ['foobar'] }); + }); + + it('process response events via processVertexStream', async () => { + const source$ = new Subject>(); + + const tapFn = jest.fn(); + processVertexStreamMock.mockImplementation(() => tap(tapFn)); + + executorMock.invoke.mockImplementation(async () => { + return { + actionId: '', + status: 'ok', + data: observableIntoEventSourceStream(source$), + }; + }); + + const response$ = geminiAdapter.chatComplete({ + executor: executorMock, + messages: [ + { + role: MessageRole.User, + content: 'question', + }, + ], + }); + + source$.next({ chunk: 1 }); + source$.next({ chunk: 2 }); + source$.complete(); + + const allChunks = await lastValueFrom(response$.pipe(toArray())); + + expect(allChunks).toEqual([{ chunk: 1 }, { chunk: 2 }]); + + expect(tapFn).toHaveBeenCalledTimes(2); + expect(tapFn).toHaveBeenCalledWith({ chunk: 1 }); + expect(tapFn).toHaveBeenCalledWith({ chunk: 2 }); + }); + }); +}); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts new file mode 100644 index 0000000000000..8f6d02da0e3f8 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts @@ -0,0 +1,213 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import * as Gemini from '@google/generative-ai'; +import { from, map, switchMap } from 'rxjs'; +import { Readable } from 'stream'; +import type { InferenceConnectorAdapter } from '../../types'; +import { Message, MessageRole } from '../../../../common/chat_complete'; +import { ToolChoiceType, ToolOptions } from '../../../../common/chat_complete/tools'; +import type { ToolSchema, ToolSchemaType } from '../../../../common/chat_complete/tool_schema'; +import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable'; +import { processVertexStream } from './process_vertex_stream'; +import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types'; + +export const geminiAdapter: InferenceConnectorAdapter = { + chatComplete: ({ executor, system, messages, toolChoice, tools }) => { + return from( + executor.invoke({ + subAction: 'invokeStream', + subActionParams: { + messages: messagesToGemini({ messages }), + systemInstruction: system, + tools: toolsToGemini(tools), + toolConfig: toolChoiceToConfig(toolChoice), + temperature: 0, + stopSequences: ['\n\nHuman:'], + }, + }) + ).pipe( + switchMap((response) => { + const readable = response.data as Readable; + return eventSourceStreamIntoObservable(readable); + }), + map((line) => { + return JSON.parse(line) as GenerateContentResponseChunk; + }), + processVertexStream() + ); + }, +}; + +function toolChoiceToConfig(toolChoice: ToolOptions['toolChoice']): GeminiToolConfig | undefined { + if (toolChoice === ToolChoiceType.required) { + return { + mode: 'ANY', + }; + } else if (toolChoice === ToolChoiceType.none) { + return { + mode: 'NONE', + }; + } else if (toolChoice === ToolChoiceType.auto) { + return { + mode: 'AUTO', + }; + } else if (toolChoice) { + return { + mode: 'ANY', + allowedFunctionNames: [toolChoice.function], + }; + } + return undefined; +} + +function toolsToGemini(tools: ToolOptions['tools']): Gemini.Tool[] { + return tools + ? [ + { + functionDeclarations: Object.entries(tools ?? {}).map( + ([toolName, { description, schema }]) => { + return { + name: toolName, + description, + parameters: schema + ? toolSchemaToGemini({ schema }) + : { + type: Gemini.FunctionDeclarationSchemaType.OBJECT, + properties: {}, + }, + }; + } + ), + }, + ] + : []; +} + +function toolSchemaToGemini({ schema }: { schema: ToolSchema }): Gemini.FunctionDeclarationSchema { + const convertSchemaType = ({ + def, + }: { + def: ToolSchemaType; + }): Gemini.FunctionDeclarationSchemaProperty => { + switch (def.type) { + case 'array': + return { + type: Gemini.FunctionDeclarationSchemaType.ARRAY, + description: def.description, + items: convertSchemaType({ def: def.items }) as Gemini.FunctionDeclarationSchema, + }; + case 'object': + return { + type: Gemini.FunctionDeclarationSchemaType.OBJECT, + description: def.description, + required: def.required as string[], + properties: Object.entries(def.properties).reduce< + Record + >((properties, [key, prop]) => { + properties[key] = convertSchemaType({ def: prop }) as Gemini.FunctionDeclarationSchema; + return properties; + }, {}), + }; + case 'string': + return { + type: Gemini.FunctionDeclarationSchemaType.STRING, + description: def.description, + enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined, + }; + case 'boolean': + return { + type: Gemini.FunctionDeclarationSchemaType.BOOLEAN, + description: def.description, + enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined, + }; + case 'number': + return { + type: Gemini.FunctionDeclarationSchemaType.NUMBER, + description: def.description, + enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined, + }; + } + }; + + return { + type: Gemini.FunctionDeclarationSchemaType.OBJECT, + required: schema.required as string[], + properties: Object.entries(schema.properties).reduce< + Record + >((properties, [key, def]) => { + properties[key] = convertSchemaType({ def }); + return properties; + }, {}), + }; +} + +function messagesToGemini({ messages }: { messages: Message[] }): GeminiMessage[] { + return messages.map(messageToGeminiMapper()).reduce((output, message) => { + // merging consecutive messages from the same user, as Gemini requires multi-turn messages + const previousMessage = output.length ? output[output.length - 1] : undefined; + if (previousMessage?.role === message.role) { + previousMessage.parts.push(...message.parts); + } else { + output.push(message); + } + return output; + }, []); +} + +function messageToGeminiMapper() { + return (message: Message): GeminiMessage => { + const role = message.role; + + switch (role) { + case MessageRole.Assistant: + const assistantMessage: GeminiMessage = { + role: 'assistant', + parts: [ + ...(message.content ? [{ text: message.content }] : []), + ...(message.toolCalls ?? []).map((toolCall) => { + return { + functionCall: { + name: toolCall.function.name, + args: ('arguments' in toolCall.function + ? toolCall.function.arguments + : {}) as object, + }, + }; + }), + ], + }; + return assistantMessage; + + case MessageRole.User: + const userMessage: GeminiMessage = { + role: 'user', + parts: [ + { + text: message.content, + }, + ], + }; + return userMessage; + + case MessageRole.Tool: + // tool responses are provided as user messages + const toolMessage: GeminiMessage = { + role: 'user', + parts: [ + { + functionResponse: { + name: message.toolCallId, + response: message.response as object, + }, + }, + ], + }; + return toolMessage; + } + }; +} diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/gemini/index.ts b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/index.ts new file mode 100644 index 0000000000000..abd7c3e552c0f --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/index.ts @@ -0,0 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { geminiAdapter } from './gemini_adapter'; diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/gemini/process_vertex_stream.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/process_vertex_stream.test.ts new file mode 100644 index 0000000000000..78e0da0a384b8 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/process_vertex_stream.test.ts @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { TestScheduler } from 'rxjs/testing'; +import { ChatCompletionEventType } from '../../../../common/chat_complete'; +import { processVertexStream } from './process_vertex_stream'; +import type { GenerateContentResponseChunk } from './types'; + +describe('processVertexStream', () => { + const getTestScheduler = () => + new TestScheduler((actual, expected) => { + expect(actual).toEqual(expected); + }); + + it('completes when the source completes', () => { + getTestScheduler().run(({ expectObservable, hot }) => { + const source$ = hot('----|'); + + const processed$ = source$.pipe(processVertexStream()); + + expectObservable(processed$).toBe('----|'); + }); + }); + + it('emits a chunk event when the source emits content', () => { + getTestScheduler().run(({ expectObservable, hot }) => { + const chunk: GenerateContentResponseChunk = { + candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'some chunk' }] } }], + }; + + const source$ = hot('--a', { a: chunk }); + + const processed$ = source$.pipe(processVertexStream()); + + expectObservable(processed$).toBe('--a', { + a: { + content: 'some chunk', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + }); + }); + }); + + it('emits a chunk event when the source emits a function call', () => { + getTestScheduler().run(({ expectObservable, hot }) => { + const chunk: GenerateContentResponseChunk = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ functionCall: { name: 'func1', args: { arg1: true } } }], + }, + }, + ], + }; + + const source$ = hot('--a', { a: chunk }); + + const processed$ = source$.pipe(processVertexStream()); + + expectObservable(processed$).toBe('--a', { + a: { + content: '', + tool_calls: [ + { + index: 0, + toolCallId: expect.any(String), + function: { name: 'func1', arguments: JSON.stringify({ arg1: true }) }, + }, + ], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + }); + }); + }); + + it('emits a token count event when the source emits content with usageMetadata', () => { + getTestScheduler().run(({ expectObservable, hot }) => { + const chunk: GenerateContentResponseChunk = { + candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'last chunk' }] } }], + usageMetadata: { + candidatesTokenCount: 1, + promptTokenCount: 2, + totalTokenCount: 3, + }, + }; + + const source$ = hot('--a', { a: chunk }); + + const processed$ = source$.pipe(processVertexStream()); + + expectObservable(processed$).toBe('--(ab)', { + a: { + tokens: { + completion: 1, + prompt: 2, + total: 3, + }, + type: ChatCompletionEventType.ChatCompletionTokenCount, + }, + b: { + content: 'last chunk', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + }); + }); + }); + + it('emits for multiple chunks', () => { + getTestScheduler().run(({ expectObservable, hot }) => { + const chunkA: GenerateContentResponseChunk = { + candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'chunk A' }] } }], + }; + const chunkB: GenerateContentResponseChunk = { + candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'chunk B' }] } }], + }; + const chunkC: GenerateContentResponseChunk = { + candidates: [{ index: 0, content: { role: 'model', parts: [{ text: 'chunk C' }] } }], + }; + + const source$ = hot('-a--b---c-|', { + a: chunkA, + b: chunkB, + c: chunkC, + }); + + const processed$ = source$.pipe(processVertexStream()); + + expectObservable(processed$).toBe('-a--b---c-|', { + a: { + content: 'chunk A', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + b: { + content: 'chunk B', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + c: { + content: 'chunk C', + tool_calls: [], + type: ChatCompletionEventType.ChatCompletionChunk, + }, + }); + }); + }); +}); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/gemini/process_vertex_stream.ts b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/process_vertex_stream.ts new file mode 100644 index 0000000000000..ec3fa8e82eed6 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/process_vertex_stream.ts @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Observable } from 'rxjs'; +import { + ChatCompletionChunkEvent, + ChatCompletionTokenCountEvent, + ChatCompletionEventType, +} from '../../../../common/chat_complete'; +import { generateFakeToolCallId } from '../../utils'; +import type { GenerateContentResponseChunk } from './types'; + +export function processVertexStream() { + return (source: Observable) => + new Observable((subscriber) => { + function handleNext(value: GenerateContentResponseChunk) { + // completion: only present on last chunk + if (value.usageMetadata) { + subscriber.next({ + type: ChatCompletionEventType.ChatCompletionTokenCount, + tokens: { + prompt: value.usageMetadata.promptTokenCount, + completion: value.usageMetadata.candidatesTokenCount, + total: value.usageMetadata.totalTokenCount, + }, + }); + } + + const contentPart = value.candidates?.[0].content.parts[0]; + const completion = contentPart?.text; + const toolCall = contentPart?.functionCall; + + if (completion || toolCall) { + subscriber.next({ + type: ChatCompletionEventType.ChatCompletionChunk, + content: completion ?? '', + tool_calls: toolCall + ? [ + { + index: 0, + toolCallId: generateFakeToolCallId(), + function: { name: toolCall.name, arguments: JSON.stringify(toolCall.args) }, + }, + ] + : [], + }); + } + } + + source.subscribe({ + next: (value) => { + try { + handleNext(value); + } catch (error) { + subscriber.error(error); + } + }, + error: (err) => { + subscriber.error(err); + }, + complete: () => { + subscriber.complete(); + }, + }); + }); +} diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/gemini/types.ts b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/types.ts new file mode 100644 index 0000000000000..7d00057d8b801 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/adapters/gemini/types.ts @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { GenerateContentResponse, Part } from '@google/generative-ai'; + +export interface GenerateContentResponseUsageMetadata { + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; +} + +/** + * Actual type for chunks, as the type from the google package is missing the + * usage metadata. + */ +export type GenerateContentResponseChunk = GenerateContentResponse & { + usageMetadata?: GenerateContentResponseUsageMetadata; +}; + +/** + * We need to use the connector's format, not directly Gemini's... + * In practice, 'parts' get mapped to 'content' + * + * See x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts + */ +export interface GeminiMessage { + role: 'assistant' | 'user'; + parts: Part[]; +} + +export interface GeminiToolConfig { + mode: 'AUTO' | 'ANY' | 'NONE'; + allowedFunctionNames?: string[]; +} diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.test.ts b/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.test.ts index 272ad76538898..9e0b0da6d5894 100644 --- a/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.test.ts +++ b/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.test.ts @@ -8,17 +8,18 @@ import { InferenceConnectorType } from '../../../common/connectors'; import { getInferenceAdapter } from './get_inference_adapter'; import { openAIAdapter } from './openai'; +import { geminiAdapter } from './gemini'; describe('getInferenceAdapter', () => { it('returns the openAI adapter for OpenAI type', () => { expect(getInferenceAdapter(InferenceConnectorType.OpenAI)).toBe(openAIAdapter); }); - it('returns undefined for Bedrock type', () => { - expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(undefined); + it('returns the gemini adapter for Gemini type', () => { + expect(getInferenceAdapter(InferenceConnectorType.Gemini)).toBe(geminiAdapter); }); - it('returns undefined for Gemini type', () => { - expect(getInferenceAdapter(InferenceConnectorType.Gemini)).toBe(undefined); + it('returns undefined for Bedrock type', () => { + expect(getInferenceAdapter(InferenceConnectorType.Bedrock)).toBe(undefined); }); }); diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.ts b/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.ts index a62ec8b795608..0538d828a473a 100644 --- a/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.ts +++ b/x-pack/plugins/inference/server/chat_complete/adapters/get_inference_adapter.ts @@ -8,6 +8,7 @@ import { InferenceConnectorType } from '../../../common/connectors'; import type { InferenceConnectorAdapter } from '../types'; import { openAIAdapter } from './openai'; +import { geminiAdapter } from './gemini'; export const getInferenceAdapter = ( connectorType: InferenceConnectorType @@ -16,11 +17,10 @@ export const getInferenceAdapter = ( case InferenceConnectorType.OpenAI: return openAIAdapter; - case InferenceConnectorType.Bedrock: - // not implemented yet - break; - case InferenceConnectorType.Gemini: + return geminiAdapter; + + case InferenceConnectorType.Bedrock: // not implemented yet break; } diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.ts b/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.ts index 80fa9bfb781f5..d72ceb2020e8a 100644 --- a/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.ts +++ b/x-pack/plugins/inference/server/chat_complete/adapters/openai/index.ts @@ -25,7 +25,7 @@ import type { ToolOptions } from '../../../../common/chat_complete/tools'; import { createTokenLimitReachedError } from '../../../../common/chat_complete/errors'; import { createInferenceInternalError } from '../../../../common/errors'; import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable'; -import { InferenceConnectorAdapter } from '../../types'; +import type { InferenceConnectorAdapter } from '../../types'; export const openAIAdapter: InferenceConnectorAdapter = { chatComplete: ({ executor, system, messages, toolChoice, tools }) => { @@ -76,6 +76,7 @@ export const openAIAdapter: InferenceConnectorAdapter = { const delta = chunk.choices[0].delta; return { + type: ChatCompletionEventType.ChatCompletionChunk, content: delta.content ?? '', tool_calls: delta.tool_calls?.map((toolCall) => { @@ -88,7 +89,6 @@ export const openAIAdapter: InferenceConnectorAdapter = { index: toolCall.index, }; }) ?? [], - type: ChatCompletionEventType.ChatCompletionChunk, }; }) ); diff --git a/x-pack/plugins/inference/server/chat_complete/utils/generate_fake_tool_call_id.ts b/x-pack/plugins/inference/server/chat_complete/utils/generate_fake_tool_call_id.ts new file mode 100644 index 0000000000000..1da675b61c4ec --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/utils/generate_fake_tool_call_id.ts @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { v4 } from 'uuid'; + +export function generateFakeToolCallId() { + return v4().substr(0, 6); +} diff --git a/x-pack/plugins/inference/server/chat_complete/utils/index.ts b/x-pack/plugins/inference/server/chat_complete/utils/index.ts index dea2ac65f4755..d9344164eff46 100644 --- a/x-pack/plugins/inference/server/chat_complete/utils/index.ts +++ b/x-pack/plugins/inference/server/chat_complete/utils/index.ts @@ -12,3 +12,4 @@ export { type InferenceExecutor, } from './inference_executor'; export { chunksIntoMessage } from './chunks_into_message'; +export { generateFakeToolCallId } from './generate_fake_tool_call_id'; diff --git a/x-pack/plugins/inference/server/routes/chat_complete.ts b/x-pack/plugins/inference/server/routes/chat_complete.ts index 6b5aea7b71696..bfa95fdbb9213 100644 --- a/x-pack/plugins/inference/server/routes/chat_complete.ts +++ b/x-pack/plugins/inference/server/routes/chat_complete.ts @@ -57,7 +57,7 @@ const chatCompleteBodySchema: Type = schema.object({ schema.object({ role: schema.literal(MessageRole.Assistant), content: schema.oneOf([schema.string(), schema.literal(null)]), - toolCalls: toolCallSchema, + toolCalls: schema.maybe(toolCallSchema), }), schema.object({ role: schema.literal(MessageRole.User), diff --git a/x-pack/plugins/stack_connectors/common/gemini/schema.ts b/x-pack/plugins/stack_connectors/common/gemini/schema.ts index 543070c705907..43bf7337626e3 100644 --- a/x-pack/plugins/stack_connectors/common/gemini/schema.ts +++ b/x-pack/plugins/stack_connectors/common/gemini/schema.ts @@ -57,16 +57,24 @@ export const RunActionRawResponseSchema = schema.any(); export const InvokeAIActionParamsSchema = schema.object({ messages: schema.any(), + systemInstruction: schema.maybe(schema.string()), model: schema.maybe(schema.string()), temperature: schema.maybe(schema.number()), stopSequences: schema.maybe(schema.arrayOf(schema.string())), signal: schema.maybe(schema.any()), timeout: schema.maybe(schema.number()), tools: schema.maybe(schema.arrayOf(schema.any())), + toolConfig: schema.maybe( + schema.object({ + mode: schema.oneOf([schema.literal('AUTO'), schema.literal('ANY'), schema.literal('NONE')]), + allowedFunctionNames: schema.maybe(schema.arrayOf(schema.string())), + }) + ), }); export const InvokeAIRawActionParamsSchema = schema.object({ messages: schema.any(), + systemInstruction: schema.maybe(schema.string()), model: schema.maybe(schema.string()), temperature: schema.maybe(schema.number()), stopSequences: schema.maybe(schema.arrayOf(schema.string())), diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts index 949ae0a6c1bd2..94dd7aa0d153c 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.test.ts @@ -239,6 +239,10 @@ describe('GeminiConnector', () => { content: 'What is the capital of France?', }, ], + toolConfig: { + mode: 'ANY' as const, + allowedFunctionNames: ['foo', 'bar'], + }, }; it('the API call is successful with correct request parameters', async () => { @@ -260,6 +264,12 @@ describe('GeminiConnector', () => { temperature: 0, maxOutputTokens: 8192, }, + tool_config: { + function_calling_config: { + mode: 'ANY', + allowed_function_names: ['foo', 'bar'], + }, + }, safety_settings: [ { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' }, ], @@ -299,6 +309,12 @@ describe('GeminiConnector', () => { temperature: 0, maxOutputTokens: 8192, }, + tool_config: { + function_calling_config: { + mode: 'ANY', + allowed_function_names: ['foo', 'bar'], + }, + }, safety_settings: [ { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_ONLY_HIGH' }, ], diff --git a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts index 75f7458d3d6b3..895dfe66d6de4 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/gemini/gemini.ts @@ -64,6 +64,12 @@ interface Payload { temperature: number; maxOutputTokens: number; }; + tool_config?: { + function_calling_config: { + mode: 'AUTO' | 'ANY' | 'NONE'; + allowed_function_names?: string[]; + }; + }; safety_settings: Array<{ category: string; threshold: string }>; } @@ -278,12 +284,22 @@ export class GeminiConnector extends SubActionConnector { } public async invokeAI( - { messages, model, temperature = 0, signal, timeout }: InvokeAIActionParams, + { + messages, + systemInstruction, + model, + temperature = 0, + signal, + timeout, + toolConfig, + }: InvokeAIActionParams, connectorUsageCollector: ConnectorUsageCollector ): Promise { const res = await this.runApi( { - body: JSON.stringify(formatGeminiPayload(messages, temperature)), + body: JSON.stringify( + formatGeminiPayload({ messages, temperature, toolConfig, systemInstruction }) + ), model, signal, timeout, @@ -295,12 +311,23 @@ export class GeminiConnector extends SubActionConnector { } public async invokeAIRaw( - { messages, model, temperature = 0, signal, timeout, tools }: InvokeAIRawActionParams, + { + messages, + model, + temperature = 0, + signal, + timeout, + tools, + systemInstruction, + }: InvokeAIRawActionParams, connectorUsageCollector: ConnectorUsageCollector ): Promise { const res = await this.runApi( { - body: JSON.stringify({ ...formatGeminiPayload(messages, temperature), tools }), + body: JSON.stringify({ + ...formatGeminiPayload({ messages, temperature, systemInstruction }), + tools, + }), model, signal, timeout, @@ -323,18 +350,23 @@ export class GeminiConnector extends SubActionConnector { public async invokeStream( { messages, + systemInstruction, model, stopSequences, temperature = 0, signal, timeout, tools, + toolConfig, }: InvokeAIActionParams, connectorUsageCollector: ConnectorUsageCollector ): Promise { return (await this.streamAPI( { - body: JSON.stringify({ ...formatGeminiPayload(messages, temperature), tools }), + body: JSON.stringify({ + ...formatGeminiPayload({ messages, temperature, toolConfig, systemInstruction }), + tools, + }), model, stopSequences, signal, @@ -346,16 +378,36 @@ export class GeminiConnector extends SubActionConnector { } /** Format the json body to meet Gemini payload requirements */ -const formatGeminiPayload = ( - data: Array<{ role: string; content: string; parts: MessagePart[] }>, - temperature: number -): Payload => { +const formatGeminiPayload = ({ + messages, + systemInstruction, + temperature, + toolConfig, +}: { + messages: Array<{ role: string; content: string; parts: MessagePart[] }>; + systemInstruction?: string; + toolConfig?: InvokeAIActionParams['toolConfig']; + temperature: number; +}): Payload => { const payload: Payload = { contents: [], generation_config: { temperature, maxOutputTokens: DEFAULT_TOKEN_LIMIT, }, + ...(systemInstruction + ? { system_instruction: { role: 'user', parts: [{ text: systemInstruction }] } } + : {}), + ...(toolConfig + ? { + tool_config: { + function_calling_config: { + mode: toolConfig.mode, + allowed_function_names: toolConfig.allowedFunctionNames, + }, + }, + } + : {}), safety_settings: [ { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, @@ -366,7 +418,7 @@ const formatGeminiPayload = ( }; let previousRole: string | null = null; - for (const row of data) { + for (const row of messages) { const correctRole = row.role === 'assistant' ? 'model' : 'user'; // if data is already preformatted by ActionsClientGeminiChatModel if (row.parts) {