From 812825354c93e3dfc059cb740e07ed6df5503574 Mon Sep 17 00:00:00 2001 From: Wangchong Zhou Date: Fri, 9 Aug 2024 16:23:52 +0800 Subject: [PATCH] feat(ai-proxy): allow to use mistral.ai cloud service by omitting upstream_url --- kong/clustering/compat/checkers.lua | 16 ++++++++++++++++ kong/llm/drivers/mistral.lua | 18 +++++++++++++++++- kong/llm/drivers/shared.lua | 1 + kong/llm/schemas/init.lua | 2 +- spec/03-plugins/38-ai-proxy/00-config_spec.lua | 6 +++++- 5 files changed, 40 insertions(+), 3 deletions(-) diff --git a/kong/clustering/compat/checkers.lua b/kong/clustering/compat/checkers.lua index 55dcbbc2bd4ad..308c6ee351756 100644 --- a/kong/clustering/compat/checkers.lua +++ b/kong/clustering/compat/checkers.lua @@ -67,6 +67,22 @@ local compatible_checkers = { has_update = true end + + if config.model.provider == "mistral" and ( + not config.model.options or + config.model.options == ngx.null or + not config.model.options.upstream_url or + config.model.options.upstream_url == ngx.null) then + + log_warn_message('configures ' .. plugin.name .. ' plugin with' .. + ' mistral provider uses fallback upstream_url for managed serivice' .. + dp_version, log_suffix) + + config.model.options = config.model.options or {} + config.model.options.upstream_url = "https://api.mistral.ai:443" + has_update = true + end + end if plugin.name == 'ai-request-transformer' then diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index 566ad903f6fc1..43b059ec80f09 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -144,8 +144,24 @@ end -- returns err or nil function _M.configure_request(conf) + local parsed_url + -- mistral shared operation paths - local parsed_url = socket_url.parse(conf.model.options.upstream_url) + if (conf.model.options and conf.model.options.upstream_url) then + parsed_url = socket_url.parse(conf.model.options.upstream_url) + else + local path = conf.model.options + and conf.model.options.upstream_path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type] + and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + or "/" + if not path then + return nil, fmt("operation %s is not supported for mistral provider", conf.route_type) + end + + parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) + parsed_url.path = path + end -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = (parsed_url.path and string_gsub(parsed_url.path, "^/*", "/")) or "/" diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 15d9ce7e62f26..f134cf68d8fcd 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -79,6 +79,7 @@ _M.upstream_url_format = { gemini = "https://generativelanguage.googleapis.com", gemini_vertex = "https://%s", bedrock = "https://bedrock-runtime.%s.amazonaws.com", + mistral = "https://api.mistral.ai:443" } _M.operation_map = { diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua index 2efcb6b4108a6..0fcc3a058a311 100644 --- a/kong/llm/schemas/init.lua +++ b/kong/llm/schemas/init.lua @@ -271,7 +271,7 @@ return { then_err = "must set %s for azure provider" }}, { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { one_of = { "mistral", "llama2" } }, + if_match = { one_of = { "llama2" } }, then_at_least_one_of = { "model.options.upstream_url" }, then_err = "must set %s for self-hosted providers/models" }}, diff --git a/spec/03-plugins/38-ai-proxy/00-config_spec.lua b/spec/03-plugins/38-ai-proxy/00-config_spec.lua index 516f5a2080e76..34cddb74d6195 100644 --- a/spec/03-plugins/38-ai-proxy/00-config_spec.lua +++ b/spec/03-plugins/38-ai-proxy/00-config_spec.lua @@ -22,7 +22,11 @@ describe(PLUGIN_NAME .. ": (schema)", function() for i, v in ipairs(SELF_HOSTED_MODELS) do - it("requires upstream_url when using self-hosted " .. v .. " model", function() + local op = it + if v == "mistral" then -- mistral.ai now has managed service too! + op = pending + end + op("requires upstream_url when using self-hosted " .. v .. " model", function() local config = { route_type = "llm/v1/chat", auth = {