Skip to content

Commit

Permalink
feat(ai-proxy): re-factor cloud-auth interface to share between many …
Browse files Browse the repository at this point in the history
…plugins
  • Loading branch information
tysoekong committed Aug 15, 2024
1 parent b45233d commit 7eb1c95
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 58 deletions.
85 changes: 85 additions & 0 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ local log_entry_keys = {

local openai_override = os.getenv("OPENAI_TEST_PORT")

---- IDENTITY SETTINGS
local GCP_SERVICE_ACCOUNT do
GCP_SERVICE_ACCOUNT = os.getenv("GCP_SERVICE_ACCOUNT")
end

local GCP = require("resty.gcp.request.credentials.accesstoken")
local aws_config = require "resty.aws.config" -- reads environment variables whilst available
local AWS = require("resty.aws")
local AWS_REGION do
AWS_REGION = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
end

local AZURE_TOKEN_SCOPE = "https://cognitiveservices.azure.com/.default"
local AZURE_TOKEN_VERSION = "v2.0"
----

_M._CONST = {
["SSE_TERMINATOR"] = "[DONE]",
}
Expand Down Expand Up @@ -229,6 +245,75 @@ local function handle_stream_event(event_table, model_info, route_type)
end
end

---
-- Manages cloud SDKs, for using "workload identity" authentications,
-- that are tied to this specific plugin in-memory.
--
-- This allows users to run different authentication configurations
-- between different AI Plugins.
--
-- @param {table} this_cache self - stores all the SDK instances
-- @param {table} plugin_config the configuration to cache against and also provide SDK settings with
-- @return {table} self
_M.cloud_identity_function = function(this_cache, plugin_config)
if plugin_config.model.provider == "gemini" and
plugin_config.auth and
plugin_config.auth.gcp_use_service_account then

ngx.log(ngx.NOTICE, "loading gcp sdk for plugin ", kong.plugin.get_id())

local service_account_json = (plugin_config.auth and plugin_config.auth.gcp_service_account_json) or GCP_SERVICE_ACCOUNT

local ok, gcp_auth = pcall(GCP.new, nil, service_account_json)
if ok and gcp_auth then
-- store our item for the next time we need it
gcp_auth.service_account_json = service_account_json
this_cache[plugin_config] = { interface = gcp_auth, error = nil }
return this_cache[plugin_config]
end

return { interface = nil, error = "cloud-authentication with GCP failed" }

elseif plugin_config.model.provider == "bedrock" then
ngx.log(ngx.NOTICE, "loading aws sdk for plugin ", kong.plugin.get_id())
local aws

local region = plugin_config.model.options
and plugin_config.model.options.bedrock
and plugin_config.model.options.bedrock.aws_region
or AWS_REGION

if not region then
return { interface = nil, error = "AWS region not specified anywhere" }
end

local access_key_set = (plugin_config.auth and plugin_config.auth.aws_access_key_id)
or aws_config.global.AWS_ACCESS_KEY_ID
local secret_key_set = plugin_config.auth and plugin_config.auth.aws_secret_access_key
or aws_config.global.AWS_SECRET_ACCESS_KEY

aws = AWS({
-- if any of these are nil, they either use the SDK default or
-- are deliberately null so that a different auth chain is used
region = region,
})

if access_key_set and secret_key_set then
-- Override credential config according to plugin config, if set
local creds = aws:Credentials {
accessKeyId = access_key_set,
secretAccessKey = secret_key_set,
}

aws.config.credentials = creds
end

this_cache[plugin_config] = { interface = aws, error = nil }

return this_cache[plugin_config]
end
end

---
-- Splits a HTTPS data chunk or frame into individual
-- SSE-format messages, see:
Expand Down
60 changes: 2 additions & 58 deletions kong/llm/proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -48,64 +48,7 @@ local ERROR__NOT_SET = 'data: {"error": true, "message": "empty or unsupported t

local _KEYBASTION = setmetatable({}, {
__mode = "k",
__index = function(this_cache, plugin_config)
if plugin_config.model.provider == "gemini" and
plugin_config.auth and
plugin_config.auth.gcp_use_service_account then

ngx.log(ngx.NOTICE, "loading gcp sdk for plugin ", kong.plugin.get_id())

local service_account_json = (plugin_config.auth and plugin_config.auth.gcp_service_account_json) or GCP_SERVICE_ACCOUNT

local ok, gcp_auth = pcall(GCP.new, nil, service_account_json)
if ok and gcp_auth then
-- store our item for the next time we need it
gcp_auth.service_account_json = service_account_json
this_cache[plugin_config] = { interface = gcp_auth, error = nil }
return this_cache[plugin_config]
end

return { interface = nil, error = "cloud-authentication with GCP failed" }

elseif plugin_config.model.provider == "bedrock" then
ngx.log(ngx.NOTICE, "loading aws sdk for plugin ", kong.plugin.get_id())
local aws

local region = plugin_config.model.options
and plugin_config.model.options.bedrock
and plugin_config.model.options.bedrock.aws_region
or AWS_REGION

if not region then
return { interface = nil, error = "AWS region not specified anywhere" }
end

local access_key_set = (plugin_config.auth and plugin_config.auth.aws_access_key_id)
or aws_config.global.AWS_ACCESS_KEY_ID
local secret_key_set = plugin_config.auth and plugin_config.auth.aws_secret_access_key
or aws_config.global.AWS_SECRET_ACCESS_KEY

aws = AWS({
-- if any of these are nil, they either use the SDK default or
-- are deliberately null so that a different auth chain is used
region = region,
})

if access_key_set and secret_key_set then
-- Override credential config according to plugin config, if set
local creds = aws:Credentials {
accessKeyId = access_key_set,
secretAccessKey = secret_key_set,
}

aws.config.credentials = creds
end

this_cache[plugin_config] = { interface = aws, error = nil }

return this_cache[plugin_config]
end
end,
__index = ai_shared.cloud_identity_function,
})


Expand Down Expand Up @@ -521,6 +464,7 @@ function _M:access(conf)

-- get the provider's cached identity interface - nil may come back, which is fine
local identity_interface = _KEYBASTION[conf]

if identity_interface and identity_interface.error then
llm_state.set_response_transformer_skipped()
kong.log.err("error authenticating with cloud-provider, ", identity_interface.error)
Expand Down

0 comments on commit 7eb1c95

Please sign in to comment.