Skip to content

Commit

Permalink
feat(llm): add auth.allow_override option for llm auth functionality (#…
Browse files Browse the repository at this point in the history
…13493)

Add `allow_override` option to allow overriding the upstream model auth
parameter or header from the caller's request.
  • Loading branch information
oowl authored Aug 15, 2024
1 parent 9bc3deb commit 700b3b0
Show file tree
Hide file tree
Showing 18 changed files with 1,090 additions and 18 deletions.
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/ai-proxy-add-allow-override-opt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: |
**AI-proxy-plugin**: Add `allow_auth_override` option to allow overriding the upstream model auth parameter or header from the caller's request.
scope: Plugin
type: feature
1 change: 1 addition & 0 deletions kong/clustering/compat/removed_fields.lua
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ return {
"model.options.bedrock",
"auth.aws_access_key_id",
"auth.aws_secret_access_key",
"auth.allow_auth_override",
"model_name_header",
},
ai_prompt_decorator = {
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -472,13 +472,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
Expand Down
10 changes: 8 additions & 2 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,23 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end


local query_table = kong.request.get_query()

-- technically min supported version
query_table["api-version"] = kong.request.get_query_arg("api-version")
or (conf.model.options and conf.model.options.azure_api_version)

if auth_param_name and auth_param_value and auth_param_location == "query" then
query_table[auth_param_name] = auth_param_value
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
end
end

kong.service.request.set_query(query_table)
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -480,13 +480,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a pre-request hook
Expand Down
11 changes: 8 additions & 3 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,18 @@ function _M.configure_request(conf)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_header_name and auth_header_value then
kong.service.request.set_header(auth_header_name, auth_header_value)
local exist_value = kong.request.get_header(auth_header_name)
if exist_value == nil or not conf.auth.allow_auth_override then
kong.service.request.set_header(auth_header_name, auth_header_value)
end
end

if auth_param_name and auth_param_value and auth_param_location == "query" then
local query_table = kong.request.get_query()
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
if query_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
query_table[auth_param_name] = auth_param_value
kong.service.request.set_query(query_table)
end
end

-- if auth_param_location is "form", it will have already been set in a global pre-request hook
Expand Down
4 changes: 3 additions & 1 deletion kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ function _M.pre_request(conf, request_table)
local auth_param_location = conf.auth and conf.auth.param_location

if auth_param_name and auth_param_value and auth_param_location == "body" and request_table then
request_table[auth_param_name] = auth_param_value
if request_table[auth_param_name] == nil or not conf.auth.allow_auth_override then
request_table[auth_param_name] = auth_param_value
end
end

-- retrieve the plugin name
Expand Down
10 changes: 10 additions & 0 deletions kong/llm/schemas/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ local auth_schema = {
required = false,
encrypted = true,
referenceable = true }},
{ allow_auth_override = {
type = "boolean",
description = "If enabled, the authorization header or parameter can be overridden in the request by the value configured in the plugin.",
required = false,
default = true }},
}
}

Expand Down Expand Up @@ -237,6 +242,11 @@ return {
{ logging = logging_schema },
},
entity_checks = {
{ conditional = { if_field = "model.provider",
if_match = { one_of = { "bedrock", "gemini" } },
then_field = "auth.allow_auth_override",
then_match = { eq = false },
then_err = "bedrock and gemini only support auth.allow_auth_override = false" }},
{ mutually_required = { "auth.header_name", "auth.header_value" }, },
{ mutually_required = { "auth.param_name", "auth.param_value", "auth.param_location" }, },

Expand Down
5 changes: 5 additions & 0 deletions spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
header_value = "value",
gcp_service_account_json = '{"service": "account"}',
gcp_use_service_account = true,
allow_auth_override = false,
},
model = {
name = "any-model-name",
Expand Down Expand Up @@ -526,6 +527,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
-- gemini fields
expected.config.auth.gcp_service_account_json = nil
expected.config.auth.gcp_use_service_account = nil
expected.config.auth.allow_auth_override = nil
expected.config.model.options.gemini = nil

-- bedrock fields
Expand Down Expand Up @@ -562,6 +564,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
header_value = "value",
gcp_service_account_json = '{"service": "account"}',
gcp_use_service_account = true,
allow_auth_override = false,
},
model = {
name = "any-model-name",
Expand Down Expand Up @@ -625,6 +628,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
header_value = "value",
gcp_service_account_json = '{"service": "account"}',
gcp_use_service_account = true,
allow_auth_override = false,
},
model = {
name = "any-model-name",
Expand Down Expand Up @@ -720,6 +724,7 @@ describe("CP/DP config compat transformations #" .. strategy, function()
-- bedrock fields
expected.config.auth.aws_access_key_id = nil
expected.config.auth.aws_secret_access_key = nil
expected.config.auth.allow_auth_override = nil
expected.config.model.options.bedrock = nil

do_assert(uuid(), "3.7.0", expected)
Expand Down
55 changes: 55 additions & 0 deletions spec/03-plugins/38-ai-proxy/00-config_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -289,4 +289,59 @@ describe(PLUGIN_NAME .. ": (schema)", function()
assert.is_truthy(ok)
end)

it("bedrock model can not support ath.allowed_auth_override", function()
local config = {
route_type = "llm/v1/chat",
auth = {
param_name = "apikey",
param_value = "key",
param_location = "query",
header_name = "Authorization",
header_value = "Bearer token",
allow_auth_override = true,
},
model = {
name = "bedrock",
provider = "bedrock",
options = {
max_tokens = 256,
temperature = 1.0,
upstream_url = "http://nowhere",
},
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.is_truthy(err)
end)

it("gemini model can not support ath.allowed_auth_override", function()
local config = {
route_type = "llm/v1/chat",
auth = {
param_name = "apikey",
param_value = "key",
param_location = "query",
header_name = "Authorization",
header_value = "Bearer token",
allow_auth_override = true,
},
model = {
name = "gemini",
provider = "gemini",
options = {
max_tokens = 256,
temperature = 1.0,
upstream_url = "http://nowhere",
},
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.is_truthy(err)
end)
end)
Loading

1 comment on commit 700b3b0

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bazel Build

Docker image available kong/kong:700b3b07c2124b759fd403c509d58bf5c4da4c7a
Artifacts available https://github.com/Kong/kong/actions/runs/10398115050

Please sign in to comment.