diff --git a/oas_docs/output/kibana.serverless.staging.yaml b/oas_docs/output/kibana.serverless.staging.yaml index a7ab200940aef..b2ed77b3d5f50 100644 --- a/oas_docs/output/kibana.serverless.staging.yaml +++ b/oas_docs/output/kibana.serverless.staging.yaml @@ -22906,6 +22906,7 @@ components: enum: - OpenAI - Azure OpenAI + - Other type: string Security_AI_Assistant_API_Reader: additionalProperties: true diff --git a/oas_docs/output/kibana.serverless.yaml b/oas_docs/output/kibana.serverless.yaml index a7ab200940aef..b2ed77b3d5f50 100644 --- a/oas_docs/output/kibana.serverless.yaml +++ b/oas_docs/output/kibana.serverless.yaml @@ -22906,6 +22906,7 @@ components: enum: - OpenAI - Azure OpenAI + - Other type: string Security_AI_Assistant_API_Reader: additionalProperties: true diff --git a/oas_docs/output/kibana.staging.yaml b/oas_docs/output/kibana.staging.yaml index e4ba9c48a3b46..d59c0d1b040d2 100644 --- a/oas_docs/output/kibana.staging.yaml +++ b/oas_docs/output/kibana.staging.yaml @@ -30731,6 +30731,7 @@ components: enum: - OpenAI - Azure OpenAI + - Other type: string Security_AI_Assistant_API_Reader: additionalProperties: true diff --git a/oas_docs/output/kibana.yaml b/oas_docs/output/kibana.yaml index e4ba9c48a3b46..d59c0d1b040d2 100644 --- a/oas_docs/output/kibana.yaml +++ b/oas_docs/output/kibana.yaml @@ -30731,6 +30731,7 @@ components: enum: - OpenAI - Azure OpenAI + - Other type: string Security_AI_Assistant_API_Reader: additionalProperties: true diff --git a/x-pack/packages/kbn-elastic-assistant-common/docs/openapi/ess/elastic_assistant_api_2023_10_31.bundled.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/docs/openapi/ess/elastic_assistant_api_2023_10_31.bundled.schema.yaml index bc7674a3f730d..e946c357ebffb 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/docs/openapi/ess/elastic_assistant_api_2023_10_31.bundled.schema.yaml +++ b/x-pack/packages/kbn-elastic-assistant-common/docs/openapi/ess/elastic_assistant_api_2023_10_31.bundled.schema.yaml @@ -1194,6 +1194,7 @@ components: enum: - OpenAI - Azure OpenAI + - Other type: string Reader: additionalProperties: true diff --git a/x-pack/packages/kbn-elastic-assistant-common/docs/openapi/serverless/elastic_assistant_api_2023_10_31.bundled.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/docs/openapi/serverless/elastic_assistant_api_2023_10_31.bundled.schema.yaml index 7d5487abb2211..e8b3f0f3dc7a5 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/docs/openapi/serverless/elastic_assistant_api_2023_10_31.bundled.schema.yaml +++ b/x-pack/packages/kbn-elastic-assistant-common/docs/openapi/serverless/elastic_assistant_api_2023_10_31.bundled.schema.yaml @@ -1194,6 +1194,7 @@ components: enum: - OpenAI - Azure OpenAI + - Other type: string Reader: additionalProperties: true diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/conversations/common_attributes.gen.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/conversations/common_attributes.gen.ts index 1ba701474b1f8..1dad26e1628db 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/conversations/common_attributes.gen.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/conversations/common_attributes.gen.ts @@ -46,7 +46,7 @@ export const Reader = z.object({}).catchall(z.unknown()); * Provider */ export type Provider = z.infer; -export const Provider = z.enum(['OpenAI', 'Azure OpenAI']); +export const Provider = z.enum(['OpenAI', 'Azure OpenAI', 'Other']); export type ProviderEnum = typeof Provider.enum; export const ProviderEnum = Provider.enum; diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/conversations/common_attributes.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/conversations/common_attributes.schema.yaml index f6a8189182474..20423236f7423 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/conversations/common_attributes.schema.yaml +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/conversations/common_attributes.schema.yaml @@ -34,6 +34,7 @@ components: enum: - OpenAI - Azure OpenAI + - Other MessageRole: type: string diff --git a/x-pack/packages/kbn-elastic-assistant/impl/connectorland/helpers.tsx b/x-pack/packages/kbn-elastic-assistant/impl/connectorland/helpers.tsx index 2bbc74af5a45a..99550f1cafe75 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/connectorland/helpers.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/connectorland/helpers.tsx @@ -18,6 +18,7 @@ import { PRECONFIGURED_CONNECTOR } from './translations'; enum OpenAiProviderType { OpenAi = 'OpenAI', AzureAi = 'Azure OpenAI', + Other = 'Other', } interface GenAiConfig { diff --git a/x-pack/plugins/actions/server/usage/actions_telemetry.test.ts b/x-pack/plugins/actions/server/usage/actions_telemetry.test.ts index b4f6d785584a4..26c37b36566e4 100644 --- a/x-pack/plugins/actions/server/usage/actions_telemetry.test.ts +++ b/x-pack/plugins/actions/server/usage/actions_telemetry.test.ts @@ -1025,15 +1025,17 @@ describe('actions telemetry', () => { '.d3security': 2, '.gen-ai__Azure OpenAI': 3, '.gen-ai__OpenAI': 1, + '.gen-ai__Other': 1, }; const { countByType, countGenAiProviderTypes } = getCounts(aggs); expect(countByType).toEqual({ __d3security: 2, - '__gen-ai': 4, + '__gen-ai': 5, }); expect(countGenAiProviderTypes).toEqual({ 'Azure OpenAI': 3, OpenAI: 1, + Other: 1, }); }); }); diff --git a/x-pack/plugins/actions/server/usage/types.ts b/x-pack/plugins/actions/server/usage/types.ts index d9fe796c2b4e0..6bdfe316c76e2 100644 --- a/x-pack/plugins/actions/server/usage/types.ts +++ b/x-pack/plugins/actions/server/usage/types.ts @@ -51,6 +51,7 @@ export const byGenAiProviderTypeSchema: MakeSchemaFrom['count_by_t // Known providers: ['Azure OpenAI']: { type: 'long' }, ['OpenAI']: { type: 'long' }, + ['Other']: { type: 'long' }, }; export const byServiceProviderTypeSchema: MakeSchemaFrom['count_active_email_connectors_by_service_type'] = diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts index 3ca1b8edb5036..9a77e645686dd 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.test.ts @@ -65,5 +65,17 @@ describe('Utils', () => { const isOpenModel = isOpenSourceModel(connector); expect(isOpenModel).toEqual(true); }); + + it('should return `true` when apiProvider of OpenAiProviderType.Other is specified', async () => { + const connector = { + actionTypeId: '.gen-ai', + config: { + apiUrl: OPENAI_CHAT_URL, + apiProvider: OpenAiProviderType.Other, + }, + } as unknown as Connector; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(true); + }); }); }); diff --git a/x-pack/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/plugins/elastic_assistant/server/routes/utils.ts index ea05fc814ec69..0fb51c7364809 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/utils.ts @@ -203,19 +203,25 @@ export const isOpenSourceModel = (connector?: Connector): boolean => { } const llmType = getLlmType(connector.actionTypeId); - const connectorApiUrl = connector.config?.apiUrl - ? (connector.config.apiUrl as string) - : undefined; + const isOpenAiType = llmType === 'openai'; + + if (!isOpenAiType) { + return false; + } const connectorApiProvider = connector.config?.apiProvider ? (connector.config?.apiProvider as OpenAiProviderType) : undefined; + if (connectorApiProvider === OpenAiProviderType.Other) { + return true; + } - const isOpenAiType = llmType === 'openai'; - const isOpenAI = - isOpenAiType && - (!connectorApiUrl || - connectorApiUrl === OPENAI_CHAT_URL || - connectorApiProvider === OpenAiProviderType.AzureAi); + const connectorApiUrl = connector.config?.apiUrl + ? (connector.config.apiUrl as string) + : undefined; - return isOpenAiType && !isOpenAI; + return ( + !!connectorApiUrl && + connectorApiUrl !== OPENAI_CHAT_URL && + connectorApiProvider !== OpenAiProviderType.AzureAi + ); }; diff --git a/x-pack/plugins/search_playground/common/types.ts b/x-pack/plugins/search_playground/common/types.ts index c239858b5b459..e2a0ae34c2ef3 100644 --- a/x-pack/plugins/search_playground/common/types.ts +++ b/x-pack/plugins/search_playground/common/types.ts @@ -57,6 +57,7 @@ export enum APIRoutes { export enum LLMs { openai = 'openai', openai_azure = 'openai_azure', + openai_other = 'openai_other', bedrock = 'bedrock', gemini = 'gemini', } diff --git a/x-pack/plugins/search_playground/public/hooks/use_llms_models.test.ts b/x-pack/plugins/search_playground/public/hooks/use_llms_models.test.ts index d661084306583..ebce3883a471b 100644 --- a/x-pack/plugins/search_playground/public/hooks/use_llms_models.test.ts +++ b/x-pack/plugins/search_playground/public/hooks/use_llms_models.test.ts @@ -15,9 +15,10 @@ jest.mock('./use_load_connectors', () => ({ })); const mockConnectors = [ - { id: 'connectorId1', title: 'OpenAI Connector', type: LLMs.openai }, - { id: 'connectorId2', title: 'OpenAI Azure Connector', type: LLMs.openai_azure }, - { id: 'connectorId2', title: 'Bedrock Connector', type: LLMs.bedrock }, + { id: 'connectorId1', name: 'OpenAI Connector', type: LLMs.openai }, + { id: 'connectorId2', name: 'OpenAI Azure Connector', type: LLMs.openai_azure }, + { id: 'connectorId2', name: 'Bedrock Connector', type: LLMs.bedrock }, + { id: 'connectorId3', name: 'OpenAI OSS Model Connector', type: LLMs.openai_other }, ]; const mockUseLoadConnectors = (data: any) => { (useLoadConnectors as jest.Mock).mockReturnValue({ data }); @@ -36,7 +37,7 @@ describe('useLLMsModels Hook', () => { expect(result.current).toEqual([ { connectorId: 'connectorId1', - connectorName: undefined, + connectorName: 'OpenAI Connector', connectorType: LLMs.openai, disabled: false, icon: expect.any(Function), @@ -48,7 +49,7 @@ describe('useLLMsModels Hook', () => { }, { connectorId: 'connectorId1', - connectorName: undefined, + connectorName: 'OpenAI Connector', connectorType: LLMs.openai, disabled: false, icon: expect.any(Function), @@ -60,7 +61,7 @@ describe('useLLMsModels Hook', () => { }, { connectorId: 'connectorId1', - connectorName: undefined, + connectorName: 'OpenAI Connector', connectorType: LLMs.openai, disabled: false, icon: expect.any(Function), @@ -72,19 +73,19 @@ describe('useLLMsModels Hook', () => { }, { connectorId: 'connectorId2', - connectorName: undefined, + connectorName: 'OpenAI Azure Connector', connectorType: LLMs.openai_azure, disabled: false, icon: expect.any(Function), - id: 'connectorId2Azure OpenAI ', - name: 'Azure OpenAI ', + id: 'connectorId2OpenAI Azure Connector (Azure OpenAI)', + name: 'OpenAI Azure Connector (Azure OpenAI)', showConnectorName: false, value: undefined, promptTokenLimit: undefined, }, { connectorId: 'connectorId2', - connectorName: undefined, + connectorName: 'Bedrock Connector', connectorType: LLMs.bedrock, disabled: false, icon: expect.any(Function), @@ -96,7 +97,7 @@ describe('useLLMsModels Hook', () => { }, { connectorId: 'connectorId2', - connectorName: undefined, + connectorName: 'Bedrock Connector', connectorType: LLMs.bedrock, disabled: false, icon: expect.any(Function), @@ -106,6 +107,18 @@ describe('useLLMsModels Hook', () => { value: 'anthropic.claude-3-5-sonnet-20240620-v1:0', promptTokenLimit: 200000, }, + { + connectorId: 'connectorId3', + connectorName: 'OpenAI OSS Model Connector', + connectorType: LLMs.openai_other, + disabled: false, + icon: expect.any(Function), + id: 'connectorId3OpenAI OSS Model Connector (OpenAI Compatible Service)', + name: 'OpenAI OSS Model Connector (OpenAI Compatible Service)', + showConnectorName: false, + value: undefined, + promptTokenLimit: undefined, + }, ]); }); diff --git a/x-pack/plugins/search_playground/public/hooks/use_llms_models.ts b/x-pack/plugins/search_playground/public/hooks/use_llms_models.ts index 7a9b01e085a6d..3d5cee7719f10 100644 --- a/x-pack/plugins/search_playground/public/hooks/use_llms_models.ts +++ b/x-pack/plugins/search_playground/public/hooks/use_llms_models.ts @@ -34,11 +34,22 @@ const mapLlmToModels: Record< }, [LLMs.openai_azure]: { icon: OpenAILogo, - getModels: (connectorName, includeName) => [ + getModels: (connectorName) => [ { label: i18n.translate('xpack.searchPlayground.openAIAzureModel', { - defaultMessage: 'Azure OpenAI {name}', - values: { name: includeName ? `(${connectorName})` : '' }, + defaultMessage: '{name} (Azure OpenAI)', + values: { name: connectorName }, + }), + }, + ], + }, + [LLMs.openai_other]: { + icon: OpenAILogo, + getModels: (connectorName) => [ + { + label: i18n.translate('xpack.searchPlayground.otherOpenAIModel', { + defaultMessage: '{name} (OpenAI Compatible Service)', + values: { name: connectorName }, }), }, ], diff --git a/x-pack/plugins/search_playground/public/hooks/use_load_connectors.test.ts b/x-pack/plugins/search_playground/public/hooks/use_load_connectors.test.ts index 3a68d91fd0246..eb2f36eb62e5f 100644 --- a/x-pack/plugins/search_playground/public/hooks/use_load_connectors.test.ts +++ b/x-pack/plugins/search_playground/public/hooks/use_load_connectors.test.ts @@ -71,6 +71,12 @@ describe('useLoadConnectors', () => { actionTypeId: '.bedrock', isMissingSecrets: false, }, + { + id: '5', + actionTypeId: '.gen-ai', + isMissingSecrets: false, + config: { apiProvider: OpenAiProviderType.Other }, + }, ]; mockedLoadConnectors.mockResolvedValue(connectors); @@ -106,6 +112,16 @@ describe('useLoadConnectors', () => { title: 'Bedrock', type: 'bedrock', }, + { + actionTypeId: '.gen-ai', + config: { + apiProvider: 'Other', + }, + id: '5', + isMissingSecrets: false, + title: 'OpenAI Other', + type: 'openai_other', + }, ]); }); }); diff --git a/x-pack/plugins/search_playground/public/hooks/use_load_connectors.ts b/x-pack/plugins/search_playground/public/hooks/use_load_connectors.ts index 94bb2da37b1ed..3d2a3e8c90b86 100644 --- a/x-pack/plugins/search_playground/public/hooks/use_load_connectors.ts +++ b/x-pack/plugins/search_playground/public/hooks/use_load_connectors.ts @@ -63,6 +63,20 @@ const connectorTypeToLLM: Array<{ type: LLMs.openai, }), }, + { + actionId: OPENAI_CONNECTOR_ID, + actionProvider: OpenAiProviderType.Other, + match: (connector) => + connector.actionTypeId === OPENAI_CONNECTOR_ID && + (connector as OpenAIConnector)?.config?.apiProvider === OpenAiProviderType.Other, + transform: (connector) => ({ + ...connector, + title: i18n.translate('xpack.searchPlayground.openAIOtherConnectorTitle', { + defaultMessage: 'OpenAI Other', + }), + type: LLMs.openai_other, + }), + }, { actionId: BEDROCK_CONNECTOR_ID, match: (connector) => connector.actionTypeId === BEDROCK_CONNECTOR_ID, diff --git a/x-pack/plugins/search_playground/server/lib/get_chat_params.test.ts b/x-pack/plugins/search_playground/server/lib/get_chat_params.test.ts index cbc696a50085e..614d00dc16e66 100644 --- a/x-pack/plugins/search_playground/server/lib/get_chat_params.test.ts +++ b/x-pack/plugins/search_playground/server/lib/get_chat_params.test.ts @@ -152,4 +152,41 @@ describe('getChatParams', () => { ) ).rejects.toThrow('Invalid connector id'); }); + + it('returns the correct chat model and uses the default model when not specified in the params', async () => { + mockActionsClient.get.mockResolvedValue({ + id: '2', + actionTypeId: OPENAI_CONNECTOR_ID, + config: { defaultModel: 'local' }, + }); + + const result = await getChatParams( + { + connectorId: '2', + prompt: 'How does it work?', + citations: false, + }, + { actions, request, logger } + ); + + expect(Prompt).toHaveBeenCalledWith('How does it work?', { + citations: false, + context: true, + type: 'openai', + }); + expect(QuestionRewritePrompt).toHaveBeenCalledWith({ + type: 'openai', + }); + expect(ActionsClientChatOpenAI).toHaveBeenCalledWith({ + logger: expect.anything(), + model: 'local', + connectorId: '2', + actionsClient: expect.anything(), + signal: expect.anything(), + traceId: 'test-uuid', + temperature: 0.2, + maxRetries: 0, + }); + expect(result.chatPrompt).toContain('How does it work?'); + }); }); diff --git a/x-pack/plugins/search_playground/server/lib/get_chat_params.ts b/x-pack/plugins/search_playground/server/lib/get_chat_params.ts index d2c4bb1afaa9d..34f902e0d1ca2 100644 --- a/x-pack/plugins/search_playground/server/lib/get_chat_params.ts +++ b/x-pack/plugins/search_playground/server/lib/get_chat_params.ts @@ -57,7 +57,7 @@ export const getChatParams = async ( actionsClient, logger, connectorId, - model, + model: model || connector?.config?.defaultModel, traceId: uuidv4(), signal: abortSignal, temperature: getDefaultArguments().temperature, diff --git a/x-pack/plugins/security_solution/public/attack_discovery/use_attack_discovery/helpers.ts b/x-pack/plugins/security_solution/public/attack_discovery/use_attack_discovery/helpers.ts index f800651985217..97eb132bdaaeb 100644 --- a/x-pack/plugins/security_solution/public/attack_discovery/use_attack_discovery/helpers.ts +++ b/x-pack/plugins/security_solution/public/attack_discovery/use_attack_discovery/helpers.ts @@ -18,6 +18,7 @@ import { isEmpty } from 'lodash/fp'; enum OpenAiProviderType { OpenAi = 'OpenAI', AzureAi = 'Azure OpenAI', + Other = 'Other', } interface GenAiConfig { diff --git a/x-pack/plugins/stack_connectors/common/openai/constants.ts b/x-pack/plugins/stack_connectors/common/openai/constants.ts index c57720d9847af..3d629360d03f3 100644 --- a/x-pack/plugins/stack_connectors/common/openai/constants.ts +++ b/x-pack/plugins/stack_connectors/common/openai/constants.ts @@ -27,6 +27,7 @@ export enum SUB_ACTION { export enum OpenAiProviderType { OpenAi = 'OpenAI', AzureAi = 'Azure OpenAI', + Other = 'Other', } export const DEFAULT_TIMEOUT_MS = 120000; diff --git a/x-pack/plugins/stack_connectors/common/openai/schema.ts b/x-pack/plugins/stack_connectors/common/openai/schema.ts index f62ee1f35174c..8a08da157b163 100644 --- a/x-pack/plugins/stack_connectors/common/openai/schema.ts +++ b/x-pack/plugins/stack_connectors/common/openai/schema.ts @@ -21,6 +21,12 @@ export const ConfigSchema = schema.oneOf([ defaultModel: schema.string({ defaultValue: DEFAULT_OPENAI_MODEL }), headers: schema.maybe(schema.recordOf(schema.string(), schema.string())), }), + schema.object({ + apiProvider: schema.oneOf([schema.literal(OpenAiProviderType.Other)]), + apiUrl: schema.string(), + defaultModel: schema.string(), + headers: schema.maybe(schema.recordOf(schema.string(), schema.string())), + }), ]); export const SecretsSchema = schema.object({ apiKey: schema.string() }); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/lib/gen_ai/use_get_dashboard.test.ts b/x-pack/plugins/stack_connectors/public/connector_types/lib/gen_ai/use_get_dashboard.test.ts index 8ca9b97292fa3..18bcdc6232792 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/lib/gen_ai/use_get_dashboard.test.ts +++ b/x-pack/plugins/stack_connectors/public/connector_types/lib/gen_ai/use_get_dashboard.test.ts @@ -53,6 +53,7 @@ describe('useGetDashboard', () => { it.each([ ['Azure OpenAI', 'openai'], ['OpenAI', 'openai'], + ['Other', 'openai'], ['Bedrock', 'bedrock'], ])( 'fetches the %p dashboard and sets the dashboard URL with %p', diff --git a/x-pack/plugins/stack_connectors/public/connector_types/openai/connector.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/openai/connector.test.tsx index 03d41dd01caa9..2c8eaf8a76257 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/openai/connector.test.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/openai/connector.test.tsx @@ -50,6 +50,17 @@ const azureConnector = { apiKey: 'thats-a-nice-looking-key', }, }; +const otherOpenAiConnector = { + ...openAiConnector, + config: { + apiUrl: 'https://localhost/oss-llm', + apiProvider: OpenAiProviderType.Other, + defaultModel: 'local-model', + }, + secrets: { + apiKey: 'thats-a-nice-looking-key', + }, +}; const navigateToUrl = jest.fn(); @@ -93,6 +104,24 @@ describe('ConnectorFields renders', () => { expect(getAllByTestId('azure-ai-api-keys-doc')[0]).toBeInTheDocument(); }); + test('other open ai connector fields are rendered', async () => { + const { getAllByTestId } = render( + + {}} /> + + ); + expect(getAllByTestId('config.apiUrl-input')[0]).toBeInTheDocument(); + expect(getAllByTestId('config.apiUrl-input')[0]).toHaveValue( + otherOpenAiConnector.config.apiUrl + ); + expect(getAllByTestId('config.apiProvider-select')[0]).toBeInTheDocument(); + expect(getAllByTestId('config.apiProvider-select')[0]).toHaveValue( + otherOpenAiConnector.config.apiProvider + ); + expect(getAllByTestId('other-ai-api-doc')[0]).toBeInTheDocument(); + expect(getAllByTestId('other-ai-api-keys-doc')[0]).toBeInTheDocument(); + }); + describe('Dashboard link', () => { it('Does not render if isEdit is false and dashboardUrl is defined', async () => { const { queryByTestId } = render( diff --git a/x-pack/plugins/stack_connectors/public/connector_types/openai/connector.tsx b/x-pack/plugins/stack_connectors/public/connector_types/openai/connector.tsx index c940ad76e3643..27cbb9a4dac08 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/openai/connector.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/openai/connector.tsx @@ -24,6 +24,8 @@ import * as i18n from './translations'; import { azureAiConfig, azureAiSecrets, + otherOpenAiConfig, + otherOpenAiSecrets, openAiConfig, openAiSecrets, providerOptions, @@ -85,6 +87,14 @@ const ConnectorFields: React.FC = ({ readOnly, isEdi secretsFormSchema={azureAiSecrets} /> )} + {config != null && config.apiProvider === OpenAiProviderType.Other && ( + + )} {isEdit && ( + {`${i18n.OTHER_OPENAI} ${i18n.DOCUMENTATION}`} + + ), + }} + /> + ), + }, + { + id: 'defaultModel', + label: i18n.DEFAULT_MODEL_LABEL, + helpText: ( + + ), + }, +]; + export const openAiSecrets: SecretsFieldSchema[] = [ { id: 'apiKey', @@ -142,6 +177,31 @@ export const azureAiSecrets: SecretsFieldSchema[] = [ }, ]; +export const otherOpenAiSecrets: SecretsFieldSchema[] = [ + { + id: 'apiKey', + label: i18n.API_KEY_LABEL, + isPasswordField: true, + helpText: ( + + {`${i18n.OTHER_OPENAI} ${i18n.DOCUMENTATION}`} + + ), + }} + /> + ), + }, +]; + export const providerOptions = [ { value: OpenAiProviderType.OpenAi, @@ -153,4 +213,9 @@ export const providerOptions = [ text: i18n.AZURE_AI, label: i18n.AZURE_AI, }, + { + value: OpenAiProviderType.Other, + text: i18n.OTHER_OPENAI, + label: i18n.OTHER_OPENAI, + }, ]; diff --git a/x-pack/plugins/stack_connectors/public/connector_types/openai/params.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/openai/params.test.tsx index 09a2652ad8f1d..7539cc6bf6373 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/openai/params.test.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/openai/params.test.tsx @@ -37,7 +37,7 @@ describe('Gen AI Params Fields renders', () => { expect(getByTestId('bodyJsonEditor')).toHaveProperty('value', '{"message": "test"}'); expect(getByTestId('bodyAddVariableButton')).toBeInTheDocument(); }); - test.each([OpenAiProviderType.OpenAi, OpenAiProviderType.AzureAi])( + test.each([OpenAiProviderType.OpenAi, OpenAiProviderType.AzureAi, OpenAiProviderType.Other])( 'useEffect handles the case when subAction and subActionParams are undefined and apiProvider is %p', (apiProvider) => { const actionParams = { @@ -79,6 +79,9 @@ describe('Gen AI Params Fields renders', () => { if (apiProvider === OpenAiProviderType.AzureAi) { expect(editAction).toHaveBeenCalledWith('subActionParams', { body: DEFAULT_BODY_AZURE }, 0); } + if (apiProvider === OpenAiProviderType.Other) { + expect(editAction).toHaveBeenCalledWith('subActionParams', { body: DEFAULT_BODY }, 0); + } } ); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/openai/translations.ts b/x-pack/plugins/stack_connectors/public/connector_types/openai/translations.ts index 4c72866c6ece4..55815faac1c8e 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/openai/translations.ts +++ b/x-pack/plugins/stack_connectors/public/connector_types/openai/translations.ts @@ -47,6 +47,10 @@ export const AZURE_AI = i18n.translate('xpack.stackConnectors.components.genAi.a defaultMessage: 'Azure OpenAI', }); +export const OTHER_OPENAI = i18n.translate('xpack.stackConnectors.components.genAi.otherAi', { + defaultMessage: 'Other (OpenAI Compatible Service)', +}); + export const DOCUMENTATION = i18n.translate( 'xpack.stackConnectors.components.genAi.documentation', { diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/index.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/index.ts index f8a3a3d32ddb2..5bf0ba6c3a562 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/index.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/index.ts @@ -53,7 +53,11 @@ export const configValidator = (configObject: Config, validatorServices: Validat const { apiProvider } = configObject; - if (apiProvider !== OpenAiProviderType.OpenAi && apiProvider !== OpenAiProviderType.AzureAi) { + if ( + apiProvider !== OpenAiProviderType.OpenAi && + apiProvider !== OpenAiProviderType.AzureAi && + apiProvider !== OpenAiProviderType.Other + ) { throw new Error( `API Provider is not supported${ apiProvider && (apiProvider as OpenAiProviderType).length ? `: ${apiProvider}` : `` diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.test.ts new file mode 100644 index 0000000000000..33722314f5422 --- /dev/null +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.test.ts @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { sanitizeRequest, getRequestWithStreamOption } from './other_openai_utils'; + +describe('Other (OpenAI Compatible Service) Utils', () => { + describe('sanitizeRequest', () => { + it('sets stream to false when stream is set to true in the body', () => { + const body = { + model: 'mistral', + stream: true, + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + + const sanitizedBodyString = sanitizeRequest(JSON.stringify(body)); + expect(sanitizedBodyString).toEqual( + `{\"model\":\"mistral\",\"stream\":false,\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}` + ); + }); + + it('sets stream to false when stream is not defined in the body', () => { + const body = { + model: 'mistral', + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + + const sanitizedBodyString = sanitizeRequest(JSON.stringify(body)); + expect(sanitizedBodyString).toEqual( + `{\"model\":\"mistral\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":false}` + ); + }); + + it('sets stream to false when stream is set to false in the body', () => { + const body = { + model: 'mistral', + stream: false, + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + + const sanitizedBodyString = sanitizeRequest(JSON.stringify(body)); + expect(sanitizedBodyString).toEqual( + `{\"model\":\"mistral\",\"stream\":false,\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}` + ); + }); + + it('does nothing when body is malformed JSON', () => { + const bodyString = `{\"model\":\"mistral\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],,}`; + + const sanitizedBodyString = sanitizeRequest(bodyString); + expect(sanitizedBodyString).toEqual(bodyString); + }); + }); + + describe('getRequestWithStreamOption', () => { + it('sets stream parameter when stream is not defined in the body', () => { + const body = { + model: 'mistral', + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + + const sanitizedBodyString = getRequestWithStreamOption(JSON.stringify(body), true); + expect(sanitizedBodyString).toEqual( + `{\"model\":\"mistral\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":true}` + ); + }); + + it('overrides stream parameter if defined in body', () => { + const body = { + model: 'mistral', + stream: true, + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + + const sanitizedBodyString = getRequestWithStreamOption(JSON.stringify(body), false); + expect(sanitizedBodyString).toEqual( + `{\"model\":\"mistral\",\"stream\":false,\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}` + ); + }); + + it('does nothing when body is malformed JSON', () => { + const bodyString = `{\"model\":\"mistral\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],,}`; + + const sanitizedBodyString = getRequestWithStreamOption(bodyString, false); + expect(sanitizedBodyString).toEqual(bodyString); + }); + }); +}); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.ts new file mode 100644 index 0000000000000..8288e0dba9ad1 --- /dev/null +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.ts @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +/** + * Sanitizes the Other (OpenAI Compatible Service) request body to set stream to false + * so users cannot specify a streaming response when the framework + * is not prepared to handle streaming + * + * The stream parameter is accepted in the ChatCompletion + * API and the Completion API only + */ +export const sanitizeRequest = (body: string): string => { + return getRequestWithStreamOption(body, false); +}; + +/** + * Intercepts the Other (OpenAI Compatible Service) request body to set the stream parameter + * + * The stream parameter is accepted in the ChatCompletion + * API and the Completion API only + */ +export const getRequestWithStreamOption = (body: string, stream: boolean): string => { + try { + const jsonBody = JSON.parse(body); + if (jsonBody) { + jsonBody.stream = stream; + } + + return JSON.stringify(jsonBody); + } catch (err) { + // swallow the error + } + + return body; +}; diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.test.ts index 9dffaab3e5e00..142f3a319eeb6 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.test.ts @@ -19,8 +19,14 @@ import { sanitizeRequest as azureAiSanitizeRequest, getRequestWithStreamOption as azureAiGetRequestWithStreamOption, } from './azure_openai_utils'; +import { + sanitizeRequest as otherOpenAiSanitizeRequest, + getRequestWithStreamOption as otherOpenAiGetRequestWithStreamOption, +} from './other_openai_utils'; + jest.mock('./openai_utils'); jest.mock('./azure_openai_utils'); +jest.mock('./other_openai_utils'); describe('Utils', () => { const azureAiUrl = @@ -38,6 +44,7 @@ describe('Utils', () => { describe('sanitizeRequest', () => { const mockOpenAiSanitizeRequest = openAiSanitizeRequest as jest.Mock; const mockAzureAiSanitizeRequest = azureAiSanitizeRequest as jest.Mock; + const mockOtherOpenAiSanitizeRequest = otherOpenAiSanitizeRequest as jest.Mock; beforeEach(() => { jest.clearAllMocks(); }); @@ -50,24 +57,36 @@ describe('Utils', () => { DEFAULT_OPENAI_MODEL ); expect(mockAzureAiSanitizeRequest).not.toHaveBeenCalled(); + expect(mockOtherOpenAiSanitizeRequest).not.toHaveBeenCalled(); + }); + + it('calls other_openai_utils sanitizeRequest when provider is Other OpenAi', () => { + sanitizeRequest(OpenAiProviderType.Other, OPENAI_CHAT_URL, bodyString, DEFAULT_OPENAI_MODEL); + expect(mockOtherOpenAiSanitizeRequest).toHaveBeenCalledWith(bodyString); + expect(mockOpenAiSanitizeRequest).not.toHaveBeenCalled(); + expect(mockAzureAiSanitizeRequest).not.toHaveBeenCalled(); }); it('calls azure_openai_utils sanitizeRequest when provider is AzureAi', () => { sanitizeRequest(OpenAiProviderType.AzureAi, azureAiUrl, bodyString); expect(mockAzureAiSanitizeRequest).toHaveBeenCalledWith(azureAiUrl, bodyString); expect(mockOpenAiSanitizeRequest).not.toHaveBeenCalled(); + expect(mockOtherOpenAiSanitizeRequest).not.toHaveBeenCalled(); }); it('does not call any helper fns when provider is unrecognized', () => { sanitizeRequest('foo', OPENAI_CHAT_URL, bodyString); expect(mockOpenAiSanitizeRequest).not.toHaveBeenCalled(); expect(mockAzureAiSanitizeRequest).not.toHaveBeenCalled(); + expect(mockOtherOpenAiSanitizeRequest).not.toHaveBeenCalled(); }); }); describe('getRequestWithStreamOption', () => { const mockOpenAiGetRequestWithStreamOption = openAiGetRequestWithStreamOption as jest.Mock; const mockAzureAiGetRequestWithStreamOption = azureAiGetRequestWithStreamOption as jest.Mock; + const mockOtherOpenAiGetRequestWithStreamOption = + otherOpenAiGetRequestWithStreamOption as jest.Mock; beforeEach(() => { jest.clearAllMocks(); }); @@ -88,6 +107,15 @@ describe('Utils', () => { DEFAULT_OPENAI_MODEL ); expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled(); + expect(mockOtherOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); + }); + + it('calls other_openai_utils getRequestWithStreamOption when provider is Other OpenAi', () => { + getRequestWithStreamOption(OpenAiProviderType.Other, OPENAI_CHAT_URL, bodyString, true); + + expect(mockOtherOpenAiGetRequestWithStreamOption).toHaveBeenCalledWith(bodyString, true); + expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); + expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled(); }); it('calls azure_openai_utils getRequestWithStreamOption when provider is AzureAi', () => { @@ -99,6 +127,7 @@ describe('Utils', () => { true ); expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); + expect(mockOtherOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); }); it('does not call any helper fns when provider is unrecognized', () => { @@ -110,6 +139,7 @@ describe('Utils', () => { ); expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled(); + expect(mockOtherOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); }); }); @@ -127,6 +157,19 @@ describe('Utils', () => { }); }); + it('returns correct axios options when provider is other openai and stream is false', () => { + expect(getAxiosOptions(OpenAiProviderType.Other, 'api-abc', false)).toEqual({ + headers: { Authorization: `Bearer api-abc`, ['content-type']: 'application/json' }, + }); + }); + + it('returns correct axios options when provider is other openai and stream is true', () => { + expect(getAxiosOptions(OpenAiProviderType.Other, 'api-abc', true)).toEqual({ + headers: { Authorization: `Bearer api-abc`, ['content-type']: 'application/json' }, + responseType: 'stream', + }); + }); + it('returns correct axios options when provider is azure openai and stream is false', () => { expect(getAxiosOptions(OpenAiProviderType.AzureAi, 'api-abc', false)).toEqual({ headers: { ['api-key']: `api-abc`, ['content-type']: 'application/json' }, diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.ts index 811dfd4ce63b4..3028433656503 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.ts @@ -16,6 +16,10 @@ import { sanitizeRequest as azureAiSanitizeRequest, getRequestWithStreamOption as azureAiGetRequestWithStreamOption, } from './azure_openai_utils'; +import { + sanitizeRequest as otherOpenAiSanitizeRequest, + getRequestWithStreamOption as otherOpenAiGetRequestWithStreamOption, +} from './other_openai_utils'; export const sanitizeRequest = ( provider: string, @@ -28,6 +32,8 @@ export const sanitizeRequest = ( return openAiSanitizeRequest(url, body, defaultModel!); case OpenAiProviderType.AzureAi: return azureAiSanitizeRequest(url, body); + case OpenAiProviderType.Other: + return otherOpenAiSanitizeRequest(body); default: return body; } @@ -42,7 +48,7 @@ export function getRequestWithStreamOption( ): string; export function getRequestWithStreamOption( - provider: OpenAiProviderType.AzureAi, + provider: OpenAiProviderType.AzureAi | OpenAiProviderType.Other, url: string, body: string, stream: boolean @@ -68,6 +74,8 @@ export function getRequestWithStreamOption( return openAiGetRequestWithStreamOption(url, body, stream, defaultModel!); case OpenAiProviderType.AzureAi: return azureAiGetRequestWithStreamOption(url, body, stream); + case OpenAiProviderType.Other: + return otherOpenAiGetRequestWithStreamOption(body, stream); default: return body; } @@ -81,6 +89,7 @@ export const getAxiosOptions = ( const responseType = stream ? { responseType: 'stream' as ResponseType } : {}; switch (provider) { case OpenAiProviderType.OpenAi: + case OpenAiProviderType.Other: return { headers: { Authorization: `Bearer ${apiKey}`, ['content-type']: 'application/json' }, ...responseType, diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts index 87dacaf4e6f17..1362b7610e2cd 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.test.ts @@ -20,6 +20,9 @@ import { RunActionResponseSchema, StreamingResponseSchema } from '../../../commo import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard'; import { PassThrough, Transform } from 'stream'; import { ConnectorUsageCollector } from '@kbn/actions-plugin/server/types'; + +const DEFAULT_OTHER_OPENAI_MODEL = 'local-model'; + jest.mock('../lib/gen_ai/create_gen_ai_dashboard'); const mockTee = jest.fn(); @@ -713,6 +716,431 @@ describe('OpenAIConnector', () => { }); }); + describe('Other OpenAI', () => { + const connector = new OpenAIConnector({ + configurationUtilities: actionsConfigMock.create(), + connector: { id: '1', type: OPENAI_CONNECTOR_ID }, + config: { + apiUrl: 'http://localhost:1234/v1/chat/completions', + apiProvider: OpenAiProviderType.Other, + defaultModel: DEFAULT_OTHER_OPENAI_MODEL, + headers: { + 'X-My-Custom-Header': 'foo', + Authorization: 'override', + }, + }, + secrets: { apiKey: '123' }, + logger, + services: actionsMock.createServices(), + }); + + const sampleOpenAiBody = { + model: DEFAULT_OTHER_OPENAI_MODEL, + messages: [ + { + role: 'user', + content: 'Hello world', + }, + ], + }; + + beforeEach(() => { + // @ts-ignore + connector.request = mockRequest; + jest.clearAllMocks(); + }); + + describe('runApi', () => { + it('the Other OpenAI API call is successful with correct parameters', async () => { + const response = await connector.runApi( + { body: JSON.stringify(sampleOpenAiBody) }, + connectorUsageCollector + ); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith( + { + ...mockDefaults, + url: 'http://localhost:1234/v1/chat/completions', + data: JSON.stringify({ + ...sampleOpenAiBody, + stream: false, + model: DEFAULT_OTHER_OPENAI_MODEL, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + }, + connectorUsageCollector + ); + expect(response).toEqual(mockResponse.data); + }); + + it('overrides stream parameter if set in the body', async () => { + const body = { + model: 'llama-3.1', + messages: [ + { + role: 'user', + content: 'Hello world', + }, + ], + }; + const response = await connector.runApi( + { + body: JSON.stringify({ + ...body, + stream: true, + }), + }, + connectorUsageCollector + ); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith( + { + ...mockDefaults, + url: 'http://localhost:1234/v1/chat/completions', + data: JSON.stringify({ + ...body, + stream: false, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + }, + connectorUsageCollector + ); + expect(response).toEqual(mockResponse.data); + }); + + it('errors during API calls are properly handled', async () => { + // @ts-ignore + connector.request = mockError; + + await expect( + connector.runApi({ body: JSON.stringify(sampleOpenAiBody) }, connectorUsageCollector) + ).rejects.toThrow('API Error'); + }); + }); + + describe('streamApi', () => { + it('the Other OpenAI API call is successful with correct parameters when stream = false', async () => { + const response = await connector.streamApi( + { + body: JSON.stringify(sampleOpenAiBody), + stream: false, + }, + connectorUsageCollector + ); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith( + { + url: 'http://localhost:1234/v1/chat/completions', + method: 'post', + responseSchema: RunActionResponseSchema, + data: JSON.stringify({ + ...sampleOpenAiBody, + stream: false, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + }, + connectorUsageCollector + ); + expect(response).toEqual(mockResponse.data); + }); + + it('the Other OpenAI API call is successful with correct parameters when stream = true', async () => { + const response = await connector.streamApi( + { + body: JSON.stringify(sampleOpenAiBody), + stream: true, + }, + connectorUsageCollector + ); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith( + { + responseType: 'stream', + url: 'http://localhost:1234/v1/chat/completions', + method: 'post', + responseSchema: StreamingResponseSchema, + data: JSON.stringify({ + ...sampleOpenAiBody, + stream: true, + model: DEFAULT_OTHER_OPENAI_MODEL, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + }, + connectorUsageCollector + ); + expect(response).toEqual({ + headers: { 'Content-Type': 'dont-compress-this' }, + ...mockResponse.data, + }); + }); + + it('overrides stream parameter if set in the body with explicit stream parameter', async () => { + const body = { + model: 'llama-3.1', + messages: [ + { + role: 'user', + content: 'Hello world', + }, + ], + }; + const response = await connector.streamApi( + { + body: JSON.stringify({ + ...body, + stream: false, + }), + stream: true, + }, + connectorUsageCollector + ); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith( + { + responseType: 'stream', + url: 'http://localhost:1234/v1/chat/completions', + method: 'post', + responseSchema: StreamingResponseSchema, + data: JSON.stringify({ + ...body, + stream: true, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + }, + connectorUsageCollector + ); + expect(response).toEqual({ + headers: { 'Content-Type': 'dont-compress-this' }, + ...mockResponse.data, + }); + }); + + it('errors during API calls are properly handled', async () => { + // @ts-ignore + connector.request = mockError; + + await expect( + connector.streamApi( + { body: JSON.stringify(sampleOpenAiBody), stream: true }, + connectorUsageCollector + ) + ).rejects.toThrow('API Error'); + }); + }); + + describe('invokeStream', () => { + const mockStream = ( + dataToStream: string[] = [ + 'data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}', + ] + ) => { + const streamMock = createStreamMock(); + dataToStream.forEach((chunk) => { + streamMock.write(chunk); + }); + streamMock.complete(); + mockRequest = jest.fn().mockResolvedValue({ ...mockResponse, data: streamMock.transform }); + return mockRequest; + }; + beforeEach(() => { + // @ts-ignore + connector.request = mockStream(); + }); + + it('the API call is successful with correct request parameters', async () => { + await connector.invokeStream(sampleOpenAiBody, connectorUsageCollector); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith( + { + url: 'http://localhost:1234/v1/chat/completions', + method: 'post', + responseSchema: StreamingResponseSchema, + responseType: 'stream', + data: JSON.stringify({ + ...sampleOpenAiBody, + stream: true, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + }, + connectorUsageCollector + ); + }); + + it('signal is properly passed to streamApi', async () => { + const signal = jest.fn(); + await connector.invokeStream({ ...sampleOpenAiBody, signal }, connectorUsageCollector); + + expect(mockRequest).toHaveBeenCalledWith( + { + url: 'http://localhost:1234/v1/chat/completions', + method: 'post', + responseSchema: StreamingResponseSchema, + responseType: 'stream', + data: JSON.stringify({ + ...sampleOpenAiBody, + stream: true, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + signal, + }, + connectorUsageCollector + ); + }); + + it('timeout is properly passed to streamApi', async () => { + const timeout = 180000; + await connector.invokeStream({ ...sampleOpenAiBody, timeout }, connectorUsageCollector); + + expect(mockRequest).toHaveBeenCalledWith( + { + url: 'http://localhost:1234/v1/chat/completions', + method: 'post', + responseSchema: StreamingResponseSchema, + responseType: 'stream', + data: JSON.stringify({ + ...sampleOpenAiBody, + stream: true, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + timeout, + }, + connectorUsageCollector + ); + }); + + it('errors during API calls are properly handled', async () => { + // @ts-ignore + connector.request = mockError; + + await expect( + connector.invokeStream(sampleOpenAiBody, connectorUsageCollector) + ).rejects.toThrow('API Error'); + }); + + it('responds with a readable stream', async () => { + // @ts-ignore + connector.request = mockStream(); + const response = await connector.invokeStream(sampleOpenAiBody, connectorUsageCollector); + expect(response instanceof PassThrough).toEqual(true); + }); + }); + + describe('invokeAI', () => { + it('the API call is successful with correct parameters', async () => { + const response = await connector.invokeAI(sampleOpenAiBody, connectorUsageCollector); + expect(mockRequest).toBeCalledTimes(1); + expect(mockRequest).toHaveBeenCalledWith( + { + ...mockDefaults, + url: 'http://localhost:1234/v1/chat/completions', + data: JSON.stringify({ + ...sampleOpenAiBody, + stream: false, + model: DEFAULT_OTHER_OPENAI_MODEL, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + }, + connectorUsageCollector + ); + expect(response.message).toEqual(mockResponseString); + expect(response.usage.total_tokens).toEqual(9); + }); + + it('signal is properly passed to runApi', async () => { + const signal = jest.fn(); + await connector.invokeAI({ ...sampleOpenAiBody, signal }, connectorUsageCollector); + + expect(mockRequest).toHaveBeenCalledWith( + { + ...mockDefaults, + url: 'http://localhost:1234/v1/chat/completions', + data: JSON.stringify({ + ...sampleOpenAiBody, + stream: false, + model: DEFAULT_OTHER_OPENAI_MODEL, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + signal, + }, + connectorUsageCollector + ); + }); + + it('timeout is properly passed to runApi', async () => { + const timeout = 180000; + await connector.invokeAI({ ...sampleOpenAiBody, timeout }, connectorUsageCollector); + + expect(mockRequest).toHaveBeenCalledWith( + { + ...mockDefaults, + url: 'http://localhost:1234/v1/chat/completions', + data: JSON.stringify({ + ...sampleOpenAiBody, + stream: false, + model: DEFAULT_OTHER_OPENAI_MODEL, + }), + headers: { + Authorization: 'Bearer 123', + 'X-My-Custom-Header': 'foo', + 'content-type': 'application/json', + }, + timeout, + }, + connectorUsageCollector + ); + }); + + it('errors during API calls are properly handled', async () => { + // @ts-ignore + connector.request = mockError; + + await expect(connector.invokeAI(sampleOpenAiBody, connectorUsageCollector)).rejects.toThrow( + 'API Error' + ); + }); + }); + }); + describe('AzureAI', () => { const connector = new OpenAIConnector({ configurationUtilities: actionsConfigMock.create(), diff --git a/x-pack/plugins/telemetry_collection_xpack/schema/xpack_plugins.json b/x-pack/plugins/telemetry_collection_xpack/schema/xpack_plugins.json index 0de2cbd77db7b..0e5d4156d9760 100644 --- a/x-pack/plugins/telemetry_collection_xpack/schema/xpack_plugins.json +++ b/x-pack/plugins/telemetry_collection_xpack/schema/xpack_plugins.json @@ -73,6 +73,9 @@ }, "[OpenAI]": { "type": "long" + }, + "[Other]": { + "type": "long" } } }, diff --git a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts index 05dfc61dd59e3..8a47b6a882456 100644 --- a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts +++ b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts @@ -147,7 +147,7 @@ export default function genAiTest({ getService }: FtrProviderContext) { statusCode: 400, error: 'Bad Request', message: - 'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected at least one defined value but got [undefined]\n- [1.apiProvider]: expected at least one defined value but got [undefined]', + 'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected at least one defined value but got [undefined]\n- [1.apiProvider]: expected at least one defined value but got [undefined]\n- [2.apiProvider]: expected at least one defined value but got [undefined]', }); }); }); @@ -168,7 +168,7 @@ export default function genAiTest({ getService }: FtrProviderContext) { statusCode: 400, error: 'Bad Request', message: - 'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected value to equal [Azure OpenAI]\n- [1.apiUrl]: expected value of type [string] but got [undefined]', + 'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected value to equal [Azure OpenAI]\n- [1.apiUrl]: expected value of type [string] but got [undefined]\n- [2.apiProvider]: expected value to equal [Other]', }); }); });