Skip to content

Commit

Permalink
feat(ai-proxy): allow to use mistral.ai cloud service by omitting ups…
Browse files Browse the repository at this point in the history
…tream_url (#13481)

AG-95
  • Loading branch information
fffonion authored and ProBrian committed Aug 13, 2024
1 parent 9be2736 commit 7636e1b
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 26 deletions.
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/ai-proxy-mistral-ai.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: '**ai-proxy**: Allowed mistral provider to use mistral.ai managed service by omitting upstream_url'
type: feature
scope: Plugin
16 changes: 16 additions & 0 deletions kong/clustering/compat/checkers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ local compatible_checkers = {

has_update = true
end

if config.model.provider == "mistral" and (
not config.model.options or
config.model.options == ngx.null or
not config.model.options.upstream_url or
config.model.options.upstream_url == ngx.null) then

log_warn_message('configures ' .. plugin.name .. ' plugin with' ..
' mistral provider uses fallback upstream_url for managed serivice' ..
dp_version, log_suffix)

config.model.options = config.model.options or {}
config.model.options.upstream_url = "https://api.mistral.ai:443"
has_update = true
end

end

if plugin.name == 'ai-request-transformer' then
Expand Down
14 changes: 5 additions & 9 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,11 @@ function _M.configure_request(conf)
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = conf.model.options
and conf.model.options.upstream_path
or ai_shared.operation_map[DRIVER_NAME][conf.route_type]
and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
or "/"

if not parsed_url.path then
return nil, fmt("operation %s is not supported for anthropic provider", conf.route_type)
end
parsed_url.path = (conf.model.options and
conf.model.options.upstream_path)
or (ai_shared.operation_map[DRIVER_NAME][conf.route_type] and
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path)
or "/"
end

-- if the path is read from a URL capture, ensure that it is valid
Expand Down
8 changes: 4 additions & 4 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,10 @@ function _M.configure_request(conf)
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = conf.model.options
and conf.model.options.upstream_path
or ai_shared.operation_map[DRIVER_NAME][conf.route_type]
and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
parsed_url.path = (conf.model.options and
conf.model.options.upstream_path)
or (ai_shared.operation_map[DRIVER_NAME][conf.route_type] and
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path)
or "/"
end

Expand Down
13 changes: 12 additions & 1 deletion kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,19 @@ end

-- returns err or nil
function _M.configure_request(conf)
local parsed_url

-- mistral shared operation paths
local parsed_url = socket_url.parse(conf.model.options.upstream_url)
if (conf.model.options and conf.model.options.upstream_url) then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = (conf.model.options and
conf.model.options.upstream_path)
or (ai_shared.operation_map[DRIVER_NAME][conf.route_type] and
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path)
or "/"
end

-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = (parsed_url.path and string_gsub(parsed_url.path, "^/*", "/")) or "/"
Expand Down
15 changes: 5 additions & 10 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,12 @@ function _M.configure_request(conf)
if (conf.model.options and conf.model.options.upstream_url) then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
local path = conf.model.options
and conf.model.options.upstream_path
or ai_shared.operation_map[DRIVER_NAME][conf.route_type]
and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
or "/"
if not path then
return nil, fmt("operation %s is not supported for openai provider", conf.route_type)
end

parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
parsed_url.path = path
parsed_url.path = (conf.model.options and
conf.model.options.upstream_path)
or (ai_shared.operation_map[DRIVER_NAME][conf.route_type] and
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path)
or "/"
end

-- if the path is read from a URL capture, ensure that it is valid
Expand Down
1 change: 1 addition & 0 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ _M.upstream_url_format = {
gemini = "https://generativelanguage.googleapis.com",
gemini_vertex = "https://%s",
bedrock = "https://bedrock-runtime.%s.amazonaws.com",
mistral = "https://api.mistral.ai:443"
}

_M.operation_map = {
Expand Down
2 changes: 1 addition & 1 deletion kong/llm/schemas/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ return {
then_err = "must set %s for azure provider" }},

{ conditional_at_least_one_of = { if_field = "model.provider",
if_match = { one_of = { "mistral", "llama2" } },
if_match = { one_of = { "llama2" } },
then_at_least_one_of = { "model.options.upstream_url" },
then_err = "must set %s for self-hosted providers/models" }},

Expand Down
6 changes: 5 additions & 1 deletion spec/03-plugins/38-ai-proxy/00-config_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ describe(PLUGIN_NAME .. ": (schema)", function()


for i, v in ipairs(SELF_HOSTED_MODELS) do
it("requires upstream_url when using self-hosted " .. v .. " model", function()
local op = it
if v == "mistral" then -- mistral.ai now has managed service too!
op = pending
end
op("requires upstream_url when using self-hosted " .. v .. " model", function()
local config = {
route_type = "llm/v1/chat",
auth = {
Expand Down

0 comments on commit 7636e1b

Please sign in to comment.