diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 02ee704bd67c..80f26c077867 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -59,6 +59,22 @@ local log_entry_keys = { local openai_override = os.getenv("OPENAI_TEST_PORT") +---- IDENTITY SETTINGS +local GCP_SERVICE_ACCOUNT do + GCP_SERVICE_ACCOUNT = os.getenv("GCP_SERVICE_ACCOUNT") +end + +local GCP = require("resty.gcp.request.credentials.accesstoken") +local aws_config = require "resty.aws.config" -- reads environment variables whilst available +local AWS = require("resty.aws") +local AWS_REGION do + AWS_REGION = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") +end + +local AZURE_TOKEN_SCOPE = "https://cognitiveservices.azure.com/.default" +local AZURE_TOKEN_VERSION = "v2.0" +---- + _M._CONST = { ["SSE_TERMINATOR"] = "[DONE]", } @@ -229,6 +245,75 @@ local function handle_stream_event(event_table, model_info, route_type) end end +--- +-- Manages cloud SDKs, for using "workload identity" authentications, +-- that are tied to this specific plugin in-memory. +-- +-- This allows users to run different authentication configurations +-- between different AI Plugins. +-- +-- @param {table} this_cache self - stores all the SDK instances +-- @param {table} plugin_config the configuration to cache against and also provide SDK settings with +-- @return {table} self +_M.cloud_identity_function = function(this_cache, plugin_config) + if plugin_config.model.provider == "gemini" and + plugin_config.auth and + plugin_config.auth.gcp_use_service_account then + + ngx.log(ngx.NOTICE, "loading gcp sdk for plugin ", kong.plugin.get_id()) + + local service_account_json = (plugin_config.auth and plugin_config.auth.gcp_service_account_json) or GCP_SERVICE_ACCOUNT + + local ok, gcp_auth = pcall(GCP.new, nil, service_account_json) + if ok and gcp_auth then + -- store our item for the next time we need it + gcp_auth.service_account_json = service_account_json + this_cache[plugin_config] = { interface = gcp_auth, error = nil } + return this_cache[plugin_config] + end + + return { interface = nil, error = "cloud-authentication with GCP failed" } + + elseif plugin_config.model.provider == "bedrock" then + ngx.log(ngx.NOTICE, "loading aws sdk for plugin ", kong.plugin.get_id()) + local aws + + local region = plugin_config.model.options + and plugin_config.model.options.bedrock + and plugin_config.model.options.bedrock.aws_region + or AWS_REGION + + if not region then + return { interface = nil, error = "AWS region not specified anywhere" } + end + + local access_key_set = (plugin_config.auth and plugin_config.auth.aws_access_key_id) + or aws_config.global.AWS_ACCESS_KEY_ID + local secret_key_set = plugin_config.auth and plugin_config.auth.aws_secret_access_key + or aws_config.global.AWS_SECRET_ACCESS_KEY + + aws = AWS({ + -- if any of these are nil, they either use the SDK default or + -- are deliberately null so that a different auth chain is used + region = region, + }) + + if access_key_set and secret_key_set then + -- Override credential config according to plugin config, if set + local creds = aws:Credentials { + accessKeyId = access_key_set, + secretAccessKey = secret_key_set, + } + + aws.config.credentials = creds + end + + this_cache[plugin_config] = { interface = aws, error = nil } + + return this_cache[plugin_config] + end +end + --- -- Splits a HTTPS data chunk or frame into individual -- SSE-format messages, see: diff --git a/kong/llm/proxy/handler.lua b/kong/llm/proxy/handler.lua index 69b97e71f862..314181c6a9e7 100644 --- a/kong/llm/proxy/handler.lua +++ b/kong/llm/proxy/handler.lua @@ -48,64 +48,7 @@ local ERROR__NOT_SET = 'data: {"error": true, "message": "empty or unsupported t local _KEYBASTION = setmetatable({}, { __mode = "k", - __index = function(this_cache, plugin_config) - if plugin_config.model.provider == "gemini" and - plugin_config.auth and - plugin_config.auth.gcp_use_service_account then - - ngx.log(ngx.NOTICE, "loading gcp sdk for plugin ", kong.plugin.get_id()) - - local service_account_json = (plugin_config.auth and plugin_config.auth.gcp_service_account_json) or GCP_SERVICE_ACCOUNT - - local ok, gcp_auth = pcall(GCP.new, nil, service_account_json) - if ok and gcp_auth then - -- store our item for the next time we need it - gcp_auth.service_account_json = service_account_json - this_cache[plugin_config] = { interface = gcp_auth, error = nil } - return this_cache[plugin_config] - end - - return { interface = nil, error = "cloud-authentication with GCP failed" } - - elseif plugin_config.model.provider == "bedrock" then - ngx.log(ngx.NOTICE, "loading aws sdk for plugin ", kong.plugin.get_id()) - local aws - - local region = plugin_config.model.options - and plugin_config.model.options.bedrock - and plugin_config.model.options.bedrock.aws_region - or AWS_REGION - - if not region then - return { interface = nil, error = "AWS region not specified anywhere" } - end - - local access_key_set = (plugin_config.auth and plugin_config.auth.aws_access_key_id) - or aws_config.global.AWS_ACCESS_KEY_ID - local secret_key_set = plugin_config.auth and plugin_config.auth.aws_secret_access_key - or aws_config.global.AWS_SECRET_ACCESS_KEY - - aws = AWS({ - -- if any of these are nil, they either use the SDK default or - -- are deliberately null so that a different auth chain is used - region = region, - }) - - if access_key_set and secret_key_set then - -- Override credential config according to plugin config, if set - local creds = aws:Credentials { - accessKeyId = access_key_set, - secretAccessKey = secret_key_set, - } - - aws.config.credentials = creds - end - - this_cache[plugin_config] = { interface = aws, error = nil } - - return this_cache[plugin_config] - end - end, + __index = ai_shared.cloud_identity_function, }) @@ -521,6 +464,7 @@ function _M:access(conf) -- get the provider's cached identity interface - nil may come back, which is fine local identity_interface = _KEYBASTION[conf] + if identity_interface and identity_interface.error then llm_state.set_response_transformer_skipped() kong.log.err("error authenticating with cloud-provider, ", identity_interface.error)