Skip to content

Commit

Permalink
fix(ai-proxy): cloud identity (sdk) now used in ai transformer plugins
Browse files Browse the repository at this point in the history
Co-authored-by: Wangchong Zhou <[email protected]>
  • Loading branch information
tysoekong and fffonion committed Aug 15, 2024
1 parent 7eb1c95 commit 6324708
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-Transformer-Plugins**: Fixed a bug where cloud identity authentication
was not used in `ai-request-transformer` and `ai-response-transformer` plugins.
scope: Plugin
type: bugfix
51 changes: 39 additions & 12 deletions kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ function _M.to_format(request_table, model_info, route_type)
return response_object, content_type, nil
end

function _M.subrequest(body, conf, http_opts, return_res_table)
function _M.subrequest(body, conf, http_opts, return_res_table, identity_interface)
-- use shared/standard subrequest routine
local body_string, err

Expand All @@ -322,25 +322,52 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
end

-- may be overridden
local url = (conf.model.options and conf.model.options.upstream_url)
or fmt(
"%s%s",
ai_shared.upstream_url_format[DRIVER_NAME],
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
)
local f_url = conf.model.options and conf.model.options.upstream_url

if not f_url then -- upstream_url override is not set
local uri = fmt(ai_shared.upstream_url_format[DRIVER_NAME], identity_interface.interface.config.region)
local path = fmt(
ai_shared.operation_map[DRIVER_NAME][conf.route_type].path,
conf.model.name,
"converse")

local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method
f_url = fmt("%s%s", uri, path)
end

local parsed_url = socket_url.parse(f_url)
local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method

-- do the IAM auth and signature headers
identity_interface.interface.config.signatureVersion = "v4"
identity_interface.interface.config.endpointPrefix = "bedrock"

local r = {
headers = {},
method = method,
path = parsed_url.path,
host = parsed_url.host,
port = tonumber(parsed_url.port) or 443,
body = cjson.encode(body),
}

local signature, err = signer(identity_interface.interface.config, r)
if not signature then
return nil, "failed to sign AWS request: " .. (err or "NONE")
end

local headers = {
["Accept"] = "application/json",
["Content-Type"] = "application/json",
}

if conf.auth and conf.auth.header_name then
headers[conf.auth.header_name] = conf.auth.header_value
headers["Authorization"] = signature.headers["Authorization"]
if signature.headers["X-Amz-Security-Token"] then
headers["X-Amz-Security-Token"] = signature.headers["X-Amz-Security-Token"]
end
if signature.headers["X-Amz-Date"] then
headers["X-Amz-Date"] = signature.headers["X-Amz-Date"]
end

local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
local res, err, httpc = ai_shared.http_request(f_url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end
Expand Down
23 changes: 20 additions & 3 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ function _M.to_format(request_table, model_info, route_type)
return response_object, content_type, nil
end

function _M.subrequest(body, conf, http_opts, return_res_table)
function _M.subrequest(body, conf, http_opts, return_res_table, identity_interface)
-- use shared/standard subrequest routine
local body_string, err

Expand Down Expand Up @@ -292,7 +292,24 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
["Content-Type"] = "application/json",
}

if conf.auth and conf.auth.header_name then
if identity_interface and identity_interface.interface then
if identity_interface.interface:needsRefresh() then
-- HACK: A bug in lua-resty-gcp tries to re-load the environment
-- variable every time, which fails in nginx
-- Create a whole new interface instead.
-- Memory leaks are mega unlikely because this should only
-- happen about once an hour, and the old one will be
-- cleaned up anyway.
local service_account_json = identity_interface.interface.service_account_json
local identity_interface_new = identity_interface.interface:new(service_account_json)
identity_interface.interface.token = identity_interface_new.token

kong.log.debug("gcp identity token for ", kong.plugin.get_id(), " has been refreshed")
end

headers["Authorization"] = "Bearer " .. identity_interface.interface.token

elseif conf.auth and conf.auth.header_name then
headers[conf.auth.header_name] = conf.auth.header_value
end

Expand Down Expand Up @@ -413,7 +430,7 @@ function _M.configure_request(conf, identity_interface)
local identity_interface_new = identity_interface:new(service_account_json)
identity_interface.token = identity_interface_new.token

kong.log.notice("gcp identity token for ", kong.plugin.get_id(), " has been refreshed")
kong.log.debug("gcp identity token for ", kong.plugin.get_id(), " has been refreshed")
end

kong.service.request.set_header("Authorization", "Bearer " .. identity_interface.token)
Expand Down
7 changes: 2 additions & 5 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ local AWS = require("resty.aws")
local AWS_REGION do
AWS_REGION = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
end

local AZURE_TOKEN_SCOPE = "https://cognitiveservices.azure.com/.default"
local AZURE_TOKEN_VERSION = "v2.0"
----

_M._CONST = {
Expand Down Expand Up @@ -260,7 +257,7 @@ _M.cloud_identity_function = function(this_cache, plugin_config)
plugin_config.auth and
plugin_config.auth.gcp_use_service_account then

ngx.log(ngx.NOTICE, "loading gcp sdk for plugin ", kong.plugin.get_id())
ngx.log(ngx.DEBUG, "loading gcp sdk for plugin ", kong.plugin.get_id())

local service_account_json = (plugin_config.auth and plugin_config.auth.gcp_service_account_json) or GCP_SERVICE_ACCOUNT

Expand All @@ -275,7 +272,7 @@ _M.cloud_identity_function = function(this_cache, plugin_config)
return { interface = nil, error = "cloud-authentication with GCP failed" }

elseif plugin_config.model.provider == "bedrock" then
ngx.log(ngx.NOTICE, "loading aws sdk for plugin ", kong.plugin.get_id())
ngx.log(ngx.DEBUG, "loading aws sdk for plugin ", kong.plugin.get_id())
local aws

local region = plugin_config.model.options
Expand Down
10 changes: 6 additions & 4 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ do
local ai_request

-- mistral, cohere, titan (via Bedrock) don't support system commands
if self.driver == "bedrock" then
if self.conf.model.provider == "bedrock" then
for _, p in ipairs(self.driver.bedrock_unsupported_system_role_patterns) do
if request.model:find(p) then
if self.conf.model.name:find(p) then
ai_request = {
messages = {
[1] = {
Expand Down Expand Up @@ -147,7 +147,7 @@ do
ai_shared.pre_request(self.conf, ai_request)

-- send it to the ai service
local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false)
local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false, self.identity_interface)
if err then
return nil, "failed to introspect request with AI service: " .. err
end
Expand Down Expand Up @@ -225,13 +225,15 @@ do
--- Instantiate a new LLM driver instance.
-- @tparam table conf Configuration table
-- @tparam table http_opts HTTP options table
-- @tparam table [optional] cloud-authentication identity interface
-- @treturn[1] table A new LLM driver instance
-- @treturn[2] nil
-- @treturn[2] string An error message if instantiation failed
function _M.new_driver(conf, http_opts)
function _M.new_driver(conf, http_opts, identity_interface)
local self = {
conf = conf or {},
http_opts = http_opts or {},
identity_interface = identity_interface, -- 'or nil'
}
setmetatable(self, LLM)

Expand Down
13 changes: 0 additions & 13 deletions kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@ local kong_utils = require("kong.tools.gzip")
local buffer = require "string.buffer"
local strip = require("kong.tools.utils").strip

-- cloud auth/sdk providers
local GCP_SERVICE_ACCOUNT do
GCP_SERVICE_ACCOUNT = os.getenv("GCP_SERVICE_ACCOUNT")
end

local GCP = require("resty.gcp.request.credentials.accesstoken")
local aws_config = require "resty.aws.config" -- reads environment variables whilst available
local AWS = require("resty.aws")
local AWS_REGION do
AWS_REGION = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
end
--


local EMPTY = require("kong.tools.table").EMPTY

Expand Down
19 changes: 18 additions & 1 deletion kong/plugins/ai-request-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@ local kong_meta = require "kong.meta"
local fmt = string.format
local llm = require("kong.llm")
local llm_state = require("kong.llm.state")
local ai_shared = require("kong.llm.drivers.shared")
--

_M.PRIORITY = 777
_M.VERSION = kong_meta.version

local _KEYBASTION = setmetatable({}, {
__mode = "k",
__index = ai_shared.cloud_identity_function,
})

local function bad_request(msg)
kong.log.info(msg)
return kong.response.exit(400, { error = { message = msg } })
Expand Down Expand Up @@ -40,14 +46,25 @@ local function create_http_opts(conf)
end

function _M:access(conf)
local kong_ctx_shared = kong.ctx.shared

kong.service.request.enable_buffering()
llm_state.should_disable_ai_proxy_response_transform()

-- get cloud identity SDK, if required
local identity_interface = _KEYBASTION[conf.llm]

if identity_interface and identity_interface.error then
kong_ctx_shared.skip_response_transformer = true
kong.log.err("error authenticating with ", conf.model.provider, " using native provider auth, ", identity_interface.error)
return kong.response.exit(500, "LLM request failed before proxying")
end

-- first find the configured LLM interface and driver
local http_opts = create_http_opts(conf)
conf.llm.__plugin_id = conf.__plugin_id
conf.llm.__key__ = conf.__key__
local ai_driver, err = llm.new_driver(conf.llm, http_opts)
local ai_driver, err = llm.new_driver(conf.llm, http_opts, identity_interface)

if not ai_driver then
return internal_server_error(err)
Expand Down
19 changes: 18 additions & 1 deletion kong/plugins/ai-response-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ local fmt = string.format
local kong_utils = require("kong.tools.gzip")
local llm = require("kong.llm")
local llm_state = require("kong.llm.state")
local ai_shared = require("kong.llm.drivers.shared")
--

_M.PRIORITY = 769
_M.VERSION = kong_meta.version

local _KEYBASTION = setmetatable({}, {
__mode = "k",
__index = ai_shared.cloud_identity_function,
})

local function bad_request(msg)
kong.log.info(msg)
return kong.response.exit(400, { error = { message = msg } })
Expand Down Expand Up @@ -99,14 +105,25 @@ end


function _M:access(conf)
local kong_ctx_shared = kong.ctx.shared

kong.service.request.enable_buffering()
llm_state.disable_ai_proxy_response_transform()

-- get cloud identity SDK, if required
local identity_interface = _KEYBASTION[conf.llm]

if identity_interface and identity_interface.error then
kong_ctx_shared.skip_response_transformer = true
kong.log.err("error authenticating with ", conf.model.provider, " using native provider auth, ", identity_interface.error)
return kong.response.exit(500, "LLM request failed before proxying")
end

-- first find the configured LLM interface and driver
local http_opts = create_http_opts(conf)
conf.llm.__plugin_id = conf.__plugin_id
conf.llm.__key__ = conf.__key__
local ai_driver, err = llm.new_driver(conf.llm, http_opts)
local ai_driver, err = llm.new_driver(conf.llm, http_opts, identity_interface)

if not ai_driver then
return internal_server_error(err)
Expand Down

0 comments on commit 6324708

Please sign in to comment.