From 12e2481601d5b0b886f5d9330ea596c91cec2d05 Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Fri, 9 Aug 2024 17:57:37 +0800 Subject: [PATCH] feat(ai-proxy): allow to use mistral.ai cloud service by omitting upstream_url (#13481) AG-95 --- .../unreleased/kong/ai-proxy-mistral-ai.yml | 3 +++ kong/clustering/compat/checkers.lua | 16 ++++++++++++++++ kong/llm/drivers/anthropic.lua | 14 +++++--------- kong/llm/drivers/cohere.lua | 8 ++++---- kong/llm/drivers/mistral.lua | 13 ++++++++++++- kong/llm/drivers/openai.lua | 15 +++++---------- kong/llm/drivers/shared.lua | 1 + kong/llm/schemas/init.lua | 2 +- spec/03-plugins/38-ai-proxy/00-config_spec.lua | 6 +++++- 9 files changed, 52 insertions(+), 26 deletions(-) create mode 100644 changelog/unreleased/kong/ai-proxy-mistral-ai.yml diff --git a/changelog/unreleased/kong/ai-proxy-mistral-ai.yml b/changelog/unreleased/kong/ai-proxy-mistral-ai.yml new file mode 100644 index 000000000000..6c558ba41051 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-mistral-ai.yml @@ -0,0 +1,3 @@ +message: '**ai-proxy**: Allowed mistral provider to use mistral.ai managed service by omitting upstream_url' +type: feature +scope: Plugin diff --git a/kong/clustering/compat/checkers.lua b/kong/clustering/compat/checkers.lua index 55dcbbc2bd4a..308c6ee35175 100644 --- a/kong/clustering/compat/checkers.lua +++ b/kong/clustering/compat/checkers.lua @@ -67,6 +67,22 @@ local compatible_checkers = { has_update = true end + + if config.model.provider == "mistral" and ( + not config.model.options or + config.model.options == ngx.null or + not config.model.options.upstream_url or + config.model.options.upstream_url == ngx.null) then + + log_warn_message('configures ' .. plugin.name .. ' plugin with' .. + ' mistral provider uses fallback upstream_url for managed serivice' .. + dp_version, log_suffix) + + config.model.options = config.model.options or {} + config.model.options.upstream_url = "https://api.mistral.ai:443" + has_update = true + end + end if plugin.name == 'ai-request-transformer' then diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 7161804cd935..332f21878098 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -447,15 +447,11 @@ function _M.configure_request(conf) parsed_url = socket_url.parse(conf.model.options.upstream_url) else parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) - parsed_url.path = conf.model.options - and conf.model.options.upstream_path - or ai_shared.operation_map[DRIVER_NAME][conf.route_type] - and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path - or "/" - - if not parsed_url.path then - return nil, fmt("operation %s is not supported for anthropic provider", conf.route_type) - end + parsed_url.path = (conf.model.options and + conf.model.options.upstream_path) + or (ai_shared.operation_map[DRIVER_NAME][conf.route_type] and + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path) + or "/" end -- if the path is read from a URL capture, ensure that it is valid diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 28c8c64eaac0..d25764164086 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -459,10 +459,10 @@ function _M.configure_request(conf) parsed_url = socket_url.parse(conf.model.options.upstream_url) else parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) - parsed_url.path = conf.model.options - and conf.model.options.upstream_path - or ai_shared.operation_map[DRIVER_NAME][conf.route_type] - and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + parsed_url.path = (conf.model.options and + conf.model.options.upstream_path) + or (ai_shared.operation_map[DRIVER_NAME][conf.route_type] and + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path) or "/" end diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index 566ad903f6fc..d1d2303b6919 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -144,8 +144,19 @@ end -- returns err or nil function _M.configure_request(conf) + local parsed_url + -- mistral shared operation paths - local parsed_url = socket_url.parse(conf.model.options.upstream_url) + if (conf.model.options and conf.model.options.upstream_url) then + parsed_url = socket_url.parse(conf.model.options.upstream_url) + else + parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) + parsed_url.path = (conf.model.options and + conf.model.options.upstream_path) + or (ai_shared.operation_map[DRIVER_NAME][conf.route_type] and + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path) + or "/" + end -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = (parsed_url.path and string_gsub(parsed_url.path, "^/*", "/")) or "/" diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 331ad3b5e7ea..52df1910586e 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -191,17 +191,12 @@ function _M.configure_request(conf) if (conf.model.options and conf.model.options.upstream_url) then parsed_url = socket_url.parse(conf.model.options.upstream_url) else - local path = conf.model.options - and conf.model.options.upstream_path - or ai_shared.operation_map[DRIVER_NAME][conf.route_type] - and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path - or "/" - if not path then - return nil, fmt("operation %s is not supported for openai provider", conf.route_type) - end - parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) - parsed_url.path = path + parsed_url.path = (conf.model.options and + conf.model.options.upstream_path) + or (ai_shared.operation_map[DRIVER_NAME][conf.route_type] and + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path) + or "/" end -- if the path is read from a URL capture, ensure that it is valid diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index b9fa994934b5..b4fa0c854246 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -79,6 +79,7 @@ _M.upstream_url_format = { gemini = "https://generativelanguage.googleapis.com", gemini_vertex = "https://%s", bedrock = "https://bedrock-runtime.%s.amazonaws.com", + mistral = "https://api.mistral.ai:443" } _M.operation_map = { diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua index 2efcb6b4108a..0fcc3a058a31 100644 --- a/kong/llm/schemas/init.lua +++ b/kong/llm/schemas/init.lua @@ -271,7 +271,7 @@ return { then_err = "must set %s for azure provider" }}, { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "mistral", "llama2" } }, + if_match = { one_of = { "llama2" } }, then_at_least_one_of = { "model.options.upstream_url" }, then_err = "must set %s for self-hosted providers/models" }}, diff --git a/spec/03-plugins/38-ai-proxy/00-config_spec.lua b/spec/03-plugins/38-ai-proxy/00-config_spec.lua index 516f5a2080e7..34cddb74d619 100644 --- a/spec/03-plugins/38-ai-proxy/00-config_spec.lua +++ b/spec/03-plugins/38-ai-proxy/00-config_spec.lua @@ -22,7 +22,11 @@ describe(PLUGIN_NAME .. ": (schema)", function() for i, v in ipairs(SELF_HOSTED_MODELS) do - it("requires upstream_url when using self-hosted " .. v .. " model", function() + local op = it + if v == "mistral" then -- mistral.ai now has managed service too! + op = pending + end + op("requires upstream_url when using self-hosted " .. v .. " model", function() local config = { route_type = "llm/v1/chat", auth = {